In [1]:
# Install the library:
#!python setup.py install --user

In [2]:
from kdmkr import datasets
from kdmkr import distillation
from kdmkr import evaluation
from kdmkr import loss
from kdmkr import model

In [3]:
from creme import stats

In [4]:
import pickle 
import torch
import tqdm

In [5]:
device     = 'cuda'
hidden_dim = 500
batch_size = 512

In [6]:
wn18rr = datasets.WN18RR(
    batch_size=512, 
    negative_sample_size=1024, 
    shuffle=True, 
    seed=42
)

In [6]:
teacher = model.RotatE(
    hidden_dim=500, 
    n_entity=wn18rr.n_entity, 
    n_relation=wn18rr.n_relation, 
    gamma=6
)
teacher = teacher.to(
    device)

In [None]:
optimizer_teacher = torch.optim.Adam(filter(lambda p: p.requires_grad, teacher.parameters()), lr = 0.00005)

max_step = 6000

bar = tqdm.tqdm(range(1, max_step), position=0)

metric = stats.RollingMean(1000)

evaluation = evaluation.Evaluation()

teacher.train()

for step in bar:
    
    optimizer_teacher.zero_grad()
    
    positive_sample, negative_sample, weight, mode = next(wn18rr)
    
    positive_sample = positive_sample.to(device)
    
    negative_sample = negative_sample.to(device)
    
    weight = weight.to(device)
    
    positive_score = teacher(sample=positive_sample)
    
    negative_score = teacher(sample=(positive_sample, negative_sample), mode=mode)
    
    loss_teacher = loss.Adversarial()(positive_score, negative_score, weight, alpha=0.5)
    
    loss_teacher.backward()
    
    optimizer_teacher.step()
    
    metric.update(loss_teacher.item())
    
    if step % 5 == 0:
        bar.set_description(f'Adversarial loss: {metric.get():6f}')
    
    if step % 2000 == 0:
        
        teacher = teacher.eval()
        
        score = evaluation(model=teacher, dataset=wn18rr.test_dataset(batch_size=8), device=device)
        
        teacher = teacher.train()
        
        print(score)
        
        # Set path HERE
        with open(f'./models/teacher_wn18rr_{score}.pickle', 'wb') as handle:
            
            pickle.dump(teacher, handle, protocol = pickle.HIGHEST_PROTOCOL)    

In [7]:
# Output of the previous training

teacher_name = 'teacher_wn18rr_HITS@10: 0.527760, HITS@1: 0.419113, HITS@3: 0.477664, MR: 5509.747288, MRR: 0.457429.pickle'

with open(f'./models/{teacher_name}', 'rb') as handle:
    
    teacher = pickle.load(handle)    

In [12]:
import numpy as np
import copy

def init_tensor(batch_size_entity, head, relation, tail):
    x = torch.zeros((1, batch_size_entity, 3))
    x[:,:,0] = head
    x[:,:,1] = relation
    x[:,:,2] = tail
    return x

In [None]:
student_batch_size = 1000

# Number of entities to consider to distill:
batch_size_entity = 20

max_step = 40000

torch.manual_seed(42)
np.random.seed(42)

wn18rr = datasets.WN18RR(
    batch_size=1000, 
    negative_sample_size=1, 
    seed=42, 
    shuffle=True
)

# Increasing the size of latents representations of the student allow to improve results.
student = model.RotatE(
    hidden_dim=1000, 
    n_entity=wn18rr.n_entity, 
    n_relation=wn18rr.n_relation, 
    gamma=0
)

student = student.to(device)

# Distillation process allow to handle different indexes between the student and the teacher.
# Distillation process allow to pre-compute batch dedicated to distillation.
distillation_process = distillation.Distillation(
    teacher_entities  = wn18rr.entities, 
    student_entities  = wn18rr.entities, 
    teacher_relations = wn18rr.relations, 
    student_relations = wn18rr.relations,
)

optimizer_student = torch.optim.Adam(
    filter(lambda p: p.requires_grad, student.parameters()), lr = 0.00005)

bar = tqdm.tqdm(range(1, max_step + 1), position=0)

# Creme online metric.
metric = stats.RollingMean(1000)

teacher = teacher.eval()
student = student.train()

for step in bar:
    
    positive_sample, negative_sample, weight, mode = next(wn18rr)
    
    batch_tensor_head = []
    teacher_relation  = []
    student_relation  = []
    batch_tensor_tail = []

    entity_distribution = torch.tensor(
        np.random.randint(low=0, high=wn18rr.n_entity, size=batch_size_entity)
    ).view(1, batch_size_entity)

    optimizer_student.zero_grad()
    
    for head, relation, tail in positive_sample:
        
        head, relation, tail = head.item(), relation.item(), tail.item()
        
        head_distribution = copy.deepcopy(entity_distribution)
        head_distribution[0][0] = head
        
        tensor_head = init_tensor(
            head     = head_distribution, 
            relation = relation, 
            tail     = tail, 
            batch_size_entity = batch_size_entity
        )
        
        tail_distribution = copy.deepcopy(entity_distribution)
        tail_distribution[0][0] = tail
        
        tensor_tail = init_tensor(
            head     = head, 
            relation = relation, 
            tail     = tail_distribution, 
            batch_size_entity = batch_size_entity
        )
        
        batch_tensor_head.append(tensor_head)
        batch_tensor_tail.append(tensor_tail)
        
        # Constructing the tensor dedicated to distillation of relation:
        teacher_common_relation_sample, _ = distillation_process.mini_batch_teacher_relation(
            head=head, tail=tail)

        student_common_relation_sample, _ = distillation_process.mini_batch_student_relation(
            teacher_head=head, teacher_tail=tail)

        teacher_relation.append(teacher_common_relation_sample)
        student_relation.append(student_common_relation_sample)
        
    # Create tensor [[(e_1, r_1, t_1),..(en, r_1, t_1)], ..,[(e1, r_3, t_2),..(e_n, r_3, t_2)]]
    # Tensor of size (batch_size, number of entity needed to compute the distribution probability, 3)
    teacher_head_tensor = torch.stack(batch_tensor_head).reshape(len(batch_tensor_head), batch_size_entity, 3).to(device=device, dtype=int)
    student_head_tensor = torch.stack(batch_tensor_head).reshape(len(batch_tensor_head), batch_size_entity, 3).to(device=device, dtype=int)
    
    teacher_relation_tensor = torch.stack(teacher_relation).reshape(len(teacher_relation), wn18rr.n_relation, 3).to(device=device, dtype=int)
    student_relation_tensor = torch.stack(student_relation).reshape(len(student_relation), wn18rr.n_relation, 3).to(device=device, dtype=int)
    
    teacher_tail_tensor = torch.stack(batch_tensor_tail).reshape(len(batch_tensor_tail), batch_size_entity, 3).to(device=device, dtype=int)
    student_tail_tensor = torch.stack(batch_tensor_tail).reshape(len(batch_tensor_tail), batch_size_entity, 3).to(device=device, dtype=int)
  
    # Distillation loss of heads
    loss_head = loss.KlDivergence()(
        teacher_score=teacher.distill(teacher_head_tensor), 
        student_score=student.distill(student_head_tensor)
    ) 
    
    # Distillation loss of relations.
    loss_relation = loss.KlDivergence()(
        teacher_score=teacher.distill(teacher_relation_tensor), 
        student_score=student.distill(student_relation_tensor)
    ) 
    
    # Distillation loss of tails.
    loss_tail = loss.KlDivergence()(
        teacher_score=teacher.distill(teacher_tail_tensor), 
        student_score=student.distill(student_tail_tensor)
    ) 
    
    # The loss of the student is equal to the sum of all losses.
    loss_student = loss_head + loss_relation + loss_tail
    
    metric.update(loss_student.item())
    
    loss_student.backward()

    optimizer_student.step()
    
    if step % 5 == 0:
    
        bar.set_description(f'Metric: {metric.get():6f}')
    
    if step % 1000 == 0:
        
        student = student.eval()
        
        score = evaluation.Evaluation()(model=student, dataset=wn18rr.test_dataset(batch_size=8), device=device)
        
        print(score)
        
        with open(f'./models/student_wn18rr_{score}.pickle', 'wb') as handle:
            
            pickle.dump(student, handle, protocol = pickle.HIGHEST_PROTOCOL)    
        
        student = student.train()

Metric: 0.198259:   2%|▏         | 999/40000 [06:27<4:10:35,  2.59it/s]

HITS@10: 0.463944, HITS@1: 0.370932, HITS@3: 0.432355, MR: 6846.283982, MRR: 0.406958


Metric: 0.039777:   5%|▍         | 1999/40000 [14:17<4:01:37,  2.62it/s]  

HITS@10: 0.478622, HITS@1: 0.398213, HITS@3: 0.449426, MR: 6146.584397, MRR: 0.429363


Metric: 0.014932:   7%|▋         | 2999/40000 [22:05<3:49:00,  2.69it/s]  

HITS@10: 0.487556, HITS@1: 0.409860, HITS@3: 0.455648, MR: 5543.849234, MRR: 0.438895


Metric: 0.007452:  10%|▉         | 3999/40000 [29:35<3:36:52,  2.77it/s]  

HITS@10: 0.493618, HITS@1: 0.412412, HITS@3: 0.460753, MR: 5318.804244, MRR: 0.442989


Metric: 0.004285:  12%|█▏        | 4999/40000 [37:11<3:28:44,  2.79it/s]  

HITS@10: 0.493140, HITS@1: 0.413210, HITS@3: 0.461391, MR: 5185.552489, MRR: 0.443610


Metric: 0.004266:  13%|█▎        | 5010/40000 [38:36<10:25:28,  1.07s/it] 