In [6]:
import torch
import torch.nn as nn
import os
import numpy as np
import sys
import random
import pickle
import torch.optim as optim
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset

sys.path.insert(0, './PLM-NR')
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from utils import MODEL_CLASSES
from tnlrv3.tokenization_tnlrv3 import TuringNLRv3Tokenizer


In [None]:
tokenizer = TuringNLRv3Tokenizer.from_pretrained("unilmv2/unilm2-base-uncased")


In [7]:
path_turing = './unilmv2'
tokenizer = TuringNLRv3Tokenizer.from_pretrained(os.path.join(path_turing, 'unilm2-base-uncased-vocab.txt'),
                                                 do_lower_case=True)

OSError: Incorrect path_or_model_id: './unilmv2/unilm2-base-uncased-vocab.txt'. Please provide either the path to a local folder or the repo_id of a model on the Hub.

In [None]:
class Args:
    def __init__(self):
        self.news_query_vector_dim = 200
        self.drop_rate = 0.2
        self.news_dim = 256
        self.T = 500
        self.corpus_path = './docs_filter.tsv'
args = Args()

# Pretrain Dataset

In [None]:
MAX_TITLE_LEN = 24
MAX_BODY_LEN = 512
NPRATIO=9
BATCH_SIZE=32

In [None]:
with open(args.corpus_path, encoding='utf-8') as f:
    total_lines = f.readlines()
len(total_lines)

In [None]:
titles, bodies = [], []
for line in total_lines:
    splited = line.strip('\n').split('\t')
    titles.append(splited[3])
    bodies.append(splited[4])

In [None]:
class PretrainDataset(Dataset):
    def __init__(self, titles, bodies):
        self.titles = titles
        self.bodies = bodies
        self.len = len(titles)
        
    def __getitem__(self, idx):
        select_list = list(range(0, idx)) + list(range(idx+1, self.len))
        neg_idx = random.sample(select_list, NPRATIO)
        neg_titles = [self.titles[i] for i in neg_idx]
        pos_title = self.titles[idx]
        titles = [tokenizer(title, max_length=MAX_TITLE_LEN, pad_to_max_length=True,
                            truncation=True) for title in [pos_title] + neg_titles]
        input_titles = np.array([title['input_ids'] + title['attention_mask'] for title in titles])
        body = tokenizer(self.bodies[idx], max_length=MAX_BODY_LEN, pad_to_max_length=True,
                         truncation=True)
        input_body = np.array(body['input_ids'] + body['attention_mask'])

        label=0
        return input_titles, input_body, label

    def __len__(self):
        return self.len

In [None]:
pretrain_ds = PretrainDataset(titles, bodies)
pretrain_dl = DataLoader(pretrain_ds, batch_size=BATCH_SIZE, num_workers=32, shuffle=True, pin_memory=True)

# Pretrain Model

In [None]:
class AttentionPooling(nn.Module):
    def __init__(self, emb_size, hidden_size):
        super(AttentionPooling, self).__init__()
        self.att_fc1 = nn.Linear(emb_size, hidden_size)
        self.att_fc2 = nn.Linear(hidden_size, 1)

    def forward(self, x, attn_mask=None):
        """
        Args:
            x: batch_size, candidate_size, emb_dim
            attn_mask: batch_size, candidate_size
        Returns:
            (shape) batch_size, emb_dim
        """
        bz = x.shape[0]
        e = self.att_fc1(x)
        e = nn.Tanh()(e)
        alpha = self.att_fc2(e)
        alpha = torch.exp(alpha)

        if attn_mask is not None:
            alpha = alpha * attn_mask.unsqueeze(2)
        
        alpha = alpha / (torch.sum(alpha, dim=1, keepdim=True) + 1e-8)
        x = torch.bmm(x.permute(0, 2, 1), alpha).squeeze(dim=-1)
        return x

class NewsEncoder(nn.Module):
    def __init__(self, args):
        super(NewsEncoder, self).__init__()
        config_class, model_class, tokenizer_class = MODEL_CLASSES['tnlrv3']
        self.bert_config = config_class.from_pretrained(
            os.path.join(path_turing, 'unilm2-base-uncased-config.json'), 
            output_hidden_states=True,
            num_hidden_layers=12)
        self.bert_model = model_class.from_pretrained(
            os.path.join(path_turing, 'unilm2-base-uncased.bin'), config=self.bert_config)
        self.attn = AttentionPooling(self.bert_config.hidden_size, args.news_query_vector_dim)
        self.dense = nn.Linear(self.bert_config.hidden_size, args.news_dim)

    def forward(self, x):
        '''
            x: batch_size, word_num * 2
            mask: batch_size, word_num
        '''
        batch_size, num_words = x.shape
        num_words = num_words // 2
        text_ids = torch.narrow(x, 1, 0, num_words)
        text_attmask = torch.narrow(x, 1, num_words, num_words)
        word_vecs = self.bert_model(text_ids, text_attmask)[3][self.bert_config.num_hidden_layers]
        news_vec = self.attn(word_vecs)
        news_vec = self.dense(news_vec)
        return news_vec

In [None]:
class TitleBodySimModel(nn.Module):
    def __init__(self, args):
        super(TitleBodySimModel, self).__init__()
        self.news_encoder = NewsEncoder(args)
        self.loss = nn.CrossEntropyLoss() 
        
    def forward(self, title, body, labels):
        '''
            title: bz, 1+K, MAX_TITLE_WORD * 2
            body: bz, MAX_BODY_WORD * 2
            labels: bz
        '''
        body_emb = self.news_encoder(body)             #bz,emb_dim
        bz, candi_num, input_num = title.shape
        title = title.reshape(-1, input_num)
        title_emb = self.news_encoder(title)
        title_emb = title_emb.reshape(bz, candi_num, -1) #bz, 1+K, emb_dim
        
        scores = torch.bmm(title_emb, body_emb.unsqueeze(dim=-1)).squeeze(-1)
        
        loss = self.loss(scores, labels)
        return scores, loss

In [None]:
pretrain_model = TitleBodySimModel(args)
device = torch.device('cuda')
pretrain_model.to(device)

In [None]:
def acc(y_true, y_hat):
    y_hat = torch.argmax(y_hat, dim=-1)
    tot = y_true.shape[0]
    hit = torch.sum(y_true == y_hat)
    return hit.data.float() * 1.0 / tot

In [None]:
for param in pretrain_model.news_encoder.bert_model.parameters():
    param.requires_grad = False

for index, layer in enumerate(pretrain_model.news_encoder.bert_model.bert.encoder.layer):
    if index in [9, 10, 11]:
        for param in layer.parameters():
            param.requires_grad = True

for name, p in pretrain_model.named_parameters():
    print(name, p.requires_grad)

In [None]:
rest_param = filter(
    lambda x: id(x) not in list(map(id, pretrain_model.news_encoder.bert_model.parameters())), pretrain_model.parameters())

optimizer = optim.Adam(
    [{'params':pretrain_model.news_encoder.bert_model.parameters(),'lr':1e-6},
    {'params':rest_param,'lr':1e-5}]
)

In [None]:
for ep in range(1):
    loss = 0.0
    accuary = 0.0
    cnt = 1
    tqdm_util = tqdm(pretrain_dl)
    pretrain_model.train()
    for title,body,labels in tqdm_util: 
        title = title.to(device)
        body = body.to(device)
        labels = labels.to(device)
        
        y_hat, bz_loss = pretrain_model(title, body, labels)
        loss += bz_loss.data.float()
        accuary += acc(labels, y_hat)
        optimizer.zero_grad()
        bz_loss.backward()
        optimizer.step()

        if cnt % 10 == 0:
            tqdm_util.set_description('ed: {}, train_loss: {:.5f}, acc: {:.5f}'.format(cnt * BATCH_SIZE, loss.data / cnt, accuary / cnt))

        if cnt % args.T == 0:
            ckpt_path = f'./DP_12_layer_{cnt}.pt'
            torch.save({'model_state_dict': pretrain_model.state_dict()}, ckpt_path)

        cnt += 1

ckpt_path = './DP_12_layer.pt'
torch.save({'model_state_dict': pretrain_model.state_dict()}, ckpt_path)

# Generate Title and Body Representations

In [None]:
ckpt_paths = [
    './DP_12_layer.pt',
    './DP_12_layer_61500.pt',
    './DP_12_layer_61000.pt',
    './DP_12_layer_60500.pt'
]

In [None]:
pretrain_model.eval()
torch.set_grad_enabled(False)

In [None]:
class NewsDataset(Dataset):
    def __init__(self, data, max_len):
        self.data = data
        self.max_len = max_len

    def __getitem__(self, idx):
        res = tokenizer(self.data[idx], max_length=self.max_len, pad_to_max_length=True, truncation=True)
        return np.array(res['input_ids'] + res['attention_mask'])

    def __len__(self):
        return len(self.data)

title_dataset = NewsDataset(titles, MAX_TITLE_LEN)
title_dataloader = DataLoader(title_dataset,
                            batch_size=512,
                            num_workers=32)

body_dataset = NewsDataset(bodies, MAX_BODY_LEN)
body_dataloader = DataLoader(body_dataset,
                            batch_size=512,
                            num_workers=32)

In [None]:
for i, ckpt_path in enumerate(ckpt_paths):
    pretrain_model.load_state_dict(torch.load(ckpt_path))
    title_scoring = []
    with torch.no_grad():
        for input_ids in tqdm(title_dataloader):
            input_ids = input_ids.cuda()
            news_vec = pretrain_model.news_encoder(input_ids)
            news_vec = news_vec.to(torch.device("cpu")).detach().numpy()
            title_scoring.extend(news_vec)

    title_scoring = np.array(title_scoring)
    with open(f'./teacher_title_emb_{i}.pkl', 'wb') as f:
        pickle.dump(title_scoring, f)

In [None]:
for i, ckpt_path in enumerate(ckpt_paths):
    pretrain_model.load_state_dict(torch.load(ckpt_path))

    body_scoring = []
    with torch.no_grad():
        for input_ids in tqdm(body_dataloader):
            input_ids = input_ids.cuda()
            news_vec = pretrain_model.news_encoder(input_ids)
            news_vec = news_vec.to(torch.device("cpu")).detach().numpy()
            body_scoring.extend(news_vec)

    body_scoring = np.array(body_scoring)
    with open(f'./teacher_body_emb_{i}.pkl', 'wb') as f:
        pickle.dump(body_scoring, f)