In [36]:
#!python setup.py install --user

In [28]:
from kdmkr import datasets
from kdmkr import distillation
from kdmkr import evaluation
from kdmkr import loss
from kdmkr import model

In [29]:
from creme import stats

In [30]:
import torch
import tqdm

In [31]:
device = 'cpu'
hidden_dim = 5
batch_size = 3 # 512

In [35]:
wn18rr = datasets.WN18RR(batch_size=batch_size, negative_sample_size=batch_size*2, seed=42)

In [33]:
teacher = model.RotatE(hidden_dim=hidden_dim, n_entity=wn18rr.n_entity, n_relation=wn18rr.n_relation, gamma=6)
teacher = teacher.to(device)

In [34]:
distillation = distillation.Distillation(
    teacher_entities  = wn18rr.entities, 
    student_entities  = wn18rr.entities, 
    teacher_relations = wn18rr.relations, 
    student_relations = wn18rr.relations
)

In [17]:
optimizer_teacher = torch.optim.Adam(filter(lambda p: p.requires_grad, teacher.parameters()), lr = 0.0001)

max_step = 1000

bar = tqdm.tqdm(1, range(max_step), position=0)

metric = stats.RollingMean(1000)
evaluation = evaluation.Evaluation()

teacher.train()

for step in bar:
    
    optimizer_teacher.zero_grad()
    
    positive_sample, negative_sample, weight, mode = next(wn18rr)
    
    positive_sample = positive_sample.to(device)
    negative_sample = negative_sample.to(device)
    weight = weight.to(device)
    
    positive_score = teacher(sample=positive_sample)
    
    negative_score = teacher(sample=(positive_sample, negative_sample), mode=mode)
    
    loss_teacher = loss.Adversarial()(positive_score, negative_score, weight, alpha=1.5)
    
    loss_teacher.backward()
    
    optimizer_teacher.step()
    
    metric.update(loss_teacher.cpu())
    
    bar.set_description(f'Metric: {metric.get():6f}')
    
    if step % 500 == 0:
        
        evaluation(model=teacher, dataset=wn18rr.test_dataset(batch_size=8), device=device)

Metric: 2.576545: 100%|██████████| 1000/1000 [00:13<00:00, 75.96it/s]


In [18]:
student = model.RotatE(hidden_dim=hidden_dim, n_entity=wn18rr.n_entity, n_relation=wn18rr.n_relation, gamma=6)
student = student.to(device)

In [19]:
optimizer_student = torch.optim.Adam(filter(lambda p: p.requires_grad, student.parameters()), lr = 0.0001)

max_step = 1000

bar = tqdm.tqdm(range(1, max_step), position=0)

metric = stats.RollingMean(1000)

teacher.eval()

student.train()

kl_divergence = loss.KlDivergence()

for step in bar:
    
    positive_sample, _, _, _ = next(wn18rr)

    for head, relation, tail in positive_sample:
        
        loss_student = 0
        
        optimizer_student.zero_grad()

        head, relation, tail = head.item(), relation.item(), tail.item()

        mode = distillation.distillation_mode(head=head, relation=relation, tail=tail)
        
        # TODO: Handle distinct entities
        
        if mode['head']:

            teacher_common_head_sample, teacher_distinct_head_sample = distillation.mini_batch_teacher_head(
                relation=relation, tail=tail)

            student_common_head_sample, student_distinct_head_sample = distillation.mini_batch_student_head(
                teacher_relation=relation, teacher_tail=tail)
            
            teacher_common_head_sample.to(device)
            student_common_head_sample.to(device)
            
            loss_student += kl_divergence(
                teacher_score=teacher.distill(teacher_common_head_sample), 
                student_score=student.distill(student_common_head_sample)
            )

        if mode['relation']:

            teacher_common_relation_sample, teacher_distinct_relation_sample = distillation.mini_batch_teacher_relation(
                head=head, tail=tail)

            student_common_relation_sample, student_distinct_relation_sample = distillation.mini_batch_student_relation(
                teacher_head=head, teacher_tail=tail)
            
            teacher_common_relation_sample.to(device)
            student_common_relation_sample.to(device)

            loss_student += kl_divergence(
                teacher_score=teacher.distill(teacher_common_relation_sample), 
                student_score=student.distill(student_common_relation_sample)
            )

        if mode['tail']:

            teacher_common_tail_sample, teacher_distinct_tail_sample = distillation.mini_batch_teacher_tail(
                head=head, relation=relation)

            student_common_tail_sample, student_distinct_tail_sample = distillation.mini_batch_student_tail(
                teacher_head=head, teacher_relation=relation)
            
            teacher_common_tail_sample.to(device)
            student_common_tail_sample.to(device)
            
            loss_student += kl_divergence(
                teacher_score=teacher.distill(teacher_common_tail_sample), 
                student_score=student.distill(student_common_tail_sample)
            ) 
        
        loss_student.backward()

        optimizer_student.step()
    
        metric.update(loss_student.cpu())
    
    bar.set_description(f'Metric: {metric.get():6f}')
    
    if step % 500 == 0:
        
        evaluation(model=teacher, dataset=wn18rr.test_dataset(batch_size=8), device=device)

Metric: 0.153018:  16%|█▌        | 162/1000 [03:49<19:26,  1.39s/it]

KeyboardInterrupt: 