In [152]:
%load_ext autoreload
%autoreload 2
from boxes import *
from learner import *
import math
import matplotlib.pyplot as plt
import os

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [153]:
@dataclass
class Learner:
    train_dl: DataLoader
    model: Module
    loss_fn: Callable
    opt: optim.Optimizer
    callbacks: CallbackCollection = field(default_factory=CallbackCollection)
    recorder: Recorder = field(default_factory=Recorder)
    categories: bool = False
    reraise_keyboard_interrupt: bool = False
    reraise_stop_training_exceptions: bool = False

    def __post_init__(self):
        self.progress = Progress(0,0,len(self.train_dl))
        self.callbacks.learner_post_init(self)

    #the split parameter will be used to find human/mouse/align data, so you need change it when using diff dataset(index)
    def split_data(self, batch_in, batch_out, split):
        category = torch.zeros(size=(batch_in.shape[0],), dtype=int)

        batch_class = batch_in > split

        for i, (a,b) in enumerate(batch_class):
            if not a and not b:
                category[i] = 0
            elif a and b:
                category[i] = 1
            else:
                category[i] = 2

        self.mouse_in = batch_in[category == 0]
        self.human_in = batch_in[category == 1]
        self.align_in = batch_in[category == 2]

        self.mouse_out = batch_out[category == 0]
        self.human_out = batch_out[category == 1]
        self.align_out = batch_out[category == 2]
        

        # INPUT TO THE MODEL:
        data_in = (self.mouse_in, self.human_in, self.align_in)
        # TARGET/LABEL:
        data_out = (self.mouse_out, self.human_out, self.align_out)

        return data_in, data_out

    def TensorNaN(self, size:Union[None,List[int], Tuple[int]]=None, device=None, requires_grad:bool=True):
        if size is None:    
            return torch.tensor(float('nan'), device=device, requires_grad=requires_grad)
        else:
            return float('nan') * torch.zeros(size=size, device=device, requires_grad=requires_grad)


    def train(self, epochs, progress_bar = True):
        try:
            self.callbacks.train_begin(self)
            for epoch in trange(epochs, desc="Overall Training:", disable=not progress_bar):
                self.callbacks.epoch_begin(self)
                for iteration, batch in enumerate(tqdm(self.train_dl, desc="Current Batch:", leave=False, disable=not progress_bar)):
                    if len(batch) == 2: # KLUDGE
                        self.batch_in, self.batch_out = batch
                    else:
                        self.batch_in = batch[0]
                        self.batch_out = None
                    self.progress.increment()
                    self.callbacks.batch_begin(self)
                    self.opt.zero_grad()
                    # self.test = self.model(torch.empty(0,2))
                    # self.loss = self.loss_fn(self.model_out, self.batch_out, self, self.recorder)
                    # self.loss = torch.tensor(float('nan'), requires_grad=True).cuda()
                    # self.loss.backward()
                    # self.opt.step()
                    # print(self.batch_in)
                    # print(self.batch_in.device)

                    if self.categories:
                        self.data_in, self.data_out = self.split_data(self.batch_in, self.batch_out, split=2737)                            
                        self.model_pred = [self.model(item) if len(item)>0 else {'P(A|B)':self.TensorNaN(device=self.batch_in.device)} for item in self.data_in]
                        self.loss = self.loss_fn(self.model_pred, self.data_out, self, self.recorder, categories=True)                        
                    else:
                        self.model_out = self.model(self.batch_in)
                        if self.batch_out is None:
                            self.loss = self.loss_fn(self.model_out, self, self.recorder)
                        else:
                            self.loss = self.loss_fn(self.model_out, self.batch_out, self, self.recorder)
                    # print(self.recorder.dataframe)
                    self.loss.backward()
                    self.callbacks.backward_end(self)
                    self.opt.step()
                    self.callbacks.batch_end(self)
                # print(self.recorder.dataframe)
                self.callbacks.epoch_end(self)
        except StopTrainingException as e:
            print(e)
            if self.reraise_stop_training_exceptions:
                raise e
        except KeyboardInterrupt:
            print(f"Stopped training at {self.progress.partial_epoch_progress()} epochs due to keyboard interrupt.")
            if self.reraise_keyboard_interrupt:
                raise KeyboardInterrupt
        finally:
            self.callbacks.train_end(self)


    def evaluation(self, trials, progress_bar=True):
        with torch.no_grad():
            # self.callbacks.eval_begin(self)
            for t in trials:
                self.callbacks.eval_align(self, t)
            self.callbacks.metric_plots(self)
            self.callbacks.bias_metric(self)
            self.callbacks.eval_end(self)

In [154]:
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
use_cuda

True

In [155]:
PATH = '../data/ontologies/anatomy/'

# aligment training split
ats = 0.8

# Transitive closure
Transitive_Closure = False

if Transitive_Closure:
    tc = "tc_"
else:
    tc = ""

# Data in unary.tsv are probabilites separated by newlines. The probability on line n is P(n), where n is the id assigned to the nth element.
unary_prob = torch.from_numpy(np.loadtxt(f'{PATH}unary/unary.tsv')).float().to(device)
num_boxes = unary_prob.shape[0]

# We're going to use random negative sampling during training, so no need to include negatives in our training data itself
train = Probs.load_from_julia(PATH, f'tr_pos_{tc}{ats}.tsv', f'tr_neg_{ats}.tsv', ratio_neg = 0).to(device)

# The dev set will have a fixed set of negatives, however.
dev = Probs.load_from_julia(PATH, f'dev_align_pos_{ats}.tsv', f'dev_align_neg_{ats}.tsv', ratio_neg = 1).to(device)

# This set is used just for evaluation purposes after training
tr_align = Probs.load_from_julia(PATH, f'tr_align_pos_{ats}.tsv', f'tr_align_neg_{ats}.tsv', ratio_neg = 1).to(device)


In [156]:
mouse_eval = Probs.load_from_julia(PATH, 'human_dev_pos.tsv', 'human_dev_neg.tsv', ratio_neg = 1).to(device)
human_eval = Probs.load_from_julia(PATH, 'mouse_dev_pos.tsv', 'mouse_dev_neg.tsv', ratio_neg = 1).to(device)

In [157]:
dims = 10
lr = 1e-2
nEpochs = 1
rns_ratio = 1
box_type = MinMaxSigmoidBoxes
use_unary = False
unary_weight = 1e-2

box_model = BoxModel(
    BoxParamType=box_type,
    vol_func=soft_volume,
    num_models=1,
    num_boxes=num_boxes,
    dims=dims,
    method="orig").to(device)

#### IF YOU ARE LOADING FROM JULIA WITH ratio_neg=0, train_dl WILL ONLY CONTAIN POSITIVE EXAMPLES
#### THIS MEANS YOUR MODEL SHOULD USE NEGATIVE SAMPLING DURING TRAINING
train_dl = TensorDataLoader(train, batch_size=2**6, shuffle=True)

mouse_dl = TensorDataLoader(mouse_eval, batch_size=2**6)
human_dl = TensorDataLoader(human_eval, batch_size=2**6)

eval_dl = [mouse_dl, human_dl]

opt = torch.optim.Adam(box_model.parameters(), lr=lr)

In [158]:
def mean_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

def human_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    kl_loss = kl_div_sym(model_out["P(A|B)"], target, eps)
    #print(kl_loss, target, model_out["P(A|B)"])
    
    [0.2, 0.6, 0.8]
    [1, 0, 1]
    
    (1-0.2)**2 + max(0, is (0.6 - 0) > 0.5)**2 +(1- 0.8)**2
    
    
    kl_loss[kl_loss < 0.5] = 0
    return kl_loss.mean()

def mouse_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

def align_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

# See boxes/loss_functions.py file for more options. Note that you may have to changed them to fit your use case.
# Also note that "kl_div_sym" is just binary cross-entropy.

In [159]:
# For this dataset we had unary probabilities as well as conditional probabilities. Our loss function will be a sum of these, which is provided by the following loss function wrapper:

# if use_unary:
#     loss_func = LossPieces(mean_cond_kl_loss, (unary_weight, mean_unary_kl_loss(unary_prob)))
# else:
#     loss_func = LossPieces(mean_cond_kl_loss)

loss_func = LossPieces( mouse_cond_kl_loss, human_cond_kl_loss, (5e-2,align_cond_kl_loss))

metrics = [metric_hard_accuracy, metric_hard_f1]
align_metrics = [metric_hard_accuracy_align, metric_hard_f1_align, metric_hard_accuracy_align_mean, metric_hard_f1_align_mean]

rec_col = RecorderCollection()

threshold = np.arange(0.1, 1, 0.1)

callbacks = CallbackCollection(
    LossCallback(rec_col.train, train),
    LossCallback(rec_col.dev, dev),
    *(MetricCallback(rec_col.dev, dev, m) for m in metrics),
    *(MetricCallback(rec_col.train, train, m) for m in metrics),
    *(MetricCallback(rec_col.onto, human_eval, m) for m in metrics),
    *(MetricCallback(rec_col.onto, mouse_eval, m) for m in metrics),
    *(EvalAlignment(rec_col.train_align, tr_align, m) for m in align_metrics),
    *(EvalAlignment(rec_col.dev_align, dev, m) for m in align_metrics),
    JustGiveMeTheData(rec_col.probs, dev, get_probabilities),
    BiasMetric(rec_col.bias, dev, pct_of_align_cond_on_human_as_min),
    PlotMetrics(rec_col.dev_roc_plot, dev, roc_plot),
    PlotMetrics(rec_col.dev_pr_plot, dev, pr_plot),
    PlotMetrics(rec_col.tr_roc_plot, tr_align, roc_plot),
    PlotMetrics(rec_col.tr_pr_plot, tr_align, pr_plot),
    MetricCallback(rec_col.train, train, metric_pearson_r),
    MetricCallback(rec_col.train, train, metric_spearman_r),
    MetricCallback(rec_col.dev, dev, metric_pearson_r),
    MetricCallback(rec_col.dev, dev, metric_spearman_r),
#     PercentIncreaseEarlyStopping(rec_col.dev, "mean_cond_kl_loss", 0.25, 10),
#     PercentIncreaseEarlyStopping(rec_col.dev, "mean_cond_kl_loss", 0.5),
#     PercentIncreaseEarlyStopping(rec_col.dev, "mouse_cond_kl_loss", 0.25, 10),
#     PercentIncreaseEarlyStopping(rec_col.dev, "mouse_cond_kl_loss", 0.5),
#     GradientClipping(-1000,1000),
    RandomNegativeSampling(num_boxes, rns_ratio),
    StopIfNaN(),
)

# l = Learner(train_dl, box_model, loss_func, opt, callbacks, recorder = rec_col.learn)
l = Learner(train_dl, box_model, loss_func, opt, callbacks, recorder = rec_col.learn, categories=True)

In [160]:
l.train(nEpochs)

tensor([0.1085, 0.1451, 0.1464,  ..., 0.1352, 0.1426, 0.1748], device='cuda:0') tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0')
tensor([0.1327, 0.0641, 0.1866,  ..., 0.0915, 0.1130, 0.1087], device='cuda:0') tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0')
tensor([0.0675, 0.0901, 0.0950,  ..., 0.2887, 0.1340, 0.0668], device='cuda:0') tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0')
tensor(nan, device='cuda:0', requires_grad=True) tensor([], device='cuda:0')
tensor(nan, device='cuda:0', requires_grad=True) tensor([], device='cuda:0')
tensor([0.0996, 0.0241, 0.2343,  ..., 0.1040, 0.0529, 0.2889], device='cuda:0') tensor([1., 1., 1.,  ..., 0., 0., 0.], device='cuda:0')


HBox(children=(IntProgress(value=0, description='Overall Training:', max=1, style=ProgressStyle(description_wi…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=177, style=ProgressStyle(description_wid…

tensor([0.0667, 0.1916, 0.0749, 0.1277, 0.1640, 0.2803, 0.0470, 0.1027, 0.1423,
        0.4296, 0.0985, 0.1148, 0.0962, 0.1390, 0.2425, 0.0508, 0.1800, 0.2582,
        0.0961, 0.0709, 0.1452, 0.1048, 0.1121, 0.1025, 0.1625],
       device='cuda:0', grad_fn=<ExpBackward>) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,
        0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
tensor([0.0437, 0.0577, 0.2693, 0.1327, 0.0960, 0.1109, 0.1232, 0.0494, 0.1658,
        0.1972, 0.1139, 0.0366, 0.1843, 0.2693, 0.2171, 0.1095, 0.0901, 0.1784,
        0.3320, 0.0852, 0.1878, 0.2385, 0.1995, 0.2427, 0.0861, 0.1127, 0.1747,
        0.1443, 0.1495, 0.1969, 0.2526, 0.0873, 0.1112, 0.0570, 0.0879, 0.0585,
        0.1864, 0.0892, 0.0706, 0.0771, 0.1178, 0.0779, 0.1057, 0.3299, 0.1228,
        0.0854, 0.2385, 0.3048, 0.1310, 0.1673, 0.1172, 0.1574, 0.1397, 0.0490,
        0.1304, 0.0698, 0.1039], device='cuda:0', grad_fn=<ExpBackward>) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1

tensor([0.1659, 0.0819, 0.1358, 0.2508, 0.0718, 0.4258, 0.0767, 0.2336, 0.1015,
        0.1407, 0.0776, 0.0965, 0.2136, 0.0820, 0.1596, 0.0469, 0.3811, 0.2008,
        0.1443, 0.1828, 0.0898, 0.3836, 0.1250, 0.1802, 0.1404, 0.0801, 0.2722,
        0.0737, 0.1609, 0.0616, 0.1004], device='cuda:0',
       grad_fn=<ExpBackward>) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
tensor([0.0959, 0.0820, 0.1165, 0.0641, 0.1067, 0.0891, 0.1631, 0.2322, 0.1472,
        0.1020, 0.2292, 0.2404, 0.0894, 0.0940, 0.0667, 0.1271, 0.1388, 0.0928,
        0.0905, 0.1049, 0.1221, 0.1756, 0.1028, 0.1271, 0.1397, 0.0987, 0.1189,
        0.1836, 0.1606, 0.0531, 0.1510, 0.3956, 0.1002, 0.1748, 0.0385, 0.1252,
        0.1033, 0.0790, 0.1316, 0.1460, 0.1388, 0.1027, 0.1506, 0.0746, 0.0926,
        0.0920, 0.1823, 0.3324, 0.1052, 0.1676, 0.3688, 0.1197, 0.3634, 0.0670,
        0.1102, 0.1782], device='c

Stopped training at 0.096045197740113 epochs due to keyboard interrupt.
