In [3]:

import time
import wandb
import torch
from torch.utils.data import DataLoader
import math
import argparse
import os
import json
from utils import *
from dataset import *

torch.cuda.set_device(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [4]:
parser = argparse.ArgumentParser()

parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training (eg. no nvidia GPU)')
parser.add_argument('--epochs', type=int, default=1000, help='number of epochs to train')
# model parameters
parser.add_argument('--model', type=str, default='softbox', help='model type: choose from softbox, gumbel')
parser.add_argument('--box_embedding_dim', type=int, default=40, help='box embedding dimension')
parser.add_argument('--softplus_temp', type=float, default=1.0, help='beta of softplus function')
# gumbel box parameters
parser.add_argument('--gumbel_beta', type=float, default=1.0, help='beta value for gumbel distribution')
parser.add_argument('--scale', type=float, default=1.0, help='scale value for gumbel distribution')

args = parser.parse_args(args=[] )
args.save_to = "./checkpoints/" + args.model


In [5]:



eps = 1e-23
def l2_side_regularizer(box, log_scale: bool = True):
    """Applies l2 regularization on all sides of all boxes and returns the sum.
    """
    min_x = box.min_embed 
    delta_x = box.delta_embed  

    if not log_scale:
        return torch.mean(delta_x ** 2)
    else:
        return torch.mean(F.relu(min_x + delta_x - 1 + eps )) +  torch.mean(F.relu(-min_x - eps)) #+ F.relu(torch.norm(min_x, p=2)-1)
    
    

In [None]:
from box_model import BoxEL 
from box_model import BoxEL as Model
import torch
from data_loader import load_cls,load_valid_data
from data_loader import load_data
from data_loader import Generator
from evaluation import compute_mean_rank,compute_rank,compute_accuracy


In [None]:

dataset = 'GALEN'

total_sub_cls=[]
train_file = f"data/{dataset}/{dataset}_train.txt"
va_file = f"data/{dataset}/{dataset}_valid.txt"
test_file = f"data/{dataset}/{dataset}_test.txt"
train_sub_cls,train_samples = load_cls(train_file)
valid_sub_cls,valid_samples = load_cls(va_file)
test_sub_cls,test_samples = load_cls(test_file)
total_sub_cls = train_sub_cls + valid_sub_cls + test_sub_cls
all_subcls = list(set(total_sub_cls))

print("Training data samples:",train_samples)
print("Training data classes:",len(train_sub_cls))

gdata_file=f"data/{dataset}/{dataset}_latest_norm_mod.owl"
train_data, classes, relations = load_data(gdata_file)
valid_data_file=f"data/{dataset}/{dataset}_valid.txt"
valid_data = load_valid_data(valid_data_file, classes, relations)
valid_data = torch.Tensor(valid_data).long().to(device)


Training data samples: 19511
Training data classes: 18290


In [106]:
train_data

{'nf1': array([[21499,     0,  6102],
        [15730,     0, 22324],
        [ 5727,     0, 19714],
        ...,
        [22156,     0,  1050],
        [  341,     0,  1171],
        [12441,     0,  9906]]), 'nf2': array([[16603,    54, 20427],
        [ 7789,  6948,  3171],
        [15112,  1435,  2670],
        ...,
        [ 7510,   826, 21574],
        [    3,  8633,  6233],
        [ 8616,   536,  7936]]), 'nf3': array([[15641,     8,   190],
        [ 8868,    11,  2333],
        [16621,    27,  3046],
        ...,
        [  688,   143,  2890],
        [10425,     8, 18603],
        [ 7068,   136,  6205]]), 'nf4': array([[   57, 16006, 17463],
        [   35,  7247, 13880],
        [  118, 10623,  4373],
        ...,
        [    6, 15205, 14928],
        [   35, 15171, 19209],
        [   68, 14056, 16778]]), 'disjoint': array([[0, 0]]), 'nf_inclusion': array([[ 24, 501],
        [753,  28],
        [261, 512],
        ...,
        [ 78, 933],
        [291, 294],
        [793, 

In [109]:
batch_size = 256

nb_classes = len(classes)
nb_relations = len(relations)

print("no. classes:",nb_classes)
print("no. relations:",nb_relations)
nb_train_data = 0

for key, val in train_data.items():
    nb_train_data = max(len(val), nb_train_data)
train_steps = int(math.ceil(nb_train_data / (1.0 * batch_size)))
train_generator = Generator(train_data, batch_size, steps=train_steps)

cls_dict = {v: k for k, v in classes.items()}
rel_dict = {v: k for k, v in relations.items()}

cls_list = []
rel_list = []
for i in range(nb_classes):
    cls_list.append(cls_dict[i])
for i in range(nb_relations):
    rel_list.append(rel_dict[i])

classes_index = list(classes.values())
classes_index = torch.Tensor(classes_index).to(device).reshape(-1,1).long()


no. classes: 24352
no. relations: 1010


In [110]:
# 
wandb.init(project="basic_box", reinit=True, config=args)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train loss,49.82947
nf1_loss,0.21455
nf2_loss,0.17634
nf3_loss,0.10427
nf4_loss,0.05738
mean_rank,830.44
valid_accuracy,0.99318
_runtime,1350.0
_timestamp,1633959464.0
_step,362.0


0,1
train loss,██▆▅▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
nf1_loss,▁▅█▆▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂
nf2_loss,▁▅█▆▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
nf3_loss,▁▆█▅▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▂▁▁▁▁▁▂▂▂▁▁▁▁▁▁▂▁
nf4_loss,▁▆█▅▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▂▁▁▁▂▁▂▁▁▁▁▂▁
mean_rank,█▇▅▄▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_accuracy,█▃▁▂▅▅▅▅▆▆▆▇▇▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


wandb: wandb version 0.12.4 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


In [114]:

dimension = 50
learning_rate = 0.001
seed = 1111

torch.manual_seed(seed)
import random
random.seed(seed)

model = BoxEL(nb_classes, nb_relations, dimension, [1e-4, 0.2], [-0.1, 0], [-0.1,0.1], [0.9,1.1], args).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
wandb.watch(model)


[]

In [116]:

model.train()
steps_per_epoch = train_steps
train_generator.reset()

for epoch in range(args.epochs):
    train_loss = 0.0
    nf1_total_loss = 0.0
    nf2_total_loss = 0.0
    nf3_total_loss = 0.0
    nf4_total_loss = 0.0
    disjoint_total_loss = 0.0
    nf1_neg_total_loss = 0.0
    role_chain_total_loss = 0.0
    role_inclusion_total_loss = 0.0
    
    for step, batch in enumerate(train_generator):
        if step < steps_per_epoch:
            nf1_loss, nf1_neg_loss, nf2_loss, nf3_loss, nf4_loss, disjoint_loss, role_inclusion_loss, role_chain_loss, nf1_reg_loss, nf2_reg_loss , nf3_reg_loss , nf4_reg_loss , disjoint_reg_loss, nf1_neg_reg_loss = model(batch[0])
            loss =  nf1_loss + nf1_reg_loss + nf1_neg_loss + nf1_neg_reg_loss+ nf2_loss + nf2_reg_loss + nf3_loss + nf3_reg_loss + nf4_loss + nf4_reg_loss +  role_inclusion_loss + role_chain_loss
            assert torch.isnan(loss).sum() == 0
            optimizer.zero_grad()
            loss.backward()
            assert torch.isnan(model.min_embedding).sum() == 0
            optimizer.step()
            assert torch.isnan(model.min_embedding).sum() == 0
            assert torch.isnan(model.min_embedding.grad).sum() == 0
            train_loss += loss
            nf1_total_loss += nf1_loss
            nf2_total_loss += nf2_loss
            nf3_total_loss += nf3_loss
            nf4_total_loss += nf4_loss
            nf1_neg_total_loss += nf1_neg_loss
            role_inclusion_total_loss += role_inclusion_loss
            role_chain_total_loss += role_inclusion_loss
            
    
        else:
            train_generator.reset()
            break
            
    mean_rank = compute_mean_rank(model, valid_data[0:100])
    valid_accuracy = compute_accuracy(model, valid_data)
    
    wandb.log({'train loss': train_loss.item()/(step+1),'nf1_loss':nf1_total_loss.item()/(step+1), 
               'nf2_loss':nf2_total_loss.item()/(step+1), 'nf3_loss':nf3_total_loss.item()/(step+1), 
               'nf4_loss':nf4_total_loss.item()/(step+1), 'nf1_neg_loss':nf1_neg_total_loss.item()/(step+1),
               'role_inclusion_loss':role_inclusion_total_loss.item()/(step+1),
               'role_chain_loss':role_chain_total_loss.item()/(step+1),
               'mean_rank':mean_rank,'valid_accuracy':valid_accuracy })

    PATH = './models/box_el/dim' + str(dimension) + "_lr" + str(learning_rate) + '_batch'+str(batch_size)+'_seed'+ str(seed)+ '.pt'

    if epoch % 100 == 0:
        torch.save(model, PATH)
        print('Epoch:%d' %(epoch + 1), "Train loss: %f" %(train_loss.item()/(step+1)), f'Valid Mean Rank: {mean_rank}\n')



Epoch:1 Train loss: 391.558575 Valid Mean Rank: 12176.5

Epoch:101 Train loss: 80.678258 Valid Mean Rank: 2326.455

Epoch:201 Train loss: 60.651415 Valid Mean Rank: 1381.465

Epoch:301 Train loss: 52.412652 Valid Mean Rank: 1049.61

Epoch:401 Train loss: 48.849852 Valid Mean Rank: 975.51

Epoch:501 Train loss: 46.280659 Valid Mean Rank: 922.1

Epoch:601 Train loss: 45.127429 Valid Mean Rank: 903.185

Epoch:701 Train loss: 44.549500 Valid Mean Rank: 872.805

Epoch:801 Train loss: 43.797099 Valid Mean Rank: 841.485

Epoch:901 Train loss: 42.983799 Valid Mean Rank: 849.08



In [87]:

test_file = f'data/{dataset}/{dataset}_test.txt'
test_data = load_valid_data(test_file,classes,relations)
test_data = torch.Tensor(test_data).long().to(device)

test_data


tensor([[ 8158,     0,   221],
        [14869,     0,   647],
        [ 3469,     0, 10560],
        ...,
        [18267,     0,  7361],
        [ 1337,     0,  9386],
        [ 9959,     0,   336]], device='cuda:3')

In [86]:
(top1,top10, top100, mean_rank, median_rank, per_rank, auc,rank_values), acc = compute_rank(model, test_data, 90), compute_accuracy(model, test_data)
print(top1,top10, top100, mean_rank, median_rank, per_rank, auc, acc)


0.0 0.0 0.0 11499.770982783357 11544.5 11569.5 0.5010938176341759
0.9917503586800573


In [54]:

# seed 3333
PATH = './models/box_el_galen_50_seed3333.pt'
model = torch.load(PATH)

top1,top10, top100, mean_rank, median_rank, per_rank, auc,rank_values = compute_rank(model, test_data, 90)
print(top1,top10, top100, mean_rank, median_rank, per_rank, auc)


0.0 0.006456241032998565 0.10043041606886657 2980.317969870875 1051.25 9149.200000000003 0.8703737719912009
