In [1]:
%load_ext autoreload
%autoreload 2
from boxes import *
from learner import *
import math



In [2]:
PATH = '../data/wordnet/mjb/rigorous_sampling/mammal_pos_0.5_neg_10_samp_uniform_num_1182/0/'
# Data in unary.tsv are probabilites separated by newlines. The probability on line n is P(n), where n is the id assigned to the nth element.
unary_prob = torch.from_numpy(np.loadtxt(f'{PATH}train_tc_unary.tsv')).float().to("cpu")
num_boxes = unary_prob.shape[0]

# We're going to use random negative sampling during training, so no need to include negatives in our training data itself
train = Probs.load_from_julia(PATH, 'train_tc_pos.tsv', 'train_neg.tsv', ratio_neg = 0).to("cpu")

# The dev set will have a fixed set of negatives, however.
dev = Probs.load_from_julia(PATH, 'dev_pos.tsv', 'dev_neg.tsv', ratio_neg = 1).to("cpu")

In [4]:
box_model = BoxModel(
    BoxParamType=MinMaxSigmoidBoxes,
    vol_func=soft_volume,
    num_models=1,
    num_boxes=num_boxes,
    dims=50,
    method="orig").to("cpu")

train_dl = TensorDataLoader(train, batch_size=2**6, shuffle=True)

opt = torch.optim.Adam(box_model.parameters(), lr=1e-2)

In [5]:
def mean_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

# See boxes/loss_functions.py file for more options. Note that you may have to changed them to fit your use case.
# Also note that "kl_div_sym" is just binary cross-entropy.

In [6]:
# For this dataset we had unary probabilities as well as conditional probabilities. Our loss function will be a sum of these, which is provided by the following loss function wrapper:
loss_func = LossPieces(mean_cond_kl_loss, (1e-2, mean_unary_kl_loss(unary_prob)))

metrics = [metric_hard_accuracy, metric_hard_f1]

rec_col = RecorderCollection()

callbacks = CallbackCollection(
    LossCallback(rec_col.train, train),
    LossCallback(rec_col.dev, dev),
    *(MetricCallback(rec_col.dev, dev, m) for m in metrics),
    *(MetricCallback(rec_col.train, train, m) for m in metrics),
    MetricCallback(rec_col.dev, dev, metric_pearson_r),
    MetricCallback(rec_col.train, dev, metric_spearman_r),
    PercentIncreaseEarlyStopping(rec_col.dev, "mean_cond_kl_loss", 0.25, 10),
    PercentIncreaseEarlyStopping(rec_col.dev, "mean_cond_kl_loss", 0.5),
#     GradientClipping(-1000,1000),
    RandomNegativeSampling(num_boxes, 1),
    StopIfNaN(),
)

l = Learner(train_dl, box_model, loss_func, opt, callbacks, recorder = rec_col.learn)

In [None]:
l.train(10)

HBox(children=(IntProgress(value=0, description='Overall Training:', max=10, style=ProgressStyle(description_w…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=54, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=54, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=54, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=54, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=54, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=54, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=54, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=54, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=54, style=ProgressStyle(description_widt…

In [7]:
rec_col.train

Unnamed: 0,0.01*mean_unary_kl_loss,loss,mean_cond_kl_loss,metric_hard_accuracy,metric_hard_f1,metric_spearman_r
0.0,0.000495,8.190474,8.189979,0.0,,-0.047588
1.0,0.000399,5.486872,5.486473,0.0,,0.718769
2.0,0.000332,3.633273,3.632942,0.0,,0.78588
3.0,0.000285,2.478737,2.478452,0.001453,0.002902,0.803919
4.0,0.000253,1.780278,1.780025,0.032258,0.0625,0.812009
5.0,0.000233,1.348903,1.34867,0.198198,0.330827,0.816082
6.0,0.000219,1.069617,1.069398,0.415867,0.587438,0.818446
7.0,0.000209,0.879246,0.879037,0.589654,0.741865,0.819667
8.0,0.000203,0.746561,0.746359,0.694566,0.819756,0.820307
9.0,0.000199,0.649739,0.64954,0.75356,0.859463,0.820281


In [8]:
rec_col.dev

Unnamed: 0,0.01*mean_unary_kl_loss,loss,mean_cond_kl_loss,metric_hard_accuracy,metric_hard_f1,metric_pearson_r
0.0,0.000495,4.094842,4.094347,0.5,,-0.060927
1.0,0.000399,2.855434,2.855035,0.5,,0.536233
2.0,0.000332,1.998673,1.998341,0.5,,0.608633
3.0,0.000285,1.45979,1.459505,0.500202,0.000808,0.65778
4.0,0.000253,1.130311,1.130058,0.5093,0.036522,0.700758
5.0,0.000233,0.927085,0.926852,0.560857,0.218143,0.731539
6.0,0.000219,0.798469,0.79825,0.631217,0.416507,0.753592
7.0,0.000209,0.714136,0.713927,0.682167,0.535186,0.769105
8.0,0.000203,0.657529,0.657326,0.715528,0.603997,0.780063
9.0,0.000199,0.620478,0.62028,0.734735,0.640941,0.787405
