In [2]:
# Install the library:
#!python setup.py install --user

In [3]:
from kdmkr import datasets
from kdmkr import distillation
from kdmkr import evaluation
from kdmkr import loss
from kdmkr import model

In [4]:
from creme import stats

In [5]:
import pickle 
import torch
import tqdm

In [6]:
device     = 'cuda'
hidden_dim = 1000

In [6]:
fb15k237 = datasets.FB15K237(
    batch_size=1024, 
    negative_sample_size=256, 
    shuffle=True, 
    seed=42
)

In [6]:
teacher = model.RotatE(
    hidden_dim=1000, 
    n_entity=fb15k237.n_entity, 
    n_relation=fb15k237.n_relation, 
    gamma=9
)
teacher = teacher.to(device)

In [None]:
optimizer_teacher = torch.optim.Adam(filter(lambda p: p.requires_grad, teacher.parameters()), lr = 0.00005)

max_step = 40000

bar = tqdm.tqdm(range(1, max_step), position=0)

metric = stats.RollingMean(1000)

evaluation = evaluation.Evaluation()

teacher.train()

for step in bar:
    
    optimizer_teacher.zero_grad()
    
    positive_sample, negative_sample, weight, mode = next(fb15k237)
    
    positive_sample = positive_sample.to(device)
    
    negative_sample = negative_sample.to(device)
    
    weight = weight.to(device)
    
    positive_score = teacher(sample=positive_sample)
    
    negative_score = teacher(sample=(positive_sample, negative_sample), mode=mode)
    
    loss_teacher = loss.Adversarial()(positive_score, negative_score, weight, alpha=0.5)
    
    loss_teacher.backward()
    
    optimizer_teacher.step()
    
    metric.update(loss_teacher.item())
    
    if step % 5 == 0:
        bar.set_description(f'Adversarial loss: {metric.get():6f}')
    
    if step % 2000 == 0:
        
        teacher = teacher.eval()
        
        score = evaluation(model=teacher, dataset=fb15k237.test_dataset(batch_size=8), device=device)
        
        teacher = teacher.train()
        
        print(score)
        
        # Set path HERE
        with open(f'./models/teacher_fb15k237_{score}.pickle', 'wb') as handle:
            
            pickle.dump(teacher, handle, protocol = pickle.HIGHEST_PROTOCOL)    

In [7]:
# Output of the previous training

teacher_name = 'teacher_fb15k237_HITS@10: 0.519716, HITS@1: 0.233778, HITS@3: 0.364971, MR: 192.433475, MRR: 0.329106.pickle'

with open(f'./models/{teacher_name}', 'rb') as handle:
    
    teacher = pickle.load(handle)    

In [None]:
import numpy as np
import copy

def init_tensor(batch_size_entity, head, relation, tail):
    x = torch.zeros((1, batch_size_entity, 3))
    x[:,:,0] = head
    x[:,:,1] = relation
    x[:,:,2] = tail
    return x

In [None]:
# Number of entities to consider to distill:
batch_size_entity = 20
batch_size_relation = 20

max_step = 40000

torch.manual_seed(42)
np.random.seed(42)

fb15k237 = datasets.FB15K237(
    batch_size=1000, 
    negative_sample_size=1, 
    seed=42, 
    shuffle=True
)

# Increasing the size of latents representations of the student allow to improve results.
student = model.RotatE(
    hidden_dim=1000, 
    n_entity=fb15k237.n_entity, 
    n_relation=fb15k237.n_relation, 
    gamma=0
)

student = student.to(device)

# Distillation process allow to handle different indexes between the student and the teacher.
# Distillation process allow to pre-compute batch dedicated to distillation.
distillation_process = distillation.Distillation(
    teacher_entities  = fb15k237.entities, 
    student_entities  = fb15k237.entities, 
    teacher_relations = fb15k237.relations, 
    student_relations = fb15k237.relations,
)

optimizer_student = torch.optim.Adam(
    filter(lambda p: p.requires_grad, student.parameters()), lr = 0.00005)

bar = tqdm.tqdm(range(1, max_step + 1), position=0)

# Creme online metric.
metric = stats.RollingMean(1000)

teacher = teacher.eval()
student = student.train()

for step in bar:
    
    positive_sample, negative_sample, weight, mode = next(fb15k237)
    
    batch_tensor_head     = []
    batch_tensor_relation = []
    batch_tensor_tail     = []

    entity_distribution = torch.tensor(
        np.random.randint(low=0, high=fb15k237.n_entity, size=batch_size_entity)
    ).view(1, batch_size_entity)
    
    relation_distribution = torch.tensor(
        np.random.randint(low=0, high=fb15k237.n_relation, size=batch_size_relation)
    ).view(1, batch_size_entity)

    optimizer_student.zero_grad()
    
    for head, relation, tail in positive_sample:
        
        head, relation, tail = head.item(), relation.item(), tail.item()
        
        head_distribution = copy.deepcopy(entity_distribution)
        head_distribution[0][0] = head
        
        tensor_head = init_tensor(
            head     = head_distribution, 
            relation = relation, 
            tail     = tail, 
            batch_size_entity = batch_size_entity
        )
        
        batch_tensor_head.append(tensor_head)
        
        tail_distribution = copy.deepcopy(entity_distribution)
        tail_distribution[0][0] = tail
        
        tensor_tail = init_tensor(
            head     = head, 
            relation = relation, 
            tail     = tail_distribution, 
            batch_size_entity = batch_size_entity
        )
        
        batch_tensor_tail.append(tensor_tail)
        
        relation_pre_distribution = copy.deepcopy(relation_distribution)
        relation_pre_distribution[0][0] = relation
        
        tensor_relation = init_tensor(
            head     = head, 
            relation = relation_pre_distribution, 
            tail     = tail, 
            batch_size_entity = batch_size_relation,
        )
        
        batch_tensor_relation.append(tensor_relation)
        
    # Create tensor [[(e_1, r_1, t_1),..(en, r_1, t_1)], ..,[(e1, r_3, t_2),..(e_n, r_3, t_2)]]
    # Tensor of size (batch_size, number of entity needed to compute the distribution probability, 3)
    teacher_head_tensor = torch.stack(batch_tensor_head).reshape(len(batch_tensor_head), batch_size_entity, 3).to(device=device, dtype=int)
    student_head_tensor = torch.stack(batch_tensor_head).reshape(len(batch_tensor_head), batch_size_entity, 3).to(device=device, dtype=int)
    
    teacher_relation_tensor = torch.stack(batch_tensor_relation).reshape(len(batch_tensor_relation), batch_size_entity, 3).to(device=device, dtype=int)
    student_relation_tensor = torch.stack(batch_tensor_relation).reshape(len(batch_tensor_relation), batch_size_entity, 3).to(device=device, dtype=int)
    
    teacher_tail_tensor = torch.stack(batch_tensor_tail).reshape(len(batch_tensor_tail), batch_size_entity, 3).to(device=device, dtype=int)
    student_tail_tensor = torch.stack(batch_tensor_tail).reshape(len(batch_tensor_tail), batch_size_entity, 3).to(device=device, dtype=int)
  
    # Distillation loss of heads
    loss_head = loss.KlDivergence()(
        teacher_score=teacher.distill(teacher_head_tensor), 
        student_score=student.distill(student_head_tensor)
    ) 
    
    # Distillation loss of relations.
    loss_relation = loss.KlDivergence()(
        teacher_score=teacher.distill(teacher_relation_tensor), 
        student_score=student.distill(student_relation_tensor)
    ) 
    
    # Distillation loss of tails.
    loss_tail = loss.KlDivergence()(
        teacher_score=teacher.distill(teacher_tail_tensor), 
        student_score=student.distill(student_tail_tensor)
    ) 
    
    # The loss of the student is equal to the sum of all losses.
    loss_student = loss_head + loss_relation + loss_tail
    
    metric.update(loss_student.item())
    
    loss_student.backward()

    optimizer_student.step()
    
    if step % 5 == 0:
    
        bar.set_description(f'Metric: {metric.get():6f}')
    
    if step % 1000 == 0:
        
        student = student.eval()
        
        score = evaluation.Evaluation()(model=student, dataset=fb15k237.test_dataset(batch_size=8), device=device)
        
        print(score)
        
        with open(f'./models/student_fb15k237_{score}.pickle', 'wb') as handle:
            
            pickle.dump(student, handle, protocol = pickle.HIGHEST_PROTOCOL)    
        
        student = student.train()

Metric: 0.171015:   2%|▏         | 999/40000 [06:45<4:25:28,  2.45it/s]

HITS@10: 0.332967, HITS@1: 0.129850, HITS@3: 0.215846, MR: 447.010554, MRR: 0.198185


Metric: 0.042698:   5%|▍         | 1999/40000 [17:10<4:16:00,  2.47it/s]  

HITS@10: 0.362577, HITS@1: 0.139817, HITS@3: 0.231872, MR: 318.034447, MRR: 0.214035


Metric: 0.024577:   7%|▋         | 2999/40000 [27:37<4:01:18,  2.56it/s]  

HITS@10: 0.391625, HITS@1: 0.152057, HITS@3: 0.253029, MR: 272.911365, MRR: 0.231538


Metric: 0.016539:  10%|▉         | 3999/40000 [37:56<4:06:58,  2.43it/s]  

HITS@10: 0.413393, HITS@1: 0.157676, HITS@3: 0.266833, MR: 251.435014, MRR: 0.242154


Metric: 0.011931:  12%|█▏        | 4999/40000 [48:21<3:50:06,  2.54it/s]  

HITS@10: 0.436187, HITS@1: 0.174485, HITS@3: 0.286695, MR: 236.269789, MRR: 0.260792


Metric: 0.008846:  15%|█▍        | 5999/40000 [58:47<3:58:51,  2.37it/s]  

HITS@10: 0.444249, HITS@1: 0.178174, HITS@3: 0.294953, MR: 225.818211, MRR: 0.266622


Metric: 0.006723:  17%|█▋        | 6999/40000 [1:09:14<3:40:40,  2.49it/s]  

HITS@10: 0.459567, HITS@1: 0.191049, HITS@3: 0.310246, MR: 217.258575, MRR: 0.280450


Metric: 0.005277:  20%|█▉        | 7999/40000 [1:19:41<3:38:02,  2.45it/s]  

HITS@10: 0.471196, HITS@1: 0.198818, HITS@3: 0.320727, MR: 212.727744, MRR: 0.289335


Metric: 0.004231:  22%|██▏       | 8999/40000 [1:30:08<3:26:15,  2.51it/s]  

HITS@10: 0.477963, HITS@1: 0.201285, HITS@3: 0.326029, MR: 214.279390, MRR: 0.293161


Metric: 0.003501:  25%|██▍       | 9999/40000 [1:40:36<3:23:31,  2.46it/s]  

HITS@10: 0.486343, HITS@1: 0.207613, HITS@3: 0.333162, MR: 210.750733, MRR: 0.300148


Metric: 0.003016:  27%|██▋       | 10999/40000 [1:51:02<3:17:14,  2.45it/s]  

HITS@10: 0.490790, HITS@1: 0.210642, HITS@3: 0.337487, MR: 209.026972, MRR: 0.303541


Metric: 0.002590:  30%|██▉       | 11999/40000 [2:01:30<3:07:46,  2.49it/s]  

HITS@10: 0.491865, HITS@1: 0.211106, HITS@3: 0.337267, MR: 209.040311, MRR: 0.304476


Metric: 0.002337:  32%|███▏      | 12999/40000 [2:11:54<3:02:43,  2.46it/s]  

HITS@10: 0.495773, HITS@1: 0.215406, HITS@3: 0.342080, MR: 205.032688, MRR: 0.308445


Metric: 0.002109:  35%|███▍      | 13999/40000 [2:22:20<2:54:04,  2.49it/s]  

HITS@10: 0.504178, HITS@1: 0.216603, HITS@3: 0.346892, MR: 203.544098, MRR: 0.311779


Metric: 0.001961:  37%|███▋      | 14999/40000 [2:32:46<2:49:53,  2.45it/s]  

HITS@10: 0.504788, HITS@1: 0.219217, HITS@3: 0.349311, MR: 204.286671, MRR: 0.313773


Metric: 0.001830:  40%|███▉      | 15999/40000 [2:43:13<2:38:36,  2.52it/s]  

HITS@10: 0.503811, HITS@1: 0.218900, HITS@3: 0.348920, MR: 204.092299, MRR: 0.313747


Metric: 0.001746:  42%|████▏     | 16999/40000 [2:53:40<2:37:23,  2.44it/s]  

HITS@10: 0.505668, HITS@1: 0.220268, HITS@3: 0.349189, MR: 201.964258, MRR: 0.314533


Metric: 0.001658:  45%|████▍     | 17999/40000 [3:04:00<2:27:20,  2.49it/s]  

HITS@10: 0.506205, HITS@1: 0.222027, HITS@3: 0.353904, MR: 202.229234, MRR: 0.316922


Metric: 0.001574:  47%|████▋     | 18999/40000 [3:14:24<2:21:48,  2.47it/s]  

HITS@10: 0.506914, HITS@1: 0.219413, HITS@3: 0.350386, MR: 202.682571, MRR: 0.314565


Metric: 0.001551:  50%|████▉     | 19999/40000 [3:24:47<2:16:50,  2.44it/s]  

HITS@10: 0.505008, HITS@1: 0.218753, HITS@3: 0.348822, MR: 201.537257, MRR: 0.313770


Metric: 0.001535:  52%|█████▏    | 20999/40000 [3:35:12<2:04:23,  2.55it/s]  

HITS@10: 0.509186, HITS@1: 0.222051, HITS@3: 0.351876, MR: 202.310442, MRR: 0.317018


Metric: 0.001474:  55%|█████▍    | 21999/40000 [3:45:37<2:01:43,  2.46it/s]  

HITS@10: 0.506694, HITS@1: 0.218435, HITS@3: 0.350044, MR: 203.966261, MRR: 0.314046


Metric: 0.001452:  57%|█████▋    | 22999/40000 [3:56:01<1:54:52,  2.47it/s]  

HITS@10: 0.506425, HITS@1: 0.221025, HITS@3: 0.350191, MR: 205.250293, MRR: 0.315864


Metric: 0.001431:  60%|█████▉    | 23999/40000 [4:06:24<1:44:50,  2.54it/s]  

HITS@10: 0.508771, HITS@1: 0.222369, HITS@3: 0.353196, MR: 205.137667, MRR: 0.317222


Metric: 0.001401:  62%|██████▏   | 24999/40000 [4:16:50<1:42:13,  2.45it/s]  

HITS@10: 0.508013, HITS@1: 0.221514, HITS@3: 0.351143, MR: 205.272256, MRR: 0.316201


Metric: 0.001395:  65%|██████▍   | 25999/40000 [4:27:16<1:32:11,  2.53it/s]  

HITS@10: 0.510554, HITS@1: 0.222931, HITS@3: 0.354417, MR: 202.126869, MRR: 0.317738


Metric: 0.001375:  67%|██████▋   | 26999/40000 [4:37:40<1:28:51,  2.44it/s]  