In [1]:
#!python setup.py install --user

In [2]:
#!pip uninstall kdmkr -y

In [3]:
from kdmkr import datasets
from kdmkr import distillation
from kdmkr import evaluation
from kdmkr import loss
from kdmkr import model

In [4]:
from creme import stats

In [5]:
import pickle 
import torch
import tqdm

In [6]:
device     = 'cuda'
hidden_dim = 500
batch_size = 512

In [7]:
teacher_name = 'teacher_wn18rr_HITS@10: 0.527760, HITS@1: 0.419113, HITS@3: 0.477664, MR: 5509.747288, MRR: 0.457429.pickle'

with open(f'/users/iris/rsourty/experiments/kdmkr/{teacher_name}', 'rb') as handle:
    teacher = pickle.load(handle)    

In [8]:
wn18rr = datasets.WN18RR(batch_size=1, negative_sample_size=1, seed=42, shuffle=True)

In [9]:
student = model.RotatE(hidden_dim=hidden_dim, n_entity=wn18rr.n_entity, n_relation=wn18rr.n_relation, gamma=6)

student = student.to(device)

In [10]:
distillation_process = distillation.Distillation(
    teacher_entities  = wn18rr.entities, 
    student_entities  = wn18rr.entities, 
    teacher_relations = wn18rr.relations, 
    student_relations = wn18rr.relations,
)

In [None]:
student_batch_size = 2

max_step = 100000

wn18rr = datasets.WN18RR(batch_size=student_batch_size, negative_sample_size=1, seed=42, shuffle=True)

student = model.RotatE(hidden_dim=500, n_entity=wn18rr.n_entity, n_relation=wn18rr.n_relation, gamma=0)

student = student.to(device)

optimizer_student = torch.optim.Adam(filter(lambda p: p.requires_grad, student.parameters()), lr = 0.0005)

bar = tqdm.tqdm(range(1, max_step + 1), position=0)

metric = stats.RollingMean(1000)

teacher.eval()

student.train()

kl_divergence = loss.KlDivergence()

for step in bar:
    
    positive_sample, _, _, _ = next(wn18rr)
    
    teacher_head = []
    student_head = []
    
    teacher_relation = []
    student_relation = []
    
    teacher_tail = []
    student_tail = []
    
    optimizer_student.zero_grad()
    
    for head, relation, tail in positive_sample:
        
        head, relation, tail = head.item(), relation.item(), tail.item()

        mode = distillation_process.distillation_mode(head=head, relation=relation, tail=tail)
        
        if mode['head']:
            
            teacher_common_head_sample, _ = distillation_process.mini_batch_teacher_head(
                relation=relation, tail=tail)
            
            student_common_head_sample, _ = distillation_process.mini_batch_student_head(
                teacher_relation=relation, teacher_tail=tail)
            
            teacher_head.append(teacher_common_head_sample)
            student_head.append(student_common_head_sample)

        if mode['relation']:
            
            teacher_common_relation_sample, _ = distillation_process.mini_batch_teacher_relation(
                head=head, tail=tail)
            
            student_common_relation_sample, _ = distillation_process.mini_batch_student_relation(
                teacher_head=head, teacher_tail=tail)
            
            teacher_relation.append(teacher_common_relation_sample)
            student_relation.append(student_common_relation_sample)
            
        if mode['tail']:

            teacher_common_tail_sample, _ = distillation_process.mini_batch_teacher_tail(
                head=head, relation=relation)

            student_common_tail_sample, _ = distillation_process.mini_batch_student_tail(
                teacher_head=head, teacher_relation=relation)
            
            teacher_tail.append(teacher_common_tail_sample)
            student_tail.append(student_common_tail_sample)
            
    teacher_head_tensor = torch.stack(teacher_head).reshape(len(teacher_head),  wn18rr.n_entity, 3).to(device)
    student_head_tensor = torch.stack(student_head).reshape(len(student_head),  wn18rr.n_entity, 3).to(device)   
      
    teacher_relation_tensor = torch.stack(teacher_relation).reshape(len(teacher_relation), wn18rr.n_relation, 3).to(device)
    student_relation_tensor = torch.stack(student_relation).reshape(len(student_relation), wn18rr.n_relation, 3).to(device)
    
    teacher_tail_tensor = torch.stack(teacher_tail).reshape(len(teacher_tail), wn18rr.n_entity, 3).to(device)
    student_tail_tensor = torch.stack(student_tail).reshape(len(student_tail), wn18rr.n_entity, 3).to(device)
    
    loss_head = kl_divergence(
        teacher_score=teacher.distill(teacher_head_tensor), 
        student_score=student.distill(student_head_tensor)
    ) 
    
    loss_relation = kl_divergence(
        teacher_score=teacher.distill(teacher_relation_tensor), 
        student_score=student.distill(student_relation_tensor)
    ) 
    
    loss_tail = kl_divergence(
        teacher_score=teacher.distill(teacher_tail_tensor), 
        student_score=student.distill(student_tail_tensor)
    )
    
    loss_student = loss_head + loss_relation + loss_tail
    
    metric.update(loss_student.item())
    
    loss_student.backward()

    optimizer_student.step()
    
    if step % 10 == 0:
    
        bar.set_description(f'Metric: {metric.get():6f}')
    
    if step % 5000 == 0:
        
        student = student.eval()
        
        score = evaluation.Evaluation()(model=student, dataset=wn18rr.test_dataset(batch_size=8), device=device)
        
        print(score)
        
        student = student.train()

Metric: 0.162908:   2%|▏         | 1695/100000 [09:41<9:24:49,  2.90it/s]