In [None]:
# pip install torch torchvision torchaudio numpy tqdm transformers scikit-learn #pickle5


Note: you may need to restart the kernel to use updated packages.


In [5]:
# pip install torch==1.6.0 tensorflow==1.15.0 horovod==0.19.5 transformers==3.0.2


In [6]:
import torch
import torch.nn as nn
import os
import numpy as np
import sys
import random
import torch.optim as optim
import torch.nn.functional as F
import pickle
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset

sys.path.insert(0, './PLM-NR')
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from utils import MODEL_CLASSES
from tnlrv3.tokenization_tnlrv3 import TuringNLRv3Tokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
path_turing = './tnlrv3/'
tokenizer = TuringNLRv3Tokenizer.from_pretrained("/Users/c-hasselris/Desktop/DeepLearning/DeepL-Project/Tiny-NewsRec/tnlrv3/tokenizer/tnlrv3-base-uncased-vocab.txt",
                                                 do_lower_case=True)



In [8]:
class Args:
    def __init__(self):
        self.news_query_vector_dim = 200
        self.drop_rate = 0.2
        self.news_dim = 256
        self.num_hidden_layers = 4
        self.corpus_path = './docs_filter.tsv'
        self.num_teachers = 4
args = Args()

# Pretrain Dataset

In [9]:
MAX_TITLE_LEN = 24
MAX_BODY_LEN = 512
NPRATIO=9
BATCH_SIZE=32

In [10]:
with open(args.corpus_path, encoding='utf-8') as f:
    total_lines = f.readlines()
len(total_lines)

0

In [11]:
titles, bodies = [], []
for line in total_lines:
    splited = line.strip('\n').split('\t')
    titles.append(splited[3])
    bodies.append(splited[4])

In [12]:
# Must run Domain-specific_Post-train.ipynb first
title_scorings, body_scorings = [], []
for i in range(args.num_teachers):
    with open(f'./teacher_title_emb_{i}.pkl', 'rb') as f:
        title_scorings.append(pickle.load(f))
    with open(f'./teacher_body_emb_{i}.pkl', 'rb') as f:
        body_scorings.append(pickle.load(f))

FileNotFoundError: [Errno 2] No such file or directory: './teacher_title_emb_0.pkl'

In [None]:
class DistillDataset(Dataset):
    def __init__(self, titles, bodies, teacher_titles, teacher_bodies):
        self.titles = titles
        self.bodies = bodies
        self.teacher_titles = teacher_titles
        self.teacher_bodies = teacher_bodies
        self.len = len(titles)
        
    def __getitem__(self, idx):
        select_list = list(range(0, idx)) + list(range(idx+1, self.len))
        neg_idx = random.sample(select_list, NPRATIO)
        neg_titles = [self.titles[i] for i in neg_idx]
        pos_title = self.titles[idx]
        titles = [tokenizer(title, max_length=MAX_TITLE_LEN, pad_to_max_length=True,
                            truncation=True) for title in [pos_title] + neg_titles]
        input_titles = np.array([title['input_ids'] + title['attention_mask'] for title in titles])
        body = tokenizer(self.bodies[idx], max_length=MAX_BODY_LEN, pad_to_max_length=True,
                        truncation=True)
        input_body = np.array(body['input_ids'] + body['attention_mask'])
                                 
        total_idx = [idx] + neg_idx
        input_teacher_titles = [x[total_idx] for x in self.teacher_titles]
        input_teacher_bodies = [x[idx] for x in self.teacher_bodies]
        label=0
        return input_titles, input_body, label, input_teacher_titles, input_teacher_bodies

    def __len__(self):
        return self.len

In [None]:
distill_ds = DistillDataset(titles, bodies, title_scorings, body_scorings)
distill_dl = DataLoader(distill_ds, batch_size=BATCH_SIZE, num_workers=32, shuffle=True, pin_memory=True)

# Pretrain Model

In [None]:
class AttentionPooling(nn.Module):
    def __init__(self, emb_size, hidden_size):
        super(AttentionPooling, self).__init__()
        self.att_fc1 = nn.Linear(emb_size, hidden_size)
        self.att_fc2 = nn.Linear(hidden_size, 1)

    def forward(self, x, attn_mask=None):
        """
        Args:
            x: batch_size, candidate_size, emb_dim
            attn_mask: batch_size, candidate_size
        Returns:
            (shape) batch_size, emb_dim
        """
        bz = x.shape[0]
        e = self.att_fc1(x)
        e = nn.Tanh()(e)
        alpha = self.att_fc2(e)
        alpha = torch.exp(alpha)

        if attn_mask is not None:
            alpha = alpha * attn_mask.unsqueeze(2)
        
        alpha = alpha / (torch.sum(alpha, dim=1, keepdim=True) + 1e-8)
        x = torch.bmm(x.permute(0, 2, 1), alpha).squeeze(dim=-1)
        return x

class NewsEncoder(nn.Module):
    def __init__(self, args):
        super(NewsEncoder, self).__init__()
        config_class, model_class, tokenizer_class = MODEL_CLASSES['tnlrv3']
        self.bert_config = config_class.from_pretrained(
            os.path.join(path_turing, 'unilm2-base-uncased-config.json'), 
            output_hidden_states=True,
            num_hidden_layers=args.num_hidden_layers)
        self.bert_model = model_class.from_pretrained(
            os.path.join(path_turing, 'unilm2-base-uncased.bin'), config=self.bert_config)
        self.attn = AttentionPooling(self.bert_config.hidden_size, args.news_query_vector_dim)
        self.dense = nn.Linear(self.bert_config.hidden_size, args.news_dim)

    def forward(self, x):
        '''
            x: batch_size, word_num * 2
            mask: batch_size, word_num
        '''
        batch_size, num_words = x.shape
        num_words = num_words // 2
        text_ids = torch.narrow(x, 1, 0, num_words)
        text_attmask = torch.narrow(x, 1, num_words, num_words)
        word_vecs = self.bert_model(text_ids, text_attmask)[3][self.bert_config.num_hidden_layers]
        news_vec = self.attn(word_vecs)
        news_vec = self.dense(news_vec)
        return news_vec

In [None]:
class TitleBodySimModel(nn.Module):
    def __init__(self, args):
        super(TitleBodySimModel, self).__init__()
        self.news_encoder = NewsEncoder(args)
        
    def forward(self, title, body):
        '''
            title: bz, 1+K, MAX_TITLE_WORD * 2
            body: bz, MAX_BODY_WORD * 2
            labels: bz
        '''
        body_emb = self.news_encoder(body)             #bz,emb_dim
        bz, candi_num, input_num = title.shape
        title = title.reshape(-1, input_num)
        title_emb = self.news_encoder(title)
        title_emb = title_emb.reshape(bz, candi_num, -1) #bz, 1+K, emb_dim
        
        scores = torch.bmm(title_emb, body_emb.unsqueeze(dim=-1)).squeeze(-1)
        
        return scores, title_emb, body_emb

In [None]:
def kd_ce_loss(logits_S, logits_T, temperature=1):
    '''
    Calculate the cross entropy between logits_S and logits_T
    :param logits_S: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
    :param logits_T: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
    :param temperature: A float or a tensor of shape (batch_size, length) or (batch_size,)
    '''
    beta_logits_T = logits_T / temperature
    beta_logits_S = logits_S / temperature
    p_T = F.softmax(beta_logits_T, dim=-1)
    loss = -(p_T * F.log_softmax(beta_logits_S, dim=-1)).sum(dim=-1).mean()
    return loss

def hid_mse_loss(state_S, state_T, mask=None, reduce=True):
    '''
    * Calculates the mse loss between `state_S` and `state_T`, which are the hidden state of the models.
    * If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
    * If the hidden sizes of student and teacher are different, 'proj' option is required in `inetermediate_matches` to match the dimensions.
    :param torch.Tensor state_S: tensor of shape  (*batch_size*, *length*, *hidden_size*)
    :param torch.Tensor state_T: tensor of shape  (*batch_size*, *length*, *hidden_size*)
    :param torch.Tensor mask:    tensor of shape  (*batch_size*, *length*)
    '''
    if mask is None:
        if not reduce:
            loss = F.mse_loss(state_S, state_T, reduction='none').mean(dim=-1)
        else:
            loss = F.mse_loss(state_S, state_T)
    else:
        if not reduce:
            loss = (F.mse_loss(state_S, state_T, reduction='none') *
                    mask.unsqueeze(-1)).mean(dim=-1)
        else:
            valid_count = mask.sum() * state_S.size(-1)
            loss = (F.mse_loss(state_S, state_T, reduction='none') *
                    mask.unsqueeze(-1)).sum() / valid_count
    return loss

In [None]:
class DistillModel(nn.Module):
    def __init__(self, args):
        super(DistillModel, self).__init__()
        self.student = TitleBodySimModel(args)
        self.target_loss = nn.CrossEntropyLoss()
        self.distill_loss = kd_ce_loss
        self.emb_loss = hid_mse_loss
        self.transform_matrix = nn.ModuleList([nn.Linear(args.news_dim, args.news_dim) for _ in range(args.num_teachers)])
        for module in self.transform_matrix:
            nn.init.xavier_uniform_(module.weight, gain=1.)
            nn.init.constant_(module.bias, 0.0)

    def forward(self, title, body, labels, teacher_titles, teacher_bodies):
        '''
            teacher_titles: [(batch_size, 1+K, news_emb) * num_teachers]
            teacher_bodies: [(batch_size, news_emb) * num_teachers]
        '''
        student_score, student_title, student_body = self.student(title, body)
        target_loss = self.target_loss(student_score, labels)

        teacher_scores, teacher_losses, teacher_MSEs = [], [], []
        for i, (teacher_title, teacher_body) in enumerate(zip(teacher_titles, teacher_bodies)):
            teacher_score = torch.bmm(teacher_title, teacher_body.unsqueeze(dim=-1)).squeeze(dim=-1)
            teacher_loss = F.cross_entropy(teacher_score, labels, reduction='none')
            teacher_scores.append(teacher_score)
            teacher_losses.append(teacher_loss)

            teacher_title_proj = self.transform_matrix[i](teacher_title)
            teacher_body_proj = self.transform_matrix[i](teacher_body)
            teacher_MSE = \
                self.emb_loss(student_title, teacher_title_proj, reduce=False).mean(dim=-1) + \
                    self.emb_loss(student_body, teacher_body_proj, reduce=False)
            teacher_MSEs.append(teacher_MSE)

        teacher_losses = - torch.stack(teacher_losses, dim=-1)
        teacher_weights = F.softmax(teacher_losses, dim=-1)

        teacher_scores = torch.stack(teacher_scores, dim=-1)
        teacher_scores = torch.bmm(teacher_scores, teacher_weights.unsqueeze(dim=-1)).squeeze(dim=-1)
        distill_loss = self.distill_loss(student_score, teacher_scores)
        emb_loss = (teacher_MSEs * teacher_weights).sum(dim=-1).mean()

        loss = target_loss + distill_loss + emb_loss
        return loss, target_loss, distill_loss, emb_loss, student_score

In [None]:
distill_model = DistillModel(args)
device = torch.device('cuda')
distill_model.to(device)

In [None]:
def acc(y_true, y_hat):
    y_hat = torch.argmax(y_hat, dim=-1)
    tot = y_true.shape[0]
    hit = torch.sum(y_true == y_hat)
    return hit.data.float() * 1.0 / tot

In [None]:
for param in distill_model.student.news_encoder.bert_model.parameters():
    param.requires_grad = False

for index, layer in enumerate(distill_model.student.news_encoder.bert_model.bert.encoder.layer):
    if index in [2, 3]:
        for param in layer.parameters():
            param.requires_grad = True

for name, p in distill_model.named_parameters():
    print(name, p.requires_grad)

In [None]:
rest_param = filter(
    lambda x: id(x) not in list(map(id, distill_model.student.news_encoder.bert_model.parameters())), distill_model.parameters())

optimizer = optim.Adam(
    [{'params':distill_model.student.news_encoder.bert_model.parameters(),'lr':1e-6},
    {'params':rest_param,'lr':1e-5}]
)

In [None]:
for ep in range(1):
    loss, target_loss, distill_loss, emb_loss = 0.0, 0.0, 0.0, 0.0
    accuary = 0.0
    cnt = 1
    tqdm_util = tqdm(distill_dl)
    distill_model.train()
    for title,body,labels,teacher_title,teacher_body in tqdm_util:
        title = title.cuda(non_blocking=True)
        body = body.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        teacher_title = teacher_title.cuda(non_blocking=True)
        teacher_body = teacher_body.cuda(non_blocking=True)
        
        bz_loss, t_loss, d_loss, e_loss, y_hat = distill_model(title, body, labels, teacher_title, teacher_body)
        
        loss += bz_loss.data.float()
        target_loss += t_loss.data.float()
        distill_loss += d_loss.data.float()
        emb_loss += e_loss.data.float()
        accuary += acc(labels, y_hat)
        optimizer.zero_grad()
        bz_loss.backward()
        optimizer.step()
        
        if cnt % 10 == 0:
            tqdm_util.set_description('ed: {}, loss: {:.5f}, t_loss: {:.5f}, d_loss: {:.5f}, e_loss: {:.5f}, acc: {:.5f}'.format(
                cnt * BATCH_SIZE, loss.data / cnt, target_loss.data / cnt, distill_loss.data / cnt, emb_loss.data / cnt, accuary / cnt))
        
        cnt += 1
        
    ckpt_path = f'./first_stage_{args.num_hidden_layers}_layer.pt'
    torch.save({'model_state_dict': distill_model.state_dict()}, ckpt_path)
