In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

In [2]:
from teacher            import Teacher
from loss               import LogisticLoss
from loss               import DistillationLoss
from benchmark          import benchmark
from template_dataset   import Dataset

In [3]:
from itertools import chain

In [4]:
# https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding
import sys
sys.path.append('/users/iris/rsourty/github/KnowledgeGraphEmbedding/codes')
from dataloader import TrainDataset
from dataloader import TestDataset
from dataloader import BidirectionalOneShotIterator

In [5]:
from creme            import stats
from torch.utils.data import DataLoader

In [6]:
import torch
import tqdm
import pickle

import numpy as np

In [7]:
# Choose gpu
torch.manual_seed(0)
torch.cuda.set_device(1)

Load each knowledge bases

In [8]:
list_dataset = []

for dataset_name in ['teacher_1', 'teacher_2', 'teacher_3']:
    
    with open(f'/users/iris/rsourty/experiments/distillation/datasets/{dataset_name}.pickle', 'rb') as handle:   

        list_dataset.append(pickle.load(handle))

In [9]:
def get_training_generator(train, n_entity, n_relation, batch_size, negative_sample_size, cpu_num = 1):

        head_loader = DataLoader(
            TrainDataset(train, n_entity, n_relation, negative_sample_size, 'head-batch'),
            batch_size  = batch_size,
            shuffle     = True,
            num_workers = max(1, cpu_num//2),
            collate_fn  = TrainDataset.collate_fn
        )

        tail_loader = DataLoader(
            TrainDataset(train, n_entity, n_relation, negative_sample_size, 'tail-batch'),
            batch_size  = batch_size,
            shuffle     = True,
            num_workers = max(1, cpu_num//2),
            collate_fn  = TrainDataset.collate_fn
        )

        return BidirectionalOneShotIterator(head_loader, tail_loader)

In [10]:
def get_multiples_training_genarator(list_dataset, batch_size, n_entity, n_relation, negative_sample_size):
    return [get_training_generator(dataset.get_train(), n_entity, n_relation, batch_size, negative_sample_size) for dataset in list_dataset]

In [11]:
def get_testing_generator(train_test_valid, valid, n_entity, n_relation, batch_size, cpu_num):
    
    head_loader = DataLoader(
            TestDataset(
                valid,
                train_test_valid,
                n_entity,
                n_relation,
                'head-batch'
            ),
            batch_size  = batch_size,
            num_workers = cpu_num,
            collate_fn  = TestDataset.collate_fn
        )

    tail_loader = DataLoader(
        TestDataset(
            valid,
            train_test_valid,
            n_entity,
            n_relation,
            'tail-batch'
        ),
        batch_size  = batch_size,
        num_workers = cpu_num,
        collate_fn  = TestDataset.collate_fn,
    )  

    return head_loader, tail_loader

In [12]:
TEST_BATCH_SIZE      = 10

In [13]:
GAMMA                = 6
ALPHA                = 0.5
LEARNING_RATE        = 0.0005
MAX_STEPS            = 10000
VALID_STEPS          = 1000
WARM_UP_STEPS        = MAX_STEPS // 2
batch_size = 512
hidden_dim = 400
negative_sample_size = batch_size * 2

In [14]:
def get_distillation_sample(positive_sample):
    list_sample = []
    for _, idx, _ in positive_sample:
        x = torch.zeros(11) + 1e-5
        x[idx] = 1.
        list_sample.append(x)
    return torch.stack(list_sample).cuda()

In [22]:
torch.manual_seed(42)

list_teacher     = [] 
dic_optimizer    = {}
dic_optimizer_kd = {}

for idx, dataset in enumerate(list_dataset):
    
    n_entity, n_relation = dataset.get_metadata()
    
    model = Teacher(
        model      = 'RotatE', 
        n_entity   = n_entity, 
        n_relation = n_relation, 
        hidden_dim = hidden_dim, 
        gamma      = 6, 
        batch_size = batch_size,
    )

    model = model.cuda()
    
    list_teacher.append(model)

    dic_optimizer_kd[idx] = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr = 0.0005,
    )
    
    dic_optimizer[idx] = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr = 0.0005,
    )

In [23]:
list_generator = get_multiples_training_genarator(
    list_dataset, 
    batch_size = batch_size, 
    n_entity   = n_entity, 
    n_relation = n_relation, 
    negative_sample_size = negative_sample_size,
)

progress_bar  = tqdm.tqdm(range(MAX_STEPS), position=0)

metrics = {idx: stats.RollingMean(100) for idx, _ in enumerate(list_teacher)}

beta = np.array([i for i in range(10000)])

beta = beta / max(beta)

for step in progress_bar:
    
    sample_dic   = {}
    loss_teacher = {}
    
    for idx_teacher, (generator, teacher) in enumerate(zip(list_generator, list_teacher)):
        
        dic_optimizer[idx_teacher].zero_grad()
        teacher.train()
        
        positive_sample, negative_sample, subsampling_weight, mode = next(generator)
        
        teacher.batch_size = positive_sample.shape[0]
        
        """
        positive_sample    = positive_sample.cuda()
        negative_sample    = negative_sample.cuda()
        subsampling_weight = subsampling_weight.cuda()
    
        positive_score = teacher(positive_sample)
        negative_score = teacher((positive_sample, negative_sample), mode = mode)
        
        loss = LogisticLoss()(
            positive_score, 
            negative_score, 
            subsampling_weight, 
            adversarial_sampling = True,
            alpha = 0.5,
        )
        """
      
        # Construct dataset for KD
        sample_distillation = [[torch.tensor([w1, r, w2]) for r in range(n_relation)] for w1, _, w2 in positive_sample.cpu().detach()]
        sample_distillation = list(chain.from_iterable(sample_distillation))
        sample_distillation = torch.stack(sample_distillation).reshape(positive_sample.shape[0], n_relation, 3).cuda()
        sample_dic[f'teacher_{idx_teacher}'] = sample_distillation
        
        loss = DistillationLoss()(
            score_student = teacher(sample_distillation, mode = 'distillation'), 
            score_teacher = get_distillation_sample(positive_sample),
        ) * (1 - beta[step]) 
        
        metrics[idx_teacher].update(loss.item())
        
        loss.backward()
        
        dic_optimizer[idx_teacher].step()
    
   
    # Compute KD
    for idx_student, student in enumerate(list_teacher):
        
        student.train() 
        
        for idx_teacher, teacher in enumerate(list_teacher):
            
            if idx_student != idx_teacher:
            
                dic_optimizer_kd[idx_student].zero_grad()
                
                teacher.eval()

                sample = sample_dic[f'teacher_{idx_teacher}']
                
                loss_distillation = DistillationLoss()(
                    score_student = student(sample, mode = 'distillation'), 
                    score_teacher = teacher(sample, mode = 'distillation'),
                ) * beta[step]
                
                loss_distillation.backward()
         
                dic_optimizer_kd[idx_student].step()
  
    for teacher in list_teacher:
        
        teacher.train()

    progress_bar.set_description(f'teacher_1: {metrics[0].get():4f}, teacher_2: {metrics[1].get():4f}, teacher_3: {metrics[2].get():4f}')

    # Evaluation:
    if (step + 1) % VALID_STEPS == 0:
        
        for idx, teacher in enumerate(list_teacher):
            
            teacher.eval()
            
            train_test_valid = list_dataset[idx].get_train() + list_dataset[idx].get_valid() + list_dataset[idx].get_test()
            
            valid_dataloader_head, valid_dataloader_tail = get_testing_generator(train_test_valid, list_dataset[idx].get_valid(), n_entity, n_relation, batch_size = 5, cpu_num = 2)
            
            score = benchmark(model, valid_dataloader_head, valid_dataloader_tail)
            
            score = ', '.join(sorted({f'{metric}: {score:4f}' for metric, score in score.items()}))
            
            print(f'Teacher: {idx}, {score}')

            #with open(f'/users/iris/rsourty/experiments/distillation/models/full_distill_teacher_{idx}_{score}.pickle', 'wb') as handle:
            
            #    pickle.dump(teacher, handle, protocol = pickle.HIGHEST_PROTOCOL)

Traceback (most recent call last):
  File "/usr/local/miniconda3/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/local/miniconda3/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/miniconda3/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/local/miniconda3/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/usr/local/miniconda3/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/local/miniconda3/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/miniconda3/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(he

Teacher: 0, HITS@10: 0.001319, HITS@1: 0.000165, HITS@3: 0.000495, MR: 21659.902407, MRR: 0.000955
Teacher: 1, HITS@10: 0.001319, HITS@1: 0.000165, HITS@3: 0.000495, MR: 21659.907023, MRR: 0.000955


KeyboardInterrupt: 

In [30]:
DistillationLoss()(
    torch.tensor([[1., 0, 0, 0]]),
    torch.tensor([[1, 1000., 1, 1]]),
)

tensor(0.4359)

In [19]:
DistillationLoss()(
    score_student = teacher(sample_distillation, mode = 'distillation'), 
    score_teacher = get_distillation_sample(positive_sample),
)

tensor(4.0311e-05, device='cuda:1', grad_fn=<MeanBackward0>)

In [26]:
teacher(sample, mode = 'distillation').shape

torch.Size([512, 11])

#### Train the student:

In [40]:
list_dataset = []

for dataset_name in ['teacher_1.pickle', 'teacher_2.pickle', 'teacher_3.pickle']:
    
    with open(f'/users/iris/rsourty/experiments/distillation/datasets/{dataset_name}', 'rb') as handle:   
        
        list_dataset.append(pickle.load(handle))

In [41]:
list_teacher = []

for model_name in [
    'as_described_teacher_0_HITS@10: 0.452357, HITS@1: 0.209034, HITS@3: 0.371085, MR: 6727.271348, MRR: 0.300592.pickle',
    'as_described_teacher_1_HITS@10: 0.452852, HITS@1: 0.211672, HITS@3: 0.371579, MR: 6727.253544, MRR: 0.301767.pickle',
    'as_described_teacher_2_HITS@10: 0.455490, HITS@1: 0.242994, HITS@3: 0.384438, MR: 6726.608803, MRR: 0.322592.pickle',
]:
    
    with open(f'/users/iris/rsourty/experiments/distillation/models/{model_name}', 'rb') as handle:   
        
        list_teacher.append(pickle.load(handle))

In [44]:
n_entity, n_relation = list_dataset[0].get_metadata()

list_generator = get_multiples_training_genarator(
    list_dataset, 
    batch_size = batch_size, 
    n_entity   = n_entity, 
    n_relation = n_relation, 
    negative_sample_size = negative_sample_size,
)

progress_bar = tqdm.tqdm(range(MAX_STEPS), position=0)

metric = stats.RollingMean(100)

loss_dic = {}

student = Teacher(
        model      = 'RotatE', 
        n_entity   = n_entity, 
        n_relation = n_relation, 
        hidden_dim = hidden_dim, 
        gamma      = 6, 
        batch_size = batch_size,
)

student = student.cuda()

optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, student.parameters()),
    lr = 0.00005,
)

distillation_loss = DistillationLoss()

for step in progress_bar:
    
    for idx_teacher, (generator, teacher) in enumerate(zip(list_generator, list_teacher)):
        
        optimizer.zero_grad()
        
        student.train()
        teacher.eval()
        
        positive_sample, _, _, _ = next(generator)
        
        student.batch_size = positive_sample.shape[0]
        teacher.batch_size = positive_sample.shape[0]
        
        sample = [[torch.tensor([w1, r, w2]) for r in range(n_relation)] for w1, _, w2 in positive_sample]
        sample = list(chain.from_iterable(sample
                                         ))
        sample = torch.stack(sample).reshape(positive_sample.shape[0], n_relation, 3).cuda()
        
        loss = distillation_loss(
            score_student = student(sample, mode = 'distillation'), 
            score_teacher = teacher(sample, mode = 'distillation'),
        )
        
        metric.update(loss)
        
        loss.backward()
        
        optimizer.step()
            
    progress_bar.set_description(f'Student: {metric.get():4f}')
    
    # Evaluation:
    if (step + 1) % VALID_STEPS == 0:
        
        
        train_test_valid = list_dataset[0].get_valid() + list_dataset[0].get_test()
            
        for dataset in list_dataset:
            
            train_test_valid += dataset.get_train()    
            
        valid_dataloader_head, valid_dataloader_tail = get_testing_generator(train_test_valid, list_dataset[0].get_valid(), n_entity, n_relation, batch_size = 5, cpu_num = 2)
    
        score = benchmark(student, valid_dataloader_head, valid_dataloader_tail)

        score = ', '.join(sorted({f'{metric}: {score:4f}' for metric, score in score.items()}))

        print(f'Student: {score}')
        
        with open(f'/users/iris/rsourty/experiments/distillation/models/student_of_teacher_as_described_{score}.pickle', 'wb') as handle:
            
            pickle.dump(teacher, handle, protocol = pickle.HIGHEST_PROTOCOL)

Student: 0.040039:   1%|          | 999/100000 [01:53<2:58:20,  9.25it/s] 

Student: HITS@10: 0.336630, HITS@1: 0.106495, HITS@3: 0.329707, MR: 14805.443290, MRR: 0.219217


Student: 0.008317:   2%|▏         | 1999/100000 [04:27<2:47:42,  9.74it/s]  

Student: HITS@10: 0.363007, HITS@1: 0.296901, HITS@3: 0.352951, MR: 13588.881965, MRR: 0.327125


Student: 0.001809:   3%|▎         | 2999/100000 [07:03<2:47:59,  9.62it/s]  

Student: HITS@10: 0.374217, HITS@1: 0.339103, HITS@3: 0.363501, MR: 12878.399110, MRR: 0.353695


Student: 0.000495:   4%|▍         | 3999/100000 [09:40<3:11:32,  8.35it/s]  

Student: HITS@10: 0.381800, HITS@1: 0.353281, HITS@3: 0.371909, MR: 12410.072041, MRR: 0.364480


Student: 0.000141:   5%|▍         | 4999/100000 [12:14<2:56:16,  8.98it/s]  

Student: HITS@10: 0.386416, HITS@1: 0.357896, HITS@3: 0.373722, MR: 12075.963073, MRR: 0.368463


Student: 0.000041:   6%|▌         | 5998/100000 [14:44<2:40:01,  9.79it/s]  

Student: HITS@10: 0.389878, HITS@1: 0.361853, HITS@3: 0.376855, MR: 11860.862018, MRR: 0.372136


Student: 0.000019:   7%|▋         | 6999/100000 [17:09<2:35:42,  9.95it/s]  

Student: HITS@10: 0.391032, HITS@1: 0.362677, HITS@3: 0.376690, MR: 11695.227992, MRR: 0.373080


Student: 0.000011:   8%|▊         | 7999/100000 [19:38<3:15:28,  7.84it/s]  

Student: HITS@10: 0.393010, HITS@1: 0.362842, HITS@3: 0.377844, MR: 11605.498351, MRR: 0.373700


Student: 0.000008:   9%|▉         | 8999/100000 [22:22<2:31:08, 10.04it/s]  

Student: HITS@10: 0.393999, HITS@1: 0.362348, HITS@3: 0.377844, MR: 11544.926805, MRR: 0.373573


Student: 0.000007:  10%|▉         | 9999/100000 [24:37<2:34:00,  9.74it/s]  

Student: HITS@10: 0.393340, HITS@1: 0.361194, HITS@3: 0.378503, MR: 11473.787339, MRR: 0.373089


Student: 0.000008:  11%|█         | 10998/100000 [27:22<2:28:42,  9.98it/s]  

Student: HITS@10: 0.393834, HITS@1: 0.361194, HITS@3: 0.377844, MR: 11429.408836, MRR: 0.372892


Student: 0.000007:  12%|█▏        | 11998/100000 [29:40<2:32:36,  9.61it/s]  

Student: HITS@10: 0.394164, HITS@1: 0.362183, HITS@3: 0.379163, MR: 11396.493900, MRR: 0.373805


Student: 0.000010:  13%|█▎        | 12999/100000 [32:13<7:45:17,  3.12it/s]  

Student: HITS@10: 0.393999, HITS@1: 0.361523, HITS@3: 0.377844, MR: 11367.324596, MRR: 0.373095


Student: 0.000007:  14%|█▍        | 13998/100000 [34:44<2:44:40,  8.70it/s]  

Student: HITS@10: 0.393834, HITS@1: 0.361853, HITS@3: 0.378503, MR: 11334.284042, MRR: 0.373300


Student: 0.000007:  15%|█▍        | 14999/100000 [37:32<2:21:13, 10.03it/s]  

Student: HITS@10: 0.394164, HITS@1: 0.361029, HITS@3: 0.378009, MR: 11297.412463, MRR: 0.372716


Student: 0.000005:  16%|█▌        | 15999/100000 [39:50<2:44:53,  8.49it/s]  

Student: HITS@10: 0.393010, HITS@1: 0.360699, HITS@3: 0.377514, MR: 11276.348335, MRR: 0.372360


Student: 0.000006:  17%|█▋        | 16999/100000 [42:24<2:20:25,  9.85it/s]  

Student: HITS@10: 0.393999, HITS@1: 0.360864, HITS@3: 0.376690, MR: 11258.974448, MRR: 0.372546


Student: 0.000006:  17%|█▋        | 17156/100000 [43:26<2:25:18,  9.50it/s]  

KeyboardInterrupt: 

In [39]:
Student: HITS@10: 0.384438, HITS@1: 0.360204, HITS@3: 0.373228, MR: 12147.641444, MRR: 0.369487

Student: HITS@10: 0.384438, HITS@1: 0.360204, HITS@3: 0.373228, MR: 12147.641444, MRR: 0.369487
