In [1]:
#!python setup.py install --user

In [3]:
from kdmkr import datasets
from kdmkr import distillation
from kdmkr import evaluation
from kdmkr import loss
from kdmkr import model

In [4]:
from creme import stats

In [5]:
import pickle 
import torch
import tqdm

In [6]:
device     = 'cuda'
hidden_dim = 500
batch_size = 512

In [7]:
wn18rr = datasets.WN18RR(batch_size=batch_size, negative_sample_size=batch_size*2, shuffle=True, seed=42)

In [8]:
teacher = model.RotatE(hidden_dim=hidden_dim, n_entity=wn18rr.n_entity, n_relation=wn18rr.n_relation, gamma=6)
teacher = teacher.to(device)

In [None]:
optimizer_teacher = torch.optim.Adam(filter(lambda p: p.requires_grad, teacher.parameters()), lr = 0.00005)

max_step = 6000

bar = tqdm.tqdm(range(1, max_step), position=0)

metric = stats.RollingMean(1000)

evaluation = evaluation.Evaluation()

teacher.train()

for step in bar:
    
    optimizer_teacher.zero_grad()
    
    positive_sample, negative_sample, weight, mode = next(wn18rr)
    
    positive_sample = positive_sample.to(device)
    
    negative_sample = negative_sample.to(device)
    
    weight = weight.to(device)
    
    positive_score = teacher(sample=positive_sample)
    
    negative_score = teacher(sample=(positive_sample, negative_sample), mode=mode)
    
    loss_teacher = loss.Adversarial()(positive_score, negative_score, weight, alpha=0.5)
    
    loss_teacher.backward()
    
    optimizer_teacher.step()
    
    metric.update(loss_teacher.item())
    
    if step % 30 == 0:
    
        bar.set_description(f'Metric: {metric.get():6f}')
    
    if step % 2000 == 0:
        
        teacher = teacher.eval()
        
        score = evaluation(model=teacher, dataset=wn18rr.test_dataset(batch_size=8), device=device)
        
        teacher = teacher.train()
        
        print(score)
        
        with open(f'/users/iris/rsourty/experiments/kdmkr/models/teacher_wn18rr_{score}.pickle', 'wb') as handle:
            
            pickle.dump(teacher, handle, protocol = pickle.HIGHEST_PROTOCOL)    

In [10]:
# Output of the previous training

teacher_name = 'teacher_wn18rr_HITS@10: 0.527760, HITS@1: 0.419113, HITS@3: 0.477664, MR: 5509.747288, MRR: 0.457429.pickle'

with open(f'/users/iris/rsourty/experiments/kdmkr/models/{teacher_name}', 'rb') as handle:
    
    teacher = pickle.load(handle)    

In [12]:
student_batch_size = 1000

max_step = 40000

torch.manual_seed(42)

wn18rr = datasets.WN18RR(batch_size=student_batch_size, negative_sample_size=1, seed=42, shuffle=True)

student = model.RotatE(hidden_dim=1500, n_entity=wn18rr.n_entity, n_relation=wn18rr.n_relation, gamma=0)

student = student.to(device)

distillation_process = distillation.Distillation(
    teacher_entities  = wn18rr.entities, 
    student_entities  = wn18rr.entities, 
    teacher_relations = wn18rr.relations, 
    student_relations = wn18rr.relations,
)

optimizer_student = torch.optim.Adam(filter(lambda p: p.requires_grad, student.parameters()), lr = 0.00005)

bar = tqdm.tqdm(range(1, max_step + 1), position=0)

metric = stats.RollingMean(1000)

teacher.eval()

student.train()

kl_divergence = loss.KlDivergence()

for step in bar:
    
    positive_sample, negative_sample, weight, mode = next(wn18rr)
    
    teacher_relation = []
    student_relation = []
    
    optimizer_student.zero_grad()
    
    for head, relation, tail in positive_sample:
        
        head, relation, tail = head.item(), relation.item(), tail.item()

        teacher_common_relation_sample, _ = distillation_process.mini_batch_teacher_relation(
            head=head, tail=tail)

        student_common_relation_sample, _ = distillation_process.mini_batch_student_relation(
            teacher_head=head, teacher_tail=tail)

        teacher_relation.append(teacher_common_relation_sample)
        student_relation.append(student_common_relation_sample)
    
    teacher_relation_tensor = torch.stack(teacher_relation).reshape(len(teacher_relation), wn18rr.n_relation, 3).to(device)
    student_relation_tensor = torch.stack(student_relation).reshape(len(student_relation), wn18rr.n_relation, 3).to(device)
    
    loss_relation = kl_divergence(
        teacher_score=teacher.distill(teacher_relation_tensor), 
        student_score=student.distill(student_relation_tensor)
    ) 
    
    loss_student = loss_relation
    
    metric.update(loss_student.item())
    
    loss_student.backward()

    optimizer_student.step()
    
    if step % 7 == 0:
    
        bar.set_description(f'Metric: {metric.get():6f}')
    
    if step % 1000 == 0:
        
        student = student.eval()
        
        score = evaluation.Evaluation()(model=student, dataset=wn18rr.test_dataset(batch_size=2), device=device)
        
        print(score)
        
        with open(f'/users/iris/rsourty/experiments/kdmkr/models/student_wn18rr_{score}.pickle', 'wb') as handle:
            
            pickle.dump(student, handle, protocol = pickle.HIGHEST_PROTOCOL)    
        
        student = student.train()

Metric: 0.037566:   2%|▏         | 999/40000 [02:53<1:51:06,  5.85it/s]

HITS@10: 0.414008, HITS@1: 0.369177, HITS@3: 0.396937, MR: 10494.610881, MRR: 0.386363


Metric: 0.000516:   5%|▍         | 1999/40000 [07:51<1:48:58,  5.81it/s]  

HITS@10: 0.417996, HITS@1: 0.373484, HITS@3: 0.402680, MR: 10160.950064, MRR: 0.390909


Metric: 0.000107:   7%|▋         | 2999/40000 [12:50<1:46:10,  5.81it/s]  

HITS@10: 0.419751, HITS@1: 0.373165, HITS@3: 0.402521, MR: 10075.413369, MRR: 0.390913


Metric: 0.000131:  10%|▉         | 3999/40000 [17:49<1:46:49,  5.62it/s]  

HITS@10: 0.414805, HITS@1: 0.368858, HITS@3: 0.398692, MR: 9969.183153, MRR: 0.386589


Metric: 0.000186:  12%|█▏        | 4999/40000 [22:47<1:38:26,  5.93it/s]  

HITS@10: 0.405712, HITS@1: 0.361040, HITS@3: 0.390077, MR: 9816.358328, MRR: 0.378677


Metric: 0.000213:  15%|█▍        | 5999/40000 [27:46<1:37:57,  5.78it/s]  

HITS@10: 0.392310, HITS@1: 0.353063, HITS@3: 0.377473, MR: 9399.855775, MRR: 0.368555


Metric: 0.000229:  17%|█▋        | 6999/40000 [32:45<1:32:44,  5.93it/s]  

HITS@10: 0.377313, HITS@1: 0.341576, HITS@3: 0.364231, MR: 8933.670230, MRR: 0.355502


Metric: 0.000236:  20%|█▉        | 7999/40000 [37:44<1:32:33,  5.76it/s]  

HITS@10: 0.358488, HITS@1: 0.325622, HITS@3: 0.345884, MR: 8481.211072, MRR: 0.338464


Metric: 0.000236:  22%|██▏       | 8999/40000 [42:43<1:27:41,  5.89it/s]  

HITS@10: 0.342055, HITS@1: 0.310147, HITS@3: 0.330408, MR: 8047.845565, MRR: 0.322884


Metric: 0.000237:  25%|██▍       | 9999/40000 [47:42<1:26:41,  5.77it/s]  

HITS@10: 0.317486, HITS@1: 0.291481, HITS@3: 0.307275, MR: 7705.363433, MRR: 0.301756


Metric: 0.000230:  30%|██▉       | 11999/40000 [57:39<1:18:49,  5.92it/s]  

HITS@10: 0.276643, HITS@1: 0.252712, HITS@3: 0.267869, MR: 7161.599394, MRR: 0.262655


Metric: 0.000231:  32%|███▏      | 12999/40000 [1:02:39<1:19:55,  5.63it/s]

HITS@10: 0.263401, HITS@1: 0.244257, HITS@3: 0.254946, MR: 6923.241863, MRR: 0.252174


Metric: 0.000225:  34%|███▍      | 13660/40000 [1:06:39<1:16:14,  5.76it/s]  IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Metric: 0.000203:  47%|████▋     | 18999/40000 [1:32:31<59:10,  5.92it/s]  

HITS@10: 0.172304, HITS@1: 0.156988, HITS@3: 0.165603, MR: 6277.749043, MRR: 0.163818


Metric: 0.000193:  50%|████▉     | 19999/40000 [1:37:30<56:57,  5.85it/s]    

HITS@10: 0.154435, HITS@1: 0.141353, HITS@3: 0.149011, MR: 6150.708679, MRR: 0.147561


Metric: 0.000190:  52%|█████▏    | 20999/40000 [1:42:29<55:01,  5.76it/s]    

HITS@10: 0.142789, HITS@1: 0.129068, HITS@3: 0.137364, MR: 6173.437779, MRR: 0.135442


Metric: 0.000187:  53%|█████▎    | 21389/40000 [1:45:42<53:27,  5.80it/s]    IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Metric: 0.000147:  67%|██████▋   | 26999/40000 [2:12:23<36:51,  5.88it/s]

HITS@10: 0.090140, HITS@1: 0.080728, HITS@3: 0.086790, MR: 6074.925814, MRR: 0.085867


Metric: 0.000145:  70%|██████▉   | 27999/40000 [2:17:21<33:39,  5.94it/s]    

HITS@10: 0.091576, HITS@1: 0.082642, HITS@3: 0.087907, MR: 6096.272495, MRR: 0.087373


Metric: 0.000139:  72%|███████▏  | 28999/40000 [2:22:20<31:19,  5.85it/s]    

HITS@10: 0.084716, HITS@1: 0.076101, HITS@3: 0.081366, MR: 6089.404116, MRR: 0.080664


Metric: 0.000135:  73%|███████▎  | 29252/40000 [2:25:09<30:50,  5.81it/s]    IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Metric: 0.000114:  87%|████████▋ | 34999/40000 [2:52:13<14:51,  5.61it/s]

HITS@10: 0.058711, HITS@1: 0.050574, HITS@3: 0.055680, MR: 6160.609126, MRR: 0.055069


Metric: 0.000111:  90%|████████▉ | 35999/40000 [2:57:12<11:28,  5.81it/s]   

HITS@10: 0.059349, HITS@1: 0.052648, HITS@3: 0.056318, MR: 6175.678845, MRR: 0.056539


Metric: 0.000109:  92%|█████████▏| 36999/40000 [3:02:11<08:41,  5.75it/s]   

HITS@10: 0.056637, HITS@1: 0.051053, HITS@3: 0.054882, MR: 6173.936981, MRR: 0.054732


Metric: 0.000109:  93%|█████████▎| 37036/40000 [3:04:23<08:22,  5.90it/s]   IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

