In [1]:
# Install the library:
#!python setup.py install --user

In [2]:
from kdmkr import datasets
from kdmkr import distillation
from kdmkr import evaluation
from kdmkr import loss
from kdmkr import model

In [3]:
from creme import stats

In [4]:
import pickle 
import torch
import tqdm
import numpy as np

In [5]:
device     = 'cuda'
hidden_dim = 500
batch_size = 512

In [6]:
wn18rr = datasets.WN18RR(
    batch_size=512, 
    negative_sample_size=1024, 
    shuffle=True, 
    seed=42
)

In [7]:
teacher = model.RotatE(
    hidden_dim=500, 
    n_entity=wn18rr.n_entity, 
    n_relation=wn18rr.n_relation, 
    gamma=6
)
teacher = teacher.to(
    device)

In [None]:
optimizer_teacher = torch.optim.Adam(filter(lambda p: p.requires_grad, teacher.parameters()), lr = 0.00005)

max_step = 6000

bar = tqdm.tqdm(range(1, max_step), position=0)

metric = stats.RollingMean(1000)

teacher.train()

for step in bar:
    
    optimizer_teacher.zero_grad()
    
    positive_sample, negative_sample, weight, mode = next(wn18rr)
    
    positive_sample = positive_sample.to(device)
    
    negative_sample = negative_sample.to(device)
    
    weight = weight.to(device)
    
    positive_score = teacher(sample=positive_sample)
    
    negative_score = teacher(sample=(positive_sample, negative_sample), mode=mode)
    
    loss_teacher = loss.Adversarial()(positive_score, negative_score, weight, alpha=0.5)
    
    loss_teacher.backward()
    
    optimizer_teacher.step()
    
    metric.update(loss_teacher.item())
    
    if step % 5 == 0:
        bar.set_description(f'Adversarial loss: {metric.get():6f}')
    
    if step % 2000 == 0:
        
        teacher = teacher.eval()
        
        score = evaluation.Evaluation()(model=teacher, dataset=wn18rr.test_dataset(batch_size=8), device=device)
        
        teacher = teacher.train()
        
        print(score)
        
        # Set path HERE
        with open(f'./models/teacher_wn18rr_{score}.pickle', 'wb') as handle:
            
            pickle.dump(teacher, handle, protocol = pickle.HIGHEST_PROTOCOL)    

In [8]:
# Output of the previous training

teacher_name = 'teacher_wn18rr_HITS@10: 0.527760, HITS@1: 0.419113, HITS@3: 0.477664, MR: 5509.747288, MRR: 0.457429.pickle'

with open(f'./models/{teacher_name}', 'rb') as handle:
    
    teacher = pickle.load(handle)    

In [9]:
max_step = 40000

torch.manual_seed(42)
np.random.seed(42)

wn18rr = datasets.WN18RR(
    batch_size=1000, 
    negative_sample_size=1, 
    seed=42, 
    shuffle=True
)

# Increasing the size of latents representations of the student allow to improve results.
student = model.RotatE(
    hidden_dim=1000, 
    n_entity=wn18rr.n_entity, 
    n_relation=wn18rr.n_relation, 
    gamma=0
)

student = student.to(device)

# Distillation process allow to handle different indexes between the student and the teacher.
# Distillation process allow to pre-compute batch dedicated to distillation.
distillation_process = distillation.Distillation(
    teacher_entities  = wn18rr.entities, 
    student_entities  = wn18rr.entities, 
    teacher_relations = wn18rr.relations, 
    student_relations = wn18rr.relations,
    batch_size_entity = 20,
    batch_size_relation = 11,
)


optimizer_student = torch.optim.Adam(
    filter(lambda p: p.requires_grad, student.parameters()), lr = 0.00005)

bar = tqdm.tqdm(range(1, max_step + 1), position=0)

# Creme online metric.
metric = stats.RollingMean(1000)

teacher = teacher.eval()
student = student.train()

for step in bar:
    
    optimizer_student.zero_grad()
    
    positive_sample, negative_sample, weight, mode = next(wn18rr)
    
    batch_head_teacher = []
    batch_head_student = []
    
    batch_relation_teacher = []
    batch_relation_student = []
    
    batch_tail_teacher = []
    batch_tail_student = []
    
    (
         entity_distribution_teacher,
         relation_distribution_teacher,
         entity_distribution_student,
         relation_distribution_student
   ) = distillation_process.uniform_subsampling()

    
    for head, relation, tail in positive_sample:
        
        head, relation, tail = head.item(), relation.item(), tail.item()
        
        distillation_available = distillation_process.available(head=head, relation=relation, tail=tail)
       
        if distillation_available:
            
            tensor_head_teacher, tensor_head_student = distillation_process.get_distillation_sample_head(
                entity_distribution_teacher=entity_distribution_teacher,
                entity_distribution_student=entity_distribution_student,
                head_teacher=head, relation_teacher=relation, tail_teacher=tail
            )
             
            tensor_relation_teacher, tensor_relation_student = distillation_process.get_distillation_sample_relation(
                  relation_distribution_teacher=relation_distribution_teacher,
                  relation_distribution_student=relation_distribution_student,
                  head_teacher=head, relation_teacher=relation, tail_teacher=tail
            )
             
            tensor_tail_teacher, tensor_tail_student = distillation_process.get_distillation_sample_tail(
                  entity_distribution_teacher=entity_distribution_teacher,
                  entity_distribution_student=entity_distribution_student,
                  head_teacher=head, relation_teacher=relation, tail_teacher=tail
            )
        
            batch_head_teacher.append(tensor_head_teacher)
            batch_head_student.append(tensor_head_student)

            batch_relation_teacher.append(tensor_relation_teacher)
            batch_relation_student.append(tensor_relation_student)

            batch_tail_teacher.append(tensor_tail_teacher)
            batch_tail_student.append(tensor_tail_student)
    
    
    teacher_head_tensor = distillation_process.stack_entity(batch_head_teacher, device=device)
    student_head_tensor = distillation_process.stack_entity(batch_head_student, device=device)
    
    teacher_relation_tensor = distillation_process.stack_relations(batch_relation_teacher, device=device)
    student_relation_tensor = distillation_process.stack_relations(batch_relation_student, device=device)
    
    teacher_tail_tensor = distillation_process.stack_entity(batch_tail_teacher, device=device)
    student_tail_tensor = distillation_process.stack_entity(batch_tail_student, device=device)

    # Distillation loss of heads
    loss_head = loss.KlDivergence()(
        teacher_score=teacher.distill(teacher_head_tensor), 
        student_score=student.distill(student_head_tensor)
    ) 
    
    # Distillation loss of relations.
    loss_relation = loss.KlDivergence()(
        teacher_score=teacher.distill(teacher_relation_tensor), 
        student_score=student.distill(student_relation_tensor)
    ) 
    
    # Distillation loss of tails.
    loss_tail = loss.KlDivergence()(
        teacher_score=teacher.distill(teacher_tail_tensor), 
        student_score=student.distill(student_tail_tensor)
    ) 
    
    # The loss of the student is equal to the sum of all losses.
    loss_student = loss_head + loss_relation + loss_tail
    
    metric.update(loss_student.item())
    
    loss_student.backward()

    optimizer_student.step()
    
    if step % 5 == 0:
    
        bar.set_description(f'Metric: {metric.get():6f}')
    
    if step % 1000 == 0:
        
        student = student.eval()
        
        score = evaluation.Evaluation()(model=student, dataset=wn18rr.test_dataset(batch_size=8), device=device)
        
        print(score)
        
        with open(f'./models/student_wn18rr_{score}.pickle', 'wb') as handle:
            
            pickle.dump(student, handle, protocol = pickle.HIGHEST_PROTOCOL)    
        
        student = student.train()

Metric: 0.188103:   2%|▏         | 999/40000 [10:41<6:44:07,  1.61it/s]

HITS@10: 0.464582, HITS@1: 0.371091, HITS@3: 0.432674, MR: 6778.074186, MRR: 0.407070


Metric: 0.041273:   5%|▍         | 1999/40000 [22:54<6:17:31,  1.68it/s]  

HITS@10: 0.481174, HITS@1: 0.401404, HITS@3: 0.450862, MR: 5879.874442, MRR: 0.431126


Metric: 0.015341:   7%|▋         | 2999/40000 [35:06<6:15:39,  1.64it/s]  

HITS@10: 0.490587, HITS@1: 0.412412, HITS@3: 0.458360, MR: 5445.871251, MRR: 0.441519


Metric: 0.007602:  10%|▉         | 3999/40000 [47:15<6:02:13,  1.66it/s]  

HITS@10: 0.493299, HITS@1: 0.415922, HITS@3: 0.460913, MR: 5267.437460, MRR: 0.444803


Metric: 0.004311:  12%|█▏        | 4999/40000 [59:27<7:28:31,  1.30it/s]  

HITS@10: 0.496809, HITS@1: 0.413050, HITS@3: 0.462987, MR: 5066.776005, MRR: 0.444453


Metric: 0.002574:  15%|█▍        | 5999/40000 [1:11:46<6:10:11,  1.53it/s]  

HITS@10: 0.498086, HITS@1: 0.413210, HITS@3: 0.462987, MR: 4950.202936, MRR: 0.445204


Metric: 0.001651:  17%|█▋        | 6999/40000 [1:23:55<6:02:08,  1.52it/s]  

HITS@10: 0.498564, HITS@1: 0.414646, HITS@3: 0.465380, MR: 5058.032068, MRR: 0.447478


Metric: 0.001199:  20%|█▉        | 7999/40000 [1:36:10<5:19:25,  1.67it/s]  

HITS@10: 0.499043, HITS@1: 0.415922, HITS@3: 0.466496, MR: 5152.536535, MRR: 0.448861


Metric: 0.001016:  22%|██▏       | 8999/40000 [1:48:26<5:20:14,  1.61it/s]  

HITS@10: 0.499362, HITS@1: 0.410976, HITS@3: 0.463625, MR: 5214.182993, MRR: 0.446302


Metric: 0.000981:  25%|██▍       | 9999/40000 [2:00:42<5:29:18,  1.52it/s]  

HITS@10: 0.508775, HITS@1: 0.411455, HITS@3: 0.462827, MR: 5242.070676, MRR: 0.446687


Metric: 0.001066:  27%|██▋       | 10999/40000 [2:12:53<4:50:15,  1.67it/s]  

HITS@10: 0.514199, HITS@1: 0.407307, HITS@3: 0.463784, MR: 5342.028398, MRR: 0.445512


Metric: 0.001068:  30%|██▉       | 11999/40000 [2:25:10<4:44:35,  1.64it/s]  

HITS@10: 0.517869, HITS@1: 0.405392, HITS@3: 0.469368, MR: 5427.585354, MRR: 0.447117


Metric: 0.001068:  32%|███▏      | 12999/40000 [2:37:23<5:24:24,  1.39it/s]  

HITS@10: 0.525048, HITS@1: 0.407307, HITS@3: 0.470325, MR: 5414.229419, MRR: 0.449145


Metric: 0.001056:  34%|███▍      | 13712/40000 [2:46:32<4:37:20,  1.58it/s]  

KeyboardInterrupt: 