In [1]:
# Install the library:
#!python setup.py install --user

In [1]:
from kdmkr import datasets
from kdmkr import distillation
from kdmkr import evaluation
from kdmkr import loss
from kdmkr import model

In [2]:
from creme import stats

In [3]:
import pickle 
import torch
import tqdm

In [4]:
device     = 'cuda'
hidden_dim = 500
batch_size = 512

In [6]:
wn18rr = datasets.WN18RR(
    batch_size=512, 
    negative_sample_size=1024, 
    shuffle=True, 
    seed=42
)

In [6]:
teacher = model.RotatE(
    hidden_dim=500, 
    n_entity=wn18rr.n_entity, 
    n_relation=wn18rr.n_relation, 
    gamma=6
)
teacher = teacher.to(
    device)

In [None]:
optimizer_teacher = torch.optim.Adam(filter(lambda p: p.requires_grad, teacher.parameters()), lr = 0.00005)

max_step = 6000

bar = tqdm.tqdm(range(1, max_step), position=0)

metric = stats.RollingMean(1000)

evaluation = evaluation.Evaluation()

teacher.train()

for step in bar:
    
    optimizer_teacher.zero_grad()
    
    positive_sample, negative_sample, weight, mode = next(wn18rr)
    
    positive_sample = positive_sample.to(device)
    
    negative_sample = negative_sample.to(device)
    
    weight = weight.to(device)
    
    positive_score = teacher(sample=positive_sample)
    
    negative_score = teacher(sample=(positive_sample, negative_sample), mode=mode)
    
    loss_teacher = loss.Adversarial()(positive_score, negative_score, weight, alpha=0.5)
    
    loss_teacher.backward()
    
    optimizer_teacher.step()
    
    metric.update(loss_teacher.item())
    
    if step % 5 == 0:
        bar.set_description(f'Adversarial loss: {metric.get():6f}')
    
    if step % 2000 == 0:
        
        teacher = teacher.eval()
        
        score = evaluation(model=teacher, dataset=wn18rr.test_dataset(batch_size=8), device=device)
        
        teacher = teacher.train()
        
        print(score)
        
        # Set path HERE
        with open(f'./models/teacher_wn18rr_{score}.pickle', 'wb') as handle:
            
            pickle.dump(teacher, handle, protocol = pickle.HIGHEST_PROTOCOL)    

In [5]:
# Output of the previous training

teacher_name = 'teacher_wn18rr_HITS@10: 0.527760, HITS@1: 0.419113, HITS@3: 0.477664, MR: 5509.747288, MRR: 0.457429.pickle'

with open(f'./models/{teacher_name}', 'rb') as handle:
    
    teacher = pickle.load(handle)    

In [None]:
student_batch_size = 1000

# Number of entities to consider to distill:
batch_size_entity = 10

max_step = 40000

torch.manual_seed(42)

wn18rr = datasets.WN18RR(
    batch_size=1000, 
    negative_sample_size=1, 
    seed=42, 
    shuffle=True
)

# Increasing the size of latents representations of the student allow to improve results.
student = model.RotatE(
    hidden_dim=1000, 
    n_entity=wn18rr.n_entity, 
    n_relation=wn18rr.n_relation, 
    gamma=0
)

student = student.to(device)

# Distillation process allow to handle different indexes between the student and the teacher.
# Distillation process allow to pre-compute batch dedicated to distillation.
distillation_process = distillation.Distillation(
    teacher_entities  = wn18rr.entities, 
    student_entities  = wn18rr.entities, 
    teacher_relations = wn18rr.relations, 
    student_relations = wn18rr.relations,
)

optimizer_student = torch.optim.Adam(
    filter(lambda p: p.requires_grad, student.parameters()), lr = 0.00005)

bar = tqdm.tqdm(range(1, max_step + 1), position=0)

# Creme online metric.
metric = stats.RollingMean(1000)

teacher = teacher.eval()
student = student.train()

for step in bar:
    
    positive_sample, negative_sample, weight, mode = next(wn18rr)
    
    teacher_head = []
    student_head = []
    teacher_relation = []
    student_relation = []
    teacher_tail = []
    student_tail = []

    optimizer_student.zero_grad()
    
    for head, relation, tail in positive_sample:
        
        head, relation, tail = head.item(), relation.item(), tail.item()
        
        # Constructing the tensor dedicated to distillation of HEAD:
        teacher_common_head_sample, _ = distillation_process.mini_batch_teacher_head(
            relation=relation, tail=tail)
        
        # Distillation class make the mapping between entities indexes of the 
        # student and teacher. For a single sample with relation r_1 and tail t_1, the method below
        # return a tensor [(e_1, r_1, t_1), (e_2, r_1, t_1), ..(e_n, r_1, t_1)] of shape (1, n_entity, 3)
        student_common_head_sample, _ = distillation_process.mini_batch_student_head(
            teacher_relation=relation, teacher_tail=tail)
        
        # Limiting the number of sample to distill entities
        # Supposing the ground truth is (e3, r1, t4)
        # Instead of construction [(e1, r1, t4), (e2, r1, t4), ..(en, r1, t4)]
        # We construct [(e3, r1, t4), (e4, r1, t4), ..(e3+batch_size_entity, r1, t4)]
        teacher_common_head_sample = teacher_common_head_sample[:,head:head+batch_size_entity]
        student_common_head_sample = student_common_head_sample[:,head:head+batch_size_entity]
        
        # It does not cost a lot of memory to store indexes. We store all indexes to compute scores 
        # once for the current batch to speed up the training processus.
        teacher_head.append(teacher_common_head_sample)
        student_head.append(student_common_head_sample)
        
        # Constructing the tensor dedicated to distillation of relation:
        teacher_common_relation_sample, _ = distillation_process.mini_batch_teacher_relation(
            head=head, tail=tail)

        student_common_relation_sample, _ = distillation_process.mini_batch_student_relation(
            teacher_head=head, teacher_tail=tail)

        teacher_relation.append(teacher_common_relation_sample)
        student_relation.append(student_common_relation_sample)
        
        # Constructing the tensor dedicated to distillation of TAIL:
        teacher_common_tail_sample, _ = distillation_process.mini_batch_teacher_tail(
            relation=relation, head=head)
        
        # tensor [(e_1, r_1, t_1), (e_1, r_1, t_2), ..(e_1, r_1, t_n)] of shape (1, n_entity, 3)
        student_common_tail_sample, _ = distillation_process.mini_batch_student_tail(
            teacher_relation=relation, teacher_head=head)
        
        teacher_common_tail_sample = teacher_common_tail_sample[:,tail:tail+batch_size_entity]
        student_common_tail_sample = student_common_tail_sample[:,tail:tail+batch_size_entity]
        
        teacher_tail.append(teacher_common_tail_sample)
        student_tail.append(student_common_tail_sample)
    
    
    # Create tensor [[(e_1, r_1, t_1),..(en, r_1, t_1)], ..,[(e1, r_3, t_2),..(e_n, r_3, t_2)]]
    # Tensor of size (batch_size, number of entity needed to compute the distribution probability, 3)
    teacher_head_tensor = torch.stack(teacher_head).reshape(len(teacher_head), batch_size_entity, 3).to(device)
    student_head_tensor = torch.stack(student_head).reshape(len(student_head), batch_size_entity, 3).to(device)
    
    teacher_relation_tensor = torch.stack(teacher_relation).reshape(len(teacher_relation), wn18rr.n_relation, 3).to(device)
    student_relation_tensor = torch.stack(student_relation).reshape(len(student_relation), wn18rr.n_relation, 3).to(device)
    
    teacher_tail_tensor = torch.stack(teacher_tail).reshape(len(teacher_tail), batch_size_entity, 3).to(device)
    student_tail_tensor = torch.stack(student_tail).reshape(len(student_tail), batch_size_entity, 3).to(device)
  
    # Distillation loss of heads
    loss_head = loss.KlDivergence()(
        teacher_score=teacher.distill(teacher_head_tensor), 
        student_score=student.distill(student_head_tensor)
    ) 
    
    # Distillation loss of relations.
    loss_relation = loss.KlDivergence()(
        teacher_score=teacher.distill(teacher_relation_tensor), 
        student_score=student.distill(student_relation_tensor)
    ) 
    
    # Distillation loss of tails.
    loss_tail = loss.KlDivergence()(
        teacher_score=teacher.distill(teacher_tail_tensor), 
        student_score=student.distill(student_tail_tensor)
    ) 
    
    # The loss of the student is equal to the sum of all losses.
    loss_student = loss_head + loss_relation + loss_tail
    
    metric.update(loss_student.item())
    
    loss_student.backward()

    optimizer_student.step()
    
    if step % 5 == 0:
    
        bar.set_description(f'Metric: {metric.get():6f}')
    
    if step % 1000 == 0:
        
        student = student.eval()
        
        score = evaluation.Evaluation()(model=student, dataset=wn18rr.test_dataset(batch_size=2), device=device)
        
        print(score)
        
        with open(f'./models/student_wn18rr_{score}.pickle', 'wb') as handle:
            
            pickle.dump(student, handle, protocol = pickle.HIGHEST_PROTOCOL)    
        
        student = student.train()