In [15]:

import time
import wandb
import torch
from torch.utils.data import DataLoader
import math
import argparse
import os
import json

torch.cuda.set_device(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [16]:

parser = argparse.ArgumentParser()
parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training (eg. no nvidia GPU)')
parser.add_argument('--epochs', type=int, default=1000, help='number of epochs to train')
# model parameters
parser.add_argument('--model', type=str, default='softbox', help='model type: choose from softbox, gumbel')
parser.add_argument('--box_embedding_dim', type=int, default=40, help='box embedding dimension')
parser.add_argument('--softplus_temp', type=float, default=1.0, help='beta of softplus function')
# gumbel box parameters
parser.add_argument('--gumbel_beta', type=float, default=1.0, help='beta value for gumbel distribution')
parser.add_argument('--scale', type=float, default=1.0, help='scale value for gumbel distribution')

parser.add_argument('--dataset', type=str, default='PPI', help='dataset')
parser.add_argument('--using_rbox', type=int, default=1, help='using_rbox')
parser.add_argument('--gpu', type=int, default=1, help='gpu')

parser.add_argument('--dimension', type=int, default=50, help='dimension')
parser.add_argument('--learning_rate', type=int, default=0.001, help='learning_rate')
parser.add_argument('--batch_size', type=int, default=256, help='batch_size')
parser.add_argument('--seed', type=int, default=1111, help='seed')

args = parser.parse_args(args=['--no_cuda'] )
args.save_to = "./checkpoints/" + args.model

gpu = args.gpu
dimension = args.dimension
learning_rate = args.learning_rate
batch_size = args.batch_size
seed = args.seed
dataset = args.dataset
using_rbox = args.using_rbox

torch.cuda.set_device(2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [17]:

torch.manual_seed(seed)
import random
random.seed(seed)

import torch
from ppi_data_loader import load_cls,load_valid_data
from ppi_data_loader import load_data
from ppi_data_loader import Generator
from evaluation import compute_mean_rank,compute_rank,compute_accuracy
from box_model import BoxELPPI as Model

In [4]:

dataset = 'PPI'

total_sub_cls = []
train_file = f"data/{dataset}/{dataset}_train.txt"
va_file = f"data/{dataset}/{dataset}_valid.txt"
test_file = f"data/{dataset}/{dataset}_test.txt"
train_sub_cls, train_samples = load_cls(train_file)
valid_sub_cls, valid_samples = load_cls(va_file)
test_sub_cls,test_samples = load_cls(test_file)
total_sub_cls = train_sub_cls + valid_sub_cls + test_sub_cls
all_subcls = list(set(total_sub_cls))

print("Training data samples:",train_samples)
print("Training data classes:",len(train_sub_cls))

gdata_file=f"data/{dataset}/{dataset}_latest_norm_mod.owl"
train_data, classes, relations = load_data(gdata_file)
valid_data_file=f"data/{dataset}/{dataset}_valid.txt"
valid_data = load_valid_data(valid_data_file, classes, relations)
valid_data = torch.Tensor(valid_data).long().to(device)


Training data samples: 171740
Training data classes: 5502


In [5]:
train_data

{'nf1': array([[ 6410,    -1, 35259],
        [42837,    -1, 18989],
        [  707,    -1,    32],
        ...,
        [40420,    -1, 38998],
        [35844,    -1,    32],
        [43256,    -1, 14563]]), 'nf2': array([[   32, 26966,  3507],
        [   32, 29579,  2607],
        [   32, 45921, 10815],
        ...,
        [46530,    32, 21459],
        [   32, 45472,  5301],
        [40939, 44451, 41792]]), 'nf3': array([[20136,     0, 33676],
        [43756,     6,  2971],
        [ 8610,     3, 46676],
        ...,
        [14405,     5, 29897],
        [25936,     5, 31069],
        [ 6191,     3, 38848]]), 'nf4': array([[    6, 29313,  7624],
        [    3, 19797,  3072],
        [    0, 15187, 31853],
        ...,
        [    3, 47475, 31855],
        [    5,  9012, 21716],
        [    5,  1564, 20949]]), 'disjoint': array([[33686, 12492, 12493],
        [26018, 11978, 12493],
        [12492, 48917, 12493],
        [25016, 12492, 12493],
        [15106, 16938, 12493],
     

In [15]:
batch_size = 256 

org = 4932
proteins = {}
for k, v in classes.items():
    if k.startswith(f'<http://{org}'):
        proteins[k] = v
print('Proteins:', len(proteins))

nb_classes = len(classes)
nb_relations = len(relations)

print("no. classes:",nb_classes)
print("no. relations:",nb_relations)
nb_train_data = 0

for key, val in train_data.items():
    nb_train_data = max(len(val), nb_train_data)
    
train_steps = int(math.ceil(nb_train_data / (1.0 * batch_size)))
train_generator = Generator(train_data, batch_size, steps=train_steps)

cls_dict = {v: k for k, v in classes.items()}
rel_dict = {v: k for k, v in relations.items()}

cls_list = []
rel_list = []
for i in range(nb_classes):
    cls_list.append(cls_dict[i])
for i in range(nb_relations):
    rel_list.append(rel_dict[i])

classes_index = list(classes.values())
classes_index = torch.Tensor(classes_index).to(device).reshape(-1,1).long()
classes_index.shape

protein_index = list(proteins.values())
protein_index = torch.Tensor(protein_index).to(device).reshape(-1,1).long()
protein_index


Proteins: 5502
no. classes: 51689
no. relations: 10


tensor([[    2],
        [    3],
        [    4],
        ...,
        [51236],
        [51496],
        [51497]], device='cuda:2')

In [29]:
wandb.init(project="basic_box", reinit=False, config=args)


In [19]:

model = Model(nb_classes, nb_relations, 50, [1e-4, 0.2], [-0.1, 0], [-0.1,0.1], [0.9,1.1], args).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
wandb.watch(model)


[]

In [32]:


model.train()
steps_per_epoch = train_steps
train_generator.reset()

epochs = 1000
for epoch in range(epochs):
    train_loss = 0.0
    nf1_total_loss = 0.0
    nf2_total_loss = 0.0
    nf3_total_loss = 0.0
    nf4_total_loss = 0.0
    disjoint_total_loss = 0.0
    nf3_neg_total_loss = 0.0
    ri_total_loss = 0.0
    rc_total_loss = 0.0
    int_total_loss = 0.0
    mem_total_loss = 0.0
    
    for step, batch in enumerate(train_generator):
        if step < steps_per_epoch:
            n1,n2,n3,n4,dis,n3n,ri,rc,inter,member,reg1,reg2,reg3,reg4,reg_dis,reg3_n,reg_inter, reg_member = model(batch[0])
            loss = n1+n2+n3+n4+dis+n3n+ri+rc+inter+member+reg1+reg2+reg3+reg4+reg_dis+reg3_n+reg_inter+reg_member
            
            assert torch.isnan(loss).sum() == 0
            optimizer.zero_grad()
            loss.backward()
            assert torch.isnan(model.min_embedding).sum() == 0
            optimizer.step()
            assert torch.isnan(model.min_embedding).sum() == 0
            assert torch.isnan(model.min_embedding.grad).sum() == 0
            train_loss += loss
            nf1_total_loss += n1
            nf2_total_loss += n2
            nf3_total_loss += n3
            nf4_total_loss += n4
            disjoint_total_loss += dis
            nf3_neg_total_loss += n3n
            ri_total_loss +=ri
            rc_total_loss += rc
            int_total_loss += inter
            mem_total_loss = member
        else:
            train_generator.reset()
            break
            
    mean_rank = 0
    valid_accuracy = compute_accuracy(model, valid_data)
    
#     wandb.log({'train loss': train_loss.item()/(step+1),'nf1_loss':nf1_total_loss.item()/(step+1), 'nf2_loss':nf2_total_loss.item()/(step+1), 'nf3_loss':nf3_total_loss.item()/(step+1), 'nf4_loss':nf4_total_loss.item()/(step+1), 'nf3_neg_loss':nf3_neg_total_loss.item()/(step+1), 
#                 'role_inclusion_loss': ri_total_loss.item()/(step+1),'role_chain_loss': rc_total_loss.item()/(step+1), 'interaction_loss': int_total_loss.item()/(step+1),'function_loss': mem_total_loss.item()/(step+1), 'mean_rank':mean_rank,'valid_accuracy':valid_accuracy })

    PATH = './models/box_el_ppi.pt'
    if epoch % 10 == 0:
        torch.save(model, PATH)
        print('Epoch:%d' %(epoch + 1), "Train loss: %f" %(train_loss.item()/(step+1)), f'Valid Mean Rank: {mean_rank}\n')
   


Epoch:1 Train loss: 3.478361 Valid Mean Rank: 0

Epoch:11 Train loss: 3.474711 Valid Mean Rank: 0

Epoch:21 Train loss: 3.479704 Valid Mean Rank: 0

Epoch:31 Train loss: 3.482792 Valid Mean Rank: 0

Epoch:41 Train loss: 3.479377 Valid Mean Rank: 0

Epoch:51 Train loss: 3.494027 Valid Mean Rank: 0

Epoch:61 Train loss: 3.480729 Valid Mean Rank: 0

Epoch:71 Train loss: 3.487980 Valid Mean Rank: 0

Epoch:81 Train loss: 3.491811 Valid Mean Rank: 0

Epoch:91 Train loss: 3.486822 Valid Mean Rank: 0

Epoch:101 Train loss: 3.475150 Valid Mean Rank: 0

Epoch:111 Train loss: 3.494037 Valid Mean Rank: 0

Epoch:121 Train loss: 3.490071 Valid Mean Rank: 0

Epoch:131 Train loss: 3.488562 Valid Mean Rank: 0

Epoch:141 Train loss: 3.481843 Valid Mean Rank: 0

Epoch:151 Train loss: 3.500453 Valid Mean Rank: 0

Epoch:161 Train loss: 3.469411 Valid Mean Rank: 0

Epoch:171 Train loss: 3.484949 Valid Mean Rank: 0

Epoch:181 Train loss: 3.485209 Valid Mean Rank: 0

Epoch:191 Train loss: 3.479152 Valid Mean 

KeyboardInterrupt: 

In [33]:

PATH = './models/box_el_ppi.pt'
# PATH = './models/box_el_ppi_boxes.pt'
model = torch.load(PATH)

test_file = f'data/{dataset}/{dataset}_test.txt'
test_data = load_valid_data(test_file,classes,relations)
test_data = torch.Tensor(test_data).long().to(device)


In [34]:

top1,top10, top100, mean_rank, median_rank, per_rank, auc,rank_values = compute_rank(model, test_data, 90, proteins, device)
print(top1,top10, top100, mean_rank, median_rank, per_rank, auc)

compute_accuracy(model, test_data)



0.0 0.00586756077116513 0.06198193163826022 1718.2797569153395 1329.0 4000.0 0.9665397516576638


0.0

In [24]:

top1,top10, top100, mean_rank, median_rank, per_rank, auc,rank_values = compute_rank(model, test_data, 90,classes,device)
print(top1,top10, top100, mean_rank, median_rank, per_rank, auc)

compute_accuracy(model, test_data)



0.0 0.0005122473689112415 0.009173884697774052 2831.051154884977 2277.0 5164.700000000001 0.9451738970891046


0.1102728881438018

In [13]:

top1,top10, top100, mean_rank, median_rank, per_rank, auc,rank_values = compute_rank(model, test_data, 90,classes,device)
print(top1,top10, top100, mean_rank, median_rank, per_rank, auc)

compute_accuracy(model, test_data)



0.0 0.0005122473689112415 0.009173884697774052 2831.051154884977 2277.0 5164.700000000001 0.9451738970891046


0.1102728881438018

In [35]:



testing: 1111 0.0 0.00430416068866571 0.11477761836441894 3118.3922166427546 921.0 10303.800000000001 0.8714267973960943 0.6603299856527977
inference: 1111 0.0 0.010766194150367845 0.2015072671810515 1067.8097075183923 415.5 2987.4 0.9551401890471426 0.9944374663556432
testing: 2222 0.0 0.005380200860832138 0.10616929698708752 3174.333213773314 919.0 10598.800000000007 0.8691295137086096 0.6535150645624104
inference: 2222 0.0 0.01040732101202225 0.2013278306118787 1080.09976673246 413.5 3016.6000000000004 0.9553046959013611 0.9942580297864705
testing: 3333 0.0 0.006097560975609756 0.1133428981348637 3161.330164992826 935.25 10594.400000000001 0.8700184004504289 0.6574605451936872
inference: 3333 0.0 0.013637179257132604 0.20455768885698905 1084.1268616544053 424.5 2986.5 0.9548764847827171 0.9938991566481249
testing: 4444 0.0 0.007173601147776184 0.11406025824964132 3117.961262553802 894.25 10008.6 0.8717968588033929 0.6660688665710186
inference: 4444 0.0 0.013278306118787008 0.20240445002691548 1077.5013457742689 404.5 3060.2000000000003 0.9555499005819137 0.9949757760631617
testing: 5555 0.0 0.004662840746054519 0.1093974175035868 3128.779949784792 898.0 10143.950000000006 0.8713564583707556 0.6477761836441894
inference: 5555 0.0 0.011304503857886236 0.19845684550511394 1065.7015072671811 411.5 2936.4 0.9558965683789529 0.9930019738022609

testing: 6666 0.0 0.005021520803443328 0.10616929698708752 3153.3669296987086 905.75 10360.600000000002 0.8703439527157969 0.6581779053084649
inference: 6666 0.0 0.01040732101202225 0.19773909922842275 1093.2905078054907 423.5 3051.8 0.9547617037341984 0.9938991566481249
    
testing: 7777 0.0 0.006097560975609756 0.11011477761836441 3099.3681850789094 920.75 9882.500000000011 0.8714952769643954 0.6588952654232425
inference: 7777 0.0 0.012739996411268616 0.19630360667504038 1081.3397631437288 427.0 3037.2000000000003 0.9552559924265435 0.9947963394939888
    
testing: 8888 0.0 0.006097560975609756 0.10975609756097561 3091.743364418938 910.75 9711.90000000002 0.872874780525156 0.6588952654232425
inference: 8888 0.0 0.01184281356540463 0.1941503678449668 1080.7852144267001 418.5 3030.8000000000006 0.9550102572391543 0.9942580297864705
    
    
    

SyntaxError: invalid syntax (<ipython-input-35-281f82d2bf01>, line 1)

In [None]:



1111 0.0 0.005738880918220947 0.10294117647058823 2876.125717360115 972.0 8990.45000000001 0.8751919911650198 0.6578192252510761
1111 0.0 0.016328727794724565 0.1955858603983492 1030.250134577427 428.0 2725.600000000001 0.9554435551497155 0.9953346492015073

2222 0.0 0.007173601147776184 0.0979196556671449 2907.119081779053 972.75 9664.400000000001 0.8742102474870697 0.6653515064562411
2222 0.0 0.014713798672169388 0.19630360667504038 1031.4783778934147 430.5 2654.8 0.9553963485689416 0.9944374663556432

3333 0.0 0.005021520803443328 0.11011477761836441 2939.3970588235293 972.75 9499.400000000005 0.872463057751984 0.6700143472022956
3333 0.0 0.012739996411268616 0.19343262156827562 1017.0012560559842 422.5 2652.6000000000004 0.9560219302626464 0.9951552126323344

4444 0.0 0.007890961262553802 0.11047345767575323 2962.374820659971 964.25 9546.800000000001 0.8714725880835268 0.6631994261119082
4444 0.0 0.013996052395478199 0.2013278306118787 1008.298492732819 421.0 2591.6000000000004 0.9563884582602485 0.9942580297864705

5555 0.0 0.011119081779053085 0.09935437589670014 2902.1370157819224 951.5 9373.7 0.8726586253331841 0.6664275466284074
5555 0.0 0.014354925533823793 0.19127938273820205 1018.201955858604 422.0 2700.1000000000045 0.9559743631339312 0.9960523954781985

6666 0.0 0.007173601147776184 0.0979196556671449 2935.169835007174 955.0 8658.600000000002 0.8726464159210671 0.6664275466284074
6666 0.0 0.016149291225551768 0.20473712542616185 1046.2881751300915 432.5 2732.8 0.9547285653201916 0.9940785932172976

7777 0.0 0.006456241032998565 0.10222381635581061 2949.4409971305595 969.25 8905.900000000009 0.8716735997218593 0.6553084648493543
7777 0.0 0.013816615826305402 0.2086847299479634 1050.7840480890006 437.5 2785.3 0.9545360947490835 0.9946169029248161

8888 0.0 0.0068149210903873745 0.10688665710186514 2885.4922883787663 911.75 9153.100000000002 0.8744365730223742 0.6667862266857962
8888 0.0 0.012022250134577427 0.19989233805849632 1024.9979364794544 418.5 2673.9000000000005 0.9556383828514529 0.9946169029248161

9999 0.0 0.007890961262553802 0.09505021520803443 2947.022238163558 954.0 9308.000000000013 0.8724852059587944 0.6685796269727403
9999 0.0 0.013098869549614211 0.20473712542616185 1045.9545128297148 430.5 2767.9 0.9543495848495924 0.993719720078952

1010 0.0 0.007173601147776184 0.10186513629842181 2919.2385222381636 911.75 9282.400000000003 0.873687974481546 0.6696556671449068
1010 0.0 0.013996052395478199 0.2147855732998385 1026.7513009151264 419.5 2772.4 0.9555555421157754 0.996411268616544

