In [1]:
%load_ext autoreload
%autoreload 2
from boxes import *
from learner import *
import math



In [2]:
PATH = '../data/ontologies/anatomy/'
# Data in unary.tsv are probabilites separated by newlines. The probability on line n is P(n), where n is the id assigned to the nth element.
unary_prob = torch.from_numpy(np.loadtxt(f'{PATH}unary/human_unary.tsv')).float().to("cpu")
num_boxes = unary_prob.shape[0]

# We're going to use random negative sampling during training, so no need to include negatives in our training data itself
train = Probs.load_from_julia(PATH, 'individual_analysis/human_adj_tr_pos.tsv', 'individual_analysis/human_adj_tr_neg.tsv', ratio_neg = 0).to("cpu")

# The dev set will have a fixed set of negatives, however.
dev = Probs.load_from_julia(PATH, 'individual_analysis/human_adj_dev_pos.tsv', 'individual_analysis/human_adj_dev_neg.tsv', ratio_neg = 1).to("cpu")

In [3]:
box_model = BoxModel(
    BoxParamType=MinMaxSigmoidBoxes,
    vol_func=soft_volume,
    num_models=1,
    num_boxes=num_boxes,
    dims=50,
    method="orig").to("cpu")

train_dl = TensorDataLoader(train, batch_size=2**6, shuffle=True)

opt = torch.optim.Adam(box_model.parameters(), lr=1e-2)

In [4]:
def mean_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

# See boxes/loss_functions.py file for more options. Note that you may have to changed them to fit your use case.
# Also note that "kl_div_sym" is just binary cross-entropy.

In [5]:
# For this dataset we had unary probabilities as well as conditional probabilities. Our loss function will be a sum of these, which is provided by the following loss function wrapper:
loss_func = LossPieces(mean_cond_kl_loss, (1e-2, mean_unary_kl_loss(unary_prob)))

metrics = [metric_hard_accuracy, metric_hard_f1]

rec_col = RecorderCollection()

callbacks = CallbackCollection(
    LossCallback(rec_col.train, train),
    LossCallback(rec_col.dev, dev),
    *(MetricCallback(rec_col.dev, dev, m) for m in metrics),
    *(MetricCallback(rec_col.train, train, m) for m in metrics),
    MetricCallback(rec_col.dev, dev, metric_pearson_r),
    MetricCallback(rec_col.train, dev, metric_spearman_r),
    PercentIncreaseEarlyStopping(rec_col.dev, "mean_cond_kl_loss", 0.25, 10),
    PercentIncreaseEarlyStopping(rec_col.dev, "mean_cond_kl_loss", 0.5),
#     GradientClipping(-1000,1000),
    RandomNegativeSampling(num_boxes, 1),
    StopIfNaN(),
)

l = Learner(train_dl, box_model, loss_func, opt, callbacks, recorder = rec_col.learn)

In [6]:
l.train(20)

HBox(children=(IntProgress(value=0, description='Overall Training:', max=20, style=ProgressStyle(description_w…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=68, style=ProgressStyle(description_widt…




In [7]:
rec_col.train

Unnamed: 0,0.01*mean_unary_kl_loss,loss,mean_cond_kl_loss,metric_hard_accuracy,metric_hard_f1,metric_spearman_r
0.0,3e-05,8.081954,8.081924,0.0,,-0.022125
1.0,2.9e-05,7.135197,7.135168,0.0,,0.138873
2.0,2.8e-05,6.106736,6.106708,0.0,,0.241461
3.0,2.7e-05,5.155363,5.155335,0.0,,0.295832
4.0,2.7e-05,4.304075,4.304049,0.000231,0.000461,0.322863
5.0,2.6e-05,3.569152,3.569126,0.011078,0.021913,0.342317
6.0,2.6e-05,2.94199,2.941965,0.057466,0.108686,0.355511
7.0,2.6e-05,2.424571,2.424545,0.133395,0.23539,0.362881
8.0,2.5e-05,1.995613,1.995588,0.201939,0.336022,0.367635
9.0,2.5e-05,1.651311,1.651286,0.27579,0.432344,0.370271


In [8]:
rec_col.dev

Unnamed: 0,0.01*mean_unary_kl_loss,loss,mean_cond_kl_loss,metric_hard_accuracy,metric_hard_f1,metric_pearson_r
0.0,3e-05,4.048131,4.0481,0.5,,-0.004201
1.0,2.9e-05,3.779072,3.779042,0.5,,0.139024
2.0,2.8e-05,3.493037,3.493009,0.5,,0.20603
3.0,2.7e-05,3.224127,3.2241,0.5,,0.219987
4.0,2.7e-05,2.982553,2.982526,0.5,,0.235346
5.0,2.6e-05,2.774221,2.774195,0.500459,0.001833,0.251156
6.0,2.6e-05,2.598428,2.598402,0.501835,0.009124,0.271093
7.0,2.6e-05,2.462819,2.462794,0.507339,0.032432,0.28627
8.0,2.5e-05,2.35899,2.358965,0.516514,0.072183,0.296494
9.0,2.5e-05,2.278745,2.27872,0.522936,0.097222,0.308827
