In [15]:
%load_ext autoreload
%autoreload 2
from boxes import *
from learner import *
import math

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
PATH = '../data/ontologies/'
# Data in unary.tsv are probabilites separated by newlines. The probability on line n is P(n), where n is the id assigned to the nth element.
unary_prob = torch.from_numpy(np.loadtxt(f'{PATH}mouse_unary.tsv')).float().to("cuda")
num_boxes = unary_prob.shape[0]

# We're going to use random negative sampling during training, so no need to include negatives in our training data itself
train = Probs.load_from_julia(PATH, 'mouse_adj_tr_pos.tsv', 'mouse_adj_tr_neg.tsv', ratio_neg = 0).to("cuda")

# The dev set will have a fixed set of negatives, however.
dev = Probs.load_from_julia(PATH, 'mouse_adj_dev_pos.tsv', 'mouse_adj_dev_neg.tsv', ratio_neg = 1).to("cuda")

In [23]:
box_model = BoxModel(
    BoxParamType=MinMaxSigmoidBoxes,
    vol_func=soft_volume,
    num_models=1,
    num_boxes=num_boxes,
    dims=75,
    method="orig").to("cuda")

train_dl = TensorDataLoader(train, batch_size=2**6, shuffle=True)

opt = torch.optim.Adam(box_model.parameters(), lr=1e-2)

In [24]:
def mean_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

# See boxes/loss_functions.py file for more options. Note that you may have to changed them to fit your use case.
# Also note that "kl_div_sym" is just binary cross-entropy.

In [25]:
# For this dataset we had unary probabilities as well as conditional probabilities. Our loss function will be a sum of these, which is provided by the following loss function wrapper:
loss_func = LossPieces(mean_cond_kl_loss, (1e-2, mean_unary_kl_loss(unary_prob)))

metrics = [metric_hard_accuracy, metric_hard_f1]

rec_col = RecorderCollection()

callbacks = CallbackCollection(
    LossCallback(rec_col.train, train),
    LossCallback(rec_col.dev, dev),
    *(MetricCallback(rec_col.dev, dev, m) for m in metrics),
    *(MetricCallback(rec_col.train, train, m) for m in metrics),
    MetricCallback(rec_col.dev, dev, metric_pearson_r),
    MetricCallback(rec_col.train, dev, metric_spearman_r),
    PercentIncreaseEarlyStopping(rec_col.dev, "mean_cond_kl_loss", 0.25, 10),
    PercentIncreaseEarlyStopping(rec_col.dev, "mean_cond_kl_loss", 0.5),
#     GradientClipping(-1000,1000),
    RandomNegativeSampling(num_boxes, 1),
    StopIfNaN(),
)

l = Learner(train_dl, box_model, loss_func, opt, callbacks, recorder = rec_col.learn)

In [26]:
l.train(20)

HBox(children=(IntProgress(value=0, description='Overall Training:', max=20, style=ProgressStyle(description_w…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=43, style=ProgressStyle(description_widt…




In [27]:
rec_col.train

Unnamed: 0,0.01*mean_unary_kl_loss,loss,mean_cond_kl_loss,metric_hard_accuracy,metric_hard_f1,metric_spearman_r
0.0,5.3e-05,9.759809,9.759756,0.0,,0.047735
1.0,5.1e-05,8.856616,8.856565,0.0,,0.107965
2.0,5e-05,7.874654,7.874604,0.0,,0.125365
3.0,4.9e-05,6.948559,6.94851,0.0,,0.119864
4.0,4.9e-05,6.0972,6.097152,0.0,,0.113542
5.0,4.8e-05,5.330537,5.330489,0.0,,0.109052
6.0,4.7e-05,4.649946,4.649899,0.005492,0.010925,0.106832
7.0,4.7e-05,4.04662,4.046574,0.047602,0.090877,0.112718
8.0,4.6e-05,3.512738,3.512691,0.093006,0.170184,0.115774
9.0,4.6e-05,3.045454,3.045408,0.113512,0.20388,0.120169


In [28]:
rec_col.dev

Unnamed: 0,0.01*mean_unary_kl_loss,loss,mean_cond_kl_loss,metric_hard_accuracy,metric_hard_f1,metric_pearson_r
0.0,5.3e-05,4.860791,4.860738,0.5,,0.024612
1.0,5.1e-05,4.678537,4.678485,0.5,,0.120056
2.0,5e-05,4.496689,4.496638,0.5,,0.163623
3.0,4.9e-05,4.325918,4.325869,0.5,,0.176754
4.0,4.9e-05,4.17711,4.177062,0.5,,0.178708
5.0,4.8e-05,4.055986,4.055938,0.5,,0.178014
6.0,4.7e-05,3.962523,3.962476,0.5,,0.186494
7.0,4.7e-05,3.87557,3.875524,0.500701,0.002801,0.189883
8.0,4.6e-05,3.811825,3.811778,0.500701,0.002801,0.191558
9.0,4.6e-05,3.773464,3.773418,0.500701,0.002801,0.195844
