In [1]:
import torch
import pickle
from torch import Tensor
from torch.nn import Module, Parameter
import torch.nn.functional as F
import numpy as np
from numpy import genfromtxt
from sklearn.metrics import roc_curve, precision_recall_curve
import matplotlib.pyplot as plt
from learner import *
from boxes import *



In [3]:
PATH = "../data/ontologies/anatomy/"

with open(f'{PATH}human.pickle', 'rb') as handle:
    human = pickle.load(handle)
    
with open(f'{PATH}mouse.pickle', 'rb') as handle:
    mouse = pickle.load(handle)

with open(f'{PATH}entities.pickle', 'rb') as handle:
    entities = pickle.load(handle)
    
print(mouse.keys())
print(human.keys())
print(entities.keys())

dict_keys(['edges', 'parents_of', 'children_of', 'mouse_entities'])
dict_keys(['edges', 'parents_of', 'children_of', 'human_entities'])
dict_keys(['all_edges', 'alignments', 'name2idx', 'idx2name', 'label2idx', 'idx2label', 'set', 'align_dict'])


### Separating the data by edge type: human, mouse, or alignment (across ontology)

In [4]:
pos_edges = np.loadtxt('../data/ontologies/anatomy/tr_pos_0.8.tsv', delimiter='\t', dtype=int)
neg_edges = np.loadtxt('../data/ontologies/anatomy/tr_neg_0.8.tsv', delimiter='\t', dtype=int)

In [5]:
pos_out = np.ones(pos_edges.shape[0], dtype=int)
neg_out = np.zeros(neg_edges.shape[0], dtype=int)

data_in = torch.from_numpy(np.concatenate((pos_edges, neg_edges[:len(pos_edges)]), axis=0))
data_out = torch.from_numpy(np.concatenate((pos_out, neg_out[:len(pos_out)]), axis=0))

In [6]:
sample = torch.tensor([[1,2],[5,6],[2,4],[6,7],[2,4],[3,6],[1,3]])
print(sample)

sample_class = sample > 3

category = torch.zeros_like(sample[:,0])

for i, (a,b) in enumerate(sample_class):
    if a and b:
        category[i] = 0
    elif not a and not b:
        category[i] = 1
    else:
        category[i] = 2
        
category

tensor([[1, 2],
        [5, 6],
        [2, 4],
        [6, 7],
        [2, 4],
        [3, 6],
        [1, 3]])


tensor([1, 0, 2, 0, 2, 2, 1])

In [7]:
# mouse upper bound
mub = max(mouse['mouse_entities'])

dl = DataLoader(pos_edges, batch_size=4, shuffle=True)

def loss_categories(batch):
    category = torch.zeros(size=(batch.shape[0],), dtype=int)
    
    batch_class = batch > mub

    for i, (a,b) in enumerate(batch_class):
        if a and b:
            category[i] = 0
        elif not a and not b:
            category[i] = 1
        else:
            category[i] = 2

    return category

batches = [batch for batch in dl]

test_batch = batches[0]
test_category = loss_categories(batches[0])

In [8]:
test_mouse = test_batch[test_category == 0]
test_human = test_batch[test_category == 1]
test_align = test_batch[test_category == 2]

In [9]:
torch.empty((0,2), dtype=torch.int32)

tensor([], size=(0, 2), dtype=torch.int32)

In [19]:
nantensor = torch.tensor(float('nan'), requires_grad=True)
nantensor

tensor(nan, requires_grad=True)

In [21]:
torch.rand(size=(1,))

tensor([0.1462])

In [13]:
'mouse' in "align_cond_kl_loss"

False

In [14]:
def isnan(x):
    return (x != x)

if isnan(a):
    print("true")

In [15]:
some_dict = {'a':2, 'b':torch.tensor(float('nan')), 'c':4}
for k,v in some_dict.items():
    print(v)

2
tensor(nan)
4


In [16]:
if len(test_align) > 0:
    print('hi')

hi


In [17]:
for cat in test_category:
    print(cat)

tensor(2)
tensor(0)
tensor(0)
tensor(1)


In [18]:
print(category[:,0])
print(category[:,1])

IndexError: too many indices for tensor of dimension 1

In [None]:
data_out.shape

In [None]:
dims = 100

init_min_vol = torch.finfo(torch.float32).tiny
# init_min_vol = 0.00000000000000000000000000000000001

per_dim_min = torch.tensor(init_min_vol).pow(1/dims)

print(init_min_vol, per_dim_min)

In [None]:
negatives = genfromtxt('../data/ontologies/dev_align_neg_0.8.tsv', delimiter='\t')
positives = genfromtxt('../data/ontologies/dev_align_pos_0.8.tsv', delimiter='\t')

neg_out = np.zeros(negatives.shape[0])
pos_out = np.ones(negatives.shape[0])

data_in  = torch.from_numpy(np.concatenate((positives, negatives), axis=0))
data_out = torch.from_numpy(np.concatenate((pos_out, neg_out), axis=0))

# model_in = torch.from_numpy(model_in)
# model_out = torch.from_numpy(model_out)

## Dummy model

In [None]:
def model(data_in: Tensor):
    return torch.rand(data_in.shape[0])

# model(model_in)

### Structure for comparing pairs of alignments

In [None]:
# Getting every pairwise

print(data_in.shape)

A_given_B = data_in[::2]
B_given_A = data_in[1::2,:]

data_out = data_out[::2]

align_pair_in = torch.stack((A_given_B, B_given_A), dim=0)

align_pair_out = torch.stack((model(A_given_B), model(B_given_A)), dim=1)


### Different methods of comparison, taking the minimum or the mean

In [None]:
threshold = 0.5

p = torch.min(align_pair_out, dim=1)
hard_pred = p.values > threshold

p_mean = torch.mean(align_pair_out, dim=1)
hard_pred_mean = p_mean > threshold


### Collect information on the classes 
- is it always guessing one class?
- what is the ratio of the ontologies that are being chosen for the minimum?

In [None]:
range(p.values.shape[0])

In [None]:
# randt = torch.rand((4,2))
# print("probabilities", randt)

# mint = torch.min(randt, dim=1)
# print("mins", mint.indices)

# nodes1 = torch.randint(0, 10, (4,2))
# nodes2 = torch.randint(0, 10, (4,2))

# print("\n\n list of nodes")
# print(nodes1)
# print(nodes2)

# comb = torch.stack((nodes1, nodes2), dim=0)
# print(comb.shape)

# mint_new = mint.indices.repeat_interleave(randt.shape[1]).reshape(-1,randt.shape[1])

# torch.gather(comb, dim=0, index=mint_new.view(1,mint_new.shape[0],-1))

In [None]:
min_indices = p.indices.repeat_interleave(data_in.shape[1]).reshape(-1,data_in.shape[1])

min_nodes = torch.gather(align_pair_in, dim=0, index=min_indices.view(1, align_pair_in.shape[1] ,-1)).squeeze(0)

In [None]:
cnt = (min_nodes[:,1]>2737).sum().float()

cnt / align_pair_in.shape[1]

In [None]:
true_pos = data_out[hard_pred==1].sum()

In [None]:
total_actual_pos = data_out[data_out==1].shape[0]
total_actual_neg = data_out[data_out==0].shape[0]

In [None]:
total_pred_pos = (hard_pred==1).sum().float()
total_actual_neg = data_out.shape[0] - total_actual_pos


false_pos = total_pred_pos - true_pos
false_neg = total_actual_pos - true_pos
true_neg = total_actual_neg - false_pos


In [None]:
print(true_pos)
print(total_pred_pos)
print(false_pos)
print(total_actual_pos)
print(false_neg)
print(total_actual_neg)
print(true_neg)


assert true_neg + true_pos + false_neg + false_pos == data_out.shape[0]
print(true_neg + true_pos + false_neg + false_pos)


### Get various pieces of information on the datasets -- based on varying the threshold value

In [None]:
fpr, tpr, thresholds1 = roc_curve(y_true=data_out, y_score=p)
# thresholds1.shape

precision, recall, thresholds2 = precision_recall_curve(y_true=data_out, probas_pred=p)

### Save the data from above in dataframes

In [None]:
rec_col = RecorderCollection()

rec_col.roc_plot.update_({'fpr':fpr}, thresholds1)
rec_col.roc_plot.update_({'tpr':tpr}, thresholds1)

rec_col.pr_plot.update_({'precision':precision[0:-1]}, thresholds2)
rec_col.pr_plot.update_({'recall':recall[0:-1]}, thresholds2)

In [None]:
# rec_col.roc_plot.dataframe
# rec_col.pr_plot.dataframe

## Make some plots

In [None]:
# fig = plt.figure(figsize=(10,8), dpi=80, facecolor='white')

# plt.plot(rec_col.roc_plot.dataframe['fpr'],rec_col.roc_plot.dataframe['tpr'])

In [None]:
# fig = plt.figure(figsize=(10,8), dpi=80, facecolor='white')

# plt.plot(rec_col.pr_plot.dataframe['recall'], rec_col.pr_plot.dataframe['precision'])

In [None]:
# hard_pred.sum()
# hard_pred_mean.sum()
# the_means.shape

In [None]:
# m = model_out[0:10]
# h = hard_pred[0:10]

# mlast = model_out[-10:]
# hlast = hard_pred[-10:]

# print(mlast)
# print(hlast)
# print(align_pairs[-10:,:])

# true_pos = m[h==1]
# true_pos = mlast[hlast==1]

# print(true_pos)

true_pos = model_out[hard_pred==1].sum()
print(true_pos)

In [None]:
a = torch.range(start=1,  end=12, dtype=int).reshape((3,4))
b = torch.range(start=13, end=24, dtype=int).reshape((3,4))

print(a, a.shape,'\n', b, b.shape)

c = torch.stack((a,b),dim=-3)
print(c, c.shape)

print(c[:,:,0])

In [None]:
# model_out = model_out.reshape((-1,2))
# print(model_out.shape)
# print(torch.all(torch.eq(model_out[:,0], model_out[:,1])))

# pair = torch.cat((model(A_given_B), model(B_given_A)), dim=1)
# p_out = torch.cat((A_given_B_out, B_given_A_out), dim=1)

# p = torch.min(pair, dim=1)

# p_out.shape

In [None]:
# A_given_B_pred = model(A_given_B) > 0.5
# B_given_A_pred = model(B_given_A) > 0.5

# pair = torch.eq(A_given_B_pred, B_given_A_pred)
# print(A_given_B_pred[0:5])
# print(B_given_A_pred[0:5])
# print(pair[0:5])

In [None]:
# t = torch.rand((4,2))

# idxs = torch.randint(low=0, high=2, size=(4,))

# print(t)
# print(idxs)

# t.gather(dim=1, index=idxs.view(-1,1))

In [None]:
# u = torch.rand((8,))
# print(u)
# print(u.reshape(-1,2))