In [None]:
%load_ext autoreload
%autoreload 2
from boxes import *
from learner import *
import math
import matplotlib.pyplot as plt
import os

%matplotlib inline

In [2]:
PATH = '../data/ontologies/anatomy/'

# aligment training split
ats = 0.8

# Transitive closure
Transitive_Closure = False

if Transitive_Closure:
    tc = "tc_"
else:
    tc = ""

# Data in unary.tsv are probabilites separated by newlines. The probability on line n is P(n), where n is the id assigned to the nth element.
unary_prob = torch.from_numpy(np.loadtxt(f'{PATH}unary/unary.tsv')).float().to("cpu")
num_boxes = unary_prob.shape[0]

# We're going to use random negative sampling during training, so no need to include negatives in our training data itself
train = Probs.load_from_julia(PATH, f'tr_pos_{tc}{ats}.tsv', f'tr_neg_{ats}.tsv', ratio_neg = 0).to("cpu")

# The dev set will have a fixed set of negatives, however.
dev = Probs.load_from_julia(PATH, f'dev_align_pos_{ats}.tsv', f'dev_align_neg_{ats}.tsv', ratio_neg = 1).to("cpu")

# This set is used just for evaluation purposes after training
tr_align = Probs.load_from_julia(PATH, f'tr_align_pos_{ats}.tsv', f'tr_align_neg_{ats}.tsv', ratio_neg = 1).to("cpu")

In [3]:
# mouse_eval = Probs.load_from_julia(PATH, 'human_dev_pos.tsv', 'human_dev_neg.tsv', ratio_neg = 1).to("cuda")
# human_eval = Probs.load_from_julia(PATH, 'mouse_dev_pos.tsv', 'mouse_dev_neg.tsv', ratio_neg = 1).to("cuda")

In [5]:
dims = 10
lr = 1e-2
nEpochs = 20
rns_ratio = 1
box_type = MinMaxSigmoidBoxes
use_unary = False
unary_weight = 1e-2

box_model = BoxModel(
    BoxParamType=box_type,
    vol_func=soft_volume,
    num_models=1,
    num_boxes=num_boxes,
    dims=dims,
    method="orig").to("cpu")

#### IF YOU ARE LOADING FROM JULIA WITH ratio_neg=0, train_dl WILL ONLY CONTAIN POSITIVE EXAMPLES
#### THIS MEANS YOUR MODEL SHOULD USE NEGATIVE SAMPLING DURING TRAINING
train_dl = TensorDataLoader(train, batch_size=2**6, shuffle=True)

# mouse_dl = TensorDataLoader(mouse_eval, batch_size=2**6)
# human_dl = TensorDataLoader(human_eval, batch_size=2**6)

# eval_dl = [mouse_dl, human_dl]

opt = torch.optim.Adam(box_model.parameters(), lr=lr)

In [6]:
def mean_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

def human_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

def mouse_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

def align_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

# See boxes/loss_functions.py file for more options. Note that you may have to changed them to fit your use case.
# Also note that "kl_div_sym" is just binary cross-entropy.

In [7]:
# For this dataset we had unary probabilities as well as conditional probabilities. Our loss function will be a sum of these, which is provided by the following loss function wrapper:

# if use_unary:
#     loss_func = LossPieces(mean_cond_kl_loss, (unary_weight, mean_unary_kl_loss(unary_prob)))
# else:
#     loss_func = LossPieces(mean_cond_kl_loss)

loss_func = LossPieces( mouse_cond_kl_loss, human_cond_kl_loss, (5e-2,align_cond_kl_loss))

metrics = [metric_hard_accuracy, metric_hard_f1]
align_metrics = [metric_hard_accuracy_align, metric_hard_f1_align, metric_hard_accuracy_align_mean, metric_hard_f1_align_mean]

rec_col = RecorderCollection()

threshold = np.arange(0.1, 1, 0.1)

callbacks = CallbackCollection(
    LossCallback(rec_col.train, train),
    LossCallback(rec_col.dev, dev),
    *(MetricCallback(rec_col.dev, dev, m) for m in metrics),
    *(MetricCallback(rec_col.train, train, m) for m in metrics),
#     *(MetricCallback(rec_col.onto, human_eval, m) for m in metrics),
#     *(MetricCallback(rec_col.onto, mouse_eval, m) for m in metrics),
    *(EvalAlignment(rec_col.train_align, tr_align, m) for m in align_metrics),
    *(EvalAlignment(rec_col.dev_align, dev, m) for m in align_metrics),
    JustGiveMeTheData(rec_col.probs, dev, get_probabilities),
    BiasMetric(rec_col.bias, dev, pct_of_align_cond_on_human_as_min),
    PlotMetrics(rec_col.dev_roc_plot, dev, roc_plot),
    PlotMetrics(rec_col.dev_pr_plot, dev, pr_plot),
    PlotMetrics(rec_col.tr_roc_plot, tr_align, roc_plot),
    PlotMetrics(rec_col.tr_pr_plot, tr_align, pr_plot),
    MetricCallback(rec_col.dev, dev, metric_pearson_r),
    MetricCallback(rec_col.train, dev, metric_spearman_r),
#     PercentIncreaseEarlyStopping(rec_col.dev, "mean_cond_kl_loss", 0.25, 10),
#     PercentIncreaseEarlyStopping(rec_col.dev, "mean_cond_kl_loss", 0.5),
#     PercentIncreaseEarlyStopping(rec_col.dev, "mouse_cond_kl_loss", 0.25, 10),
#     PercentIncreaseEarlyStopping(rec_col.dev, "mouse_cond_kl_loss", 0.5),
#     GradientClipping(-1000,1000),
    RandomNegativeSampling(num_boxes, rns_ratio),
    StopIfNaN(),
)

# l = Learner(train_dl, box_model, loss_func, opt, callbacks, recorder = rec_col.learn)
l = Learner(train_dl, box_model, loss_func, opt, callbacks, recorder = rec_col.learn, categories=True)

In [None]:
l.train(nEpochs)

HBox(children=(IntProgress(value=0, description='Overall Training:', max=20, style=ProgressStyle(description_w…

HBox(children=(IntProgress(value=0, description='Current Batch:', max=177, style=ProgressStyle(description_wid…

In [None]:
l.evaluation(threshold)

In [None]:
rec_col.train

In [None]:
rec_col.dev

In [None]:
rec_col.train_align

In [None]:
rec_col.dev_align

In [None]:
# $$$$_1 refers to mouse, the metrics without a subscript _1, refer to the human ontology

rec_col.onto

In [None]:
rec_col.bias

In [None]:
rec_col.probs

## Make plots

In [None]:
df_train = rec_col.train_align.dataframe

fig1 = plt.figure(num=1, figsize=(10,8), dpi=80, facecolor='white')
plt.plot(df_train, linewidth=3)
plt.ylim((0,1))
plt.xlabel("Alignment Factor")
plt.ylabel("Score")
plt.legend(df_train, loc=0)

# plt.show()

In [None]:
df_dev = rec_col.dev_align.dataframe

fig2 = plt.figure(num=2, figsize=(10,8), dpi=80, facecolor='white')
plt.plot(df_dev, linewidth=3)
plt.ylim((0,1))
plt.xlabel("Alignment Factor")
plt.ylabel("Score")
plt.legend(df_train, loc=0)

# plt.show()

In [None]:
fig3 = plt.figure(num=3, figsize=(10,8), dpi=80, facecolor='white')
plt.plot(rec_col.dev_roc_plot.dataframe['fpr'], rec_col.dev_roc_plot.dataframe['tpr'], linewidth=3)
plt.plot(rec_col.tr_roc_plot.dataframe['fpr'], rec_col.tr_roc_plot.dataframe['tpr'], color='g', linewidth=3)
plt.plot([0, 1], [0, 1], linestyle='--', color='xkcd:orange', linewidth=3)
plt.ylim((0,1.05))
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend(("Box Model - Dev Alignments", "Box Model - Train Alignments", "No Skill"), loc=0)

# plt.show()

In [None]:
no_skill_pr = dev.probs[dev.probs==1].float().sum()  / dev.probs.shape[0]

fig4 = plt.figure(num=4, figsize=(10,8), dpi=80, facecolor='white')
plt.plot(rec_col.dev_pr_plot.dataframe['recall'], rec_col.dev_pr_plot.dataframe['precision'], linewidth=3)
plt.plot(rec_col.tr_pr_plot.dataframe['recall'], rec_col.tr_pr_plot.dataframe['precision'], color='g', linewidth=3)
plt.plot([0, 1], [no_skill_pr, no_skill_pr], linestyle='--', color='xkcd:orange', linewidth=3)
plt.ylim((0,1.05))
plt.xlabel("Precision")
plt.ylabel("Recall")
plt.legend(("Box Model - Dev Alignments", "Box Model - Train Alignments", "No Skill"), loc=0)

# plt.show()

In [None]:
# average_min_probability = np.mean(rec_col.probs.dataframe['Minimum Probablity'])
# align_pair_probs = np.stack((rec_col.probs.dataframe['Alignment 1 Probablity'], 
#                              rec_col.probs.dataframe['Alignment 2 Probablity']), axis=1)

# average_align_pair_probs = np.mean(align_pair_probs, axis=1)
# print(average_align_pair_probs.shape)

# fig = plt.figure(figsize=(10,8), dpi=80, facecolor='white')
# plt.plot(range(0, 606), rec_col.probs.dataframe['Alignment 1 Probablity'])
# plt.plot(range(0, 606), rec_col.probs.dataframe['Alignment 2 Probablity'])
# plt.plot(range(0, 606), average_align_pair_probs)
# # plt.plot(range(0, 606), rec_col.probs.dataframe['Minimum Probablity'])
# # plt.plot([0, 606], [average_min_probability, average_min_probability], linestyle='--')
# plt.ylim((0,1))
# plt.xlim((200,300))
# plt.xlabel("Alignment Number")
# plt.ylabel("Probability")
# plt.legend(("Alignment Conditional Probability 1", 
#             "Alignment Conditional Probability 2", 
# # #             "Minimum Probability", "Average Minimum Probability"
#            ), loc=0)

# plt.show()


In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(box_model)

### Save the model

In [None]:
fpath =  f"../results/{box_type.__name__}_nEpochs{nEpochs}_lr{lr}_dims{dims}_ratio{rns_ratio}_{tc}split{ats}_Unary{use_unary}{unary_weight}/"

# This is the most naive implementation
def save_recorders():
    rec_col.learn.dataframe.to_csv(f"{fpath}learn.csv")
    rec_col.train.dataframe.to_csv(f"{fpath}train.csv")
    rec_col.dev.dataframe.to_csv(f"{fpath}dev.csv")
    rec_col.onto.dataframe.to_csv(f"{fpath}onto.csv")
    rec_col.train_align.dataframe.to_csv(f"{fpath}train_align.csv")
    rec_col.dev_align.dataframe.to_csv(f"{fpath}dev_align.csv")
    rec_col.dev_roc_plot.dataframe.to_csv(f"{fpath}dev_roc_plot.csv")
    rec_col.dev_pr_plot.dataframe.to_csv(f"{fpath}dev_pr_plot.csv")
    rec_col.tr_roc_plot.dataframe.to_csv(f"{fpath}tr_roc_plot.csv")
    rec_col.tr_pr_plot.dataframe.to_csv(f"{fpath}tr_pr_plot.csv")
    rec_col.probs.dataframe.to_csv(f"{fpath}probs.csv")
    rec_col.bias.dataframe.to_csv(f"{fpath}bias.csv")
    
def save_plots():
    fig1.savefig(f"{fpath}train_alignment_plot.png")
    fig2.savefig(f"{fpath}dev_alignment_plot.png")
    fig3.savefig(f"{fpath}roc_curve.png")
    fig4.savefig(f"{fpath}pr_curve.png")


try:
    os.makedirs(fpath)
except OSError:
    print("Creation of the directory %s failed" % fpath)
    print("Did not save any of the files.")
else:
    print("Successfully created the directory %s " % fpath)
    print("Saving files ...")
    
    fmodel = f"{fpath}model.pth"
    
    save_model = {}
    save_model['state_dict'] = box_model.state_dict()
    save_model['optimizer']  = opt.state_dict()
    save_model['nEpochs']  = nEpochs
    save_model['recorders'] = rec_col
    save_model['train'] = train
    save_model['tr_align'] = tr_align
    save_model['dev'] = dev
    torch.save(save_model, fmodel)
    
    save_recorders()
    save_plots()
    
    print("Save complete")