In [24]:
%load_ext autoreload
%autoreload 2
from boxes import *
from learner import *
import math

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
PATH = '../data/wordnet/mjb/rigorous_sampling/mammal_pos_0.5_neg_10_samp_uniform_num_1182/0/'
# Data in unary.tsv are probabilites separated by newlines. The probability on line n is P(n), where n is the id assigned to the nth element.
unary_prob = torch.from_numpy(np.loadtxt(f'{PATH}train_tc_unary.tsv')).float().to("cuda")
num_boxes = unary_prob.shape[0]

# We're going to use random negative sampling during training, so no need to include negatives in our training data itself
train = Probs.load_from_julia(PATH, 'train_tc_pos.tsv', 'train_neg.tsv', ratio_neg = 0).to("cuda")

# The dev set will have a fixed set of negatives, however.
dev = Probs.load_from_julia(PATH, 'dev_pos.tsv', 'dev_neg.tsv', ratio_neg = 1).to("cuda")

In [26]:
box_model = BoxModel(
    BoxParamType=MinMaxSigmoidBoxes,
    vol_func=soft_volume,
    num_models=1,
    num_boxes=num_boxes,
    dims=50,
    method="orig").to("cuda")

train_dl = TensorDataLoader(train, batch_size=2**6, shuffle=True)

opt = torch.optim.Adam(box_model.parameters(), lr=1e-2)

In [27]:
def mean_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

# See boxes/loss_functions.py file for more options. Note that you may have to changed them to fit your use case.
# Also note that "kl_div_sym" is just binary cross-entropy.

In [28]:
# For this dataset we had unary probabilities as well as conditional probabilities. Our loss function will be a sum of these, which is provided by the following loss function wrapper:
loss_func = LossPieces(mean_cond_kl_loss, (1e-2, mean_unary_kl_loss(unary_prob)))

metrics = [metric_hard_accuracy, metric_hard_f1]

rec_col = RecorderCollection()

callbacks = CallbackCollection(
    LossCallback(rec_col.train, train),
    LossCallback(rec_col.dev, dev),
    *(MetricCallback(rec_col.dev, dev, m) for m in metrics),
    *(MetricCallback(rec_col.train, train, m) for m in metrics),
    MetricCallback(rec_col.dev, dev, metric_pearson_r),
    MetricCallback(rec_col.train, dev, metric_spearman_r),
    PercentIncreaseEarlyStopping(rec_col.dev, "mean_cond_kl_loss", 0.25, 10),
    PercentIncreaseEarlyStopping(rec_col.dev, "mean_cond_kl_loss", 0.5),
#     GradientClipping(-1000,1000),
    RandomNegativeSampling(num_boxes, 1),
    StopIfNaN(),
)

l = Learner(train_dl, box_model, loss_func, opt, callbacks, recorder = rec_col.learn)

In [29]:
l.train(10)

odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])


HBox(children=(IntProgress(value=0, description='Overall Training:', max=10, style=ProgressStyle(description_w…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=54, style=ProgressStyle(description_widt…

odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unar

odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unar

HBox(children=(IntProgress(value=0, description='Current Batch:', max=54, style=ProgressStyle(description_widt…

odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unar

odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])


HBox(children=(IntProgress(value=0, description='Current Batch:', max=54, style=ProgressStyle(description_widt…

odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unar

odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unar

HBox(children=(IntProgress(value=0, description='Current Batch:', max=54, style=ProgressStyle(description_widt…

odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unary_kl_loss', <function func_list_to_dict.<locals>.<lambda> at 0x2aab68b93a70>)])
odict_items([('mean_cond_kl_loss', <function mean_cond_kl_loss at 0x2aab68b93290>), ('0.01*mean_unar

In [7]:
rec_col.train

Unnamed: 0,0.01*mean_unary_kl_loss,loss,mean_cond_kl_loss,metric_hard_accuracy,metric_hard_f1,metric_spearman_r
0.0,0.000489,7.991005,7.990516,0.0,,0.038463
1.0,0.000395,5.354831,5.354435,0.0,,0.720873
2.0,0.000329,3.562744,3.562414,0.0,,0.783711
3.0,0.000283,2.447748,2.447465,0.001162,0.002322,0.800982
4.0,0.000252,1.770337,1.770085,0.038361,0.073887,0.809101
5.0,0.000232,1.347736,1.347505,0.213601,0.352012,0.813386
6.0,0.000218,1.07259,1.072372,0.438826,0.609978,0.815888
7.0,0.000208,0.884799,0.884591,0.594013,0.745305,0.817628
8.0,0.000202,0.749176,0.748974,0.690788,0.817119,0.818795
9.0,0.000198,0.652249,0.652052,0.750654,0.85757,0.819491


In [8]:
rec_col.dev

Unnamed: 0,0.01*mean_unary_kl_loss,loss,mean_cond_kl_loss,metric_hard_accuracy,metric_hard_f1,metric_pearson_r
0.0,0.000489,4.002032,4.001543,0.5,,-0.012014
1.0,0.000395,2.795372,2.794977,0.5,,0.526993
2.0,0.000329,1.963452,1.963123,0.5,,0.596902
3.0,0.000283,1.439865,1.439582,0.500202,0.000808,0.650625
4.0,0.000252,1.121082,1.120829,0.513546,0.052756,0.694091
5.0,0.000232,0.923374,0.923143,0.57036,0.247788,0.727191
6.0,0.000218,0.798089,0.797872,0.636474,0.430292,0.750311
7.0,0.000208,0.715076,0.714867,0.686413,0.545561,0.766012
8.0,0.000202,0.65692,0.656718,0.722604,0.61804,0.777236
9.0,0.000198,0.618716,0.618518,0.740194,0.65072,0.784891
