In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
import sys
import random
import pickle
import torch.optim as optim
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

from transformers import BertTokenizer, BertConfig, BertModel
from transformers import AutoTokenizer
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP


In [3]:
tokenizer = AutoTokenizer.from_pretrained('prajjwal1/bert-small')

In [4]:
class Args:
    def __init__(self):
        self.news_query_vector_dim = 200
        self.drop_rate = 0.2
        self.news_dim = 256
        self.T = 500
        self.corpus_path = './docs_filter.tsv'
args = Args()

# Pretrain Dataset

In [5]:
MAX_TITLE_LEN = 24
MAX_BODY_LEN = 512
NPRATIO=9
BATCH_SIZE=32

In [6]:
with open(args.corpus_path, encoding='utf-8') as f:
    total_lines = f.readlines()
len(total_lines)

101527

In [7]:
titles, bodies = [], []
for line in total_lines:
    splited = line.strip('\n').split('\t')
    titles.append(splited[3])
    bodies.append(splited[4])

In [8]:
class PretrainDataset(Dataset):
    def __init__(self, titles, bodies):
        self.titles = titles
        self.bodies = bodies
        self.len = len(titles)
        
    def __getitem__(self, idx):
        select_list = list(range(0, idx)) + list(range(idx+1, self.len))
        neg_idx = random.sample(select_list, NPRATIO)
        neg_titles = [self.titles[i] for i in neg_idx]
        pos_title = self.titles[idx]
        titles = [tokenizer(title, max_length=MAX_TITLE_LEN, padding='max_length',
                            truncation=True) for title in [pos_title] + neg_titles]
        input_titles = np.array([title['input_ids'] + title['attention_mask'] for title in titles])
        body = tokenizer(self.bodies[idx], max_length=MAX_BODY_LEN, padding='max_length',
                         truncation=True)
        input_body = np.array(body['input_ids'] + body['attention_mask'])

        label=0
        return input_titles, input_body, label

    def __len__(self):
        return self.len

In [9]:
pretrain_ds = PretrainDataset(titles, bodies)
pretrain_dl = DataLoader(pretrain_ds, batch_size=BATCH_SIZE, num_workers=32, shuffle=True, pin_memory=True)

# Pretrain Model

In [10]:
class AttentionPooling(nn.Module):
    def __init__(self, emb_size, hidden_size):
        super(AttentionPooling, self).__init__()
        self.att_fc1 = nn.Linear(emb_size, hidden_size)
        self.att_fc2 = nn.Linear(hidden_size, 1)

    def forward(self, x, attn_mask=None):
        """
        Args:
            x: batch_size, candidate_size, emb_dim
            attn_mask: batch_size, candidate_size
        Returns:
            (shape) batch_size, emb_dim
        """
        bz = x.shape[0]
        e = self.att_fc1(x)
        e = nn.Tanh()(e)
        alpha = self.att_fc2(e)
        alpha = torch.exp(alpha)

        if attn_mask is not None:
            alpha = alpha * attn_mask.unsqueeze(2)
        
        alpha = alpha / (torch.sum(alpha, dim=1, keepdim=True) + 1e-8)
        x = torch.bmm(x.permute(0, 2, 1), alpha).squeeze(dim=-1)
        return x

class NewsEncoder(nn.Module):
    def __init__(self, args):
        super(NewsEncoder, self).__init__()
        config_class, model_class = BertConfig, BertModel
        self.bert_config = config_class.from_pretrained(
            'prajjwal1/bert-small', 
            output_hidden_states=True)
        self.bert_model = model_class.from_pretrained(
            'prajjwal1/bert-small', config=self.bert_config)
        self.attn = AttentionPooling(self.bert_config.hidden_size, args.news_query_vector_dim)
        self.dense = nn.Linear(self.bert_config.hidden_size, args.news_dim)

    def forward(self, x):
        '''
            x: batch_size, word_num * 2
            mask: batch_size, word_num
        '''
        batch_size, num_words = x.shape
        num_words = num_words // 2
        text_ids = torch.narrow(x, 1, 0, num_words)
        text_attmask = torch.narrow(x, 1, num_words, num_words)
        word_vecs = self.bert_model(text_ids, text_attmask)[0]
        news_vec = self.attn(word_vecs)
        news_vec = self.dense(news_vec)
        return news_vec

In [16]:
class TitleBodySimModel(nn.Module):
    def __init__(self, args):
        super(TitleBodySimModel, self).__init__()
        self.news_encoder = NewsEncoder(args)
        self.loss = nn.CrossEntropyLoss() 
        
    def forward(self, title, body, labels):
        '''
            title: bz, 1+K, MAX_TITLE_WORD * 2
            body: bz, MAX_BODY_WORD * 2
            labels: bz
        '''
        body_emb = self.news_encoder(body)             #bz,emb_dim
        bz, candi_num, input_num = title.shape
        title = title.reshape(-1, input_num)
        title_emb = self.news_encoder(title)
        title_emb = title_emb.reshape(bz, candi_num, -1) #bz, 1+K, emb_dim
        scores = torch.bmm(title_emb, body_emb.unsqueeze(dim=-1)).squeeze(-1)
        
        loss = self.loss(scores, labels)
        return scores, loss

In [17]:
pretrain_model = TitleBodySimModel(args)
device = torch.device('cuda')
pretrain_model.to(device)

Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


TitleBodySimModel(
  (news_encoder): NewsEncoder(
    (bert_model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 512, padding_idx=0)
        (position_embeddings): Embedding(512, 512)
        (token_type_embeddings): Embedding(2, 512)
        (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=512, out_features=512, bias=True)
                (key): Linear(in_features=512, out_features=512, bias=True)
                (value): Linear(in_features=512, out_features=512, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=512, out_features=512, bias=True

In [18]:
def acc(y_true, y_hat):
    y_hat = torch.argmax(y_hat, dim=-1)
    tot = y_true.shape[0]
    hit = torch.sum(y_true == y_hat)
    return hit.data.float() * 1.0 / tot

In [19]:
optimizer = optim.Adam(
    params=pretrain_model.parameters(), lr=1e-6
)

In [20]:
for ep in range(1):
    loss = 0.0
    accuary = 0.0
    cnt = 1
    tqdm_util = tqdm(pretrain_dl)
    pretrain_model.train()
    for title,body,labels in tqdm_util: 
        title = title.to(device)
        body = body.to(device)
        labels = labels.to(device)
        y_hat, bz_loss = pretrain_model(title, body, labels)
        loss += bz_loss.data.float()
        accuary += acc(labels, y_hat)
        optimizer.zero_grad()
        bz_loss.backward()
        optimizer.step()

        if cnt % 10 == 0:
            tqdm_util.set_description('ed: {}, train_loss: {:.5f}, acc: {:.5f}'.format(cnt * BATCH_SIZE, loss.data / cnt, accuary / cnt))

        if cnt % args.T == 0:
            ckpt_path = f'./BERT_finetune_{cnt}.pt'
            torch.save({'model_state_dict': pretrain_model.state_dict()}, ckpt_path)

        cnt += 1

ckpt_path = './BERT_finetune.pt'
torch.save({'model_state_dict': pretrain_model.state_dict()}, ckpt_path)

  0%|          | 0/3173 [00:02<?, ?it/s]

# Generate Title and Body Representations

In [21]:
ckpt_paths = [
    './BERT_finetune.pt',
    './BERT_finetune_3000.pt'
]

In [22]:
pretrain_model.eval()
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f73bca855e0>

In [23]:
class NewsDataset(Dataset):
    def __init__(self, data, max_len):
        self.data = data
        self.max_len = max_len

    def __getitem__(self, idx):
        res = tokenizer(self.data[idx], max_length=self.max_len, pad_to_max_length=True, truncation=True)
        return np.array(res['input_ids'] + res['attention_mask'])

    def __len__(self):
        return len(self.data)

title_dataset = NewsDataset(titles, MAX_TITLE_LEN)
title_dataloader = DataLoader(title_dataset,
                            batch_size=512,
                            num_workers=32)

body_dataset = NewsDataset(bodies, MAX_BODY_LEN)
body_dataloader = DataLoader(body_dataset,
                            batch_size=512,
                            num_workers=32)

In [24]:
for i, ckpt_path in enumerate(ckpt_paths):
    pretrain_model.load_state_dict(torch.load(ckpt_path))
    title_scoring = []
    with torch.no_grad():
        for input_ids in tqdm(title_dataloader):
            input_ids = input_ids.cuda()
            news_vec = pretrain_model.news_encoder(input_ids)
            news_vec = news_vec.to(torch.device("cpu")).detach().numpy()
            title_scoring.extend(news_vec)

    title_scoring = np.array(title_scoring)
    with open(f'./teacher_title_emb_{i}.pkl', 'wb') as f:
        pickle.dump(title_scoring, f)

RuntimeError: Error(s) in loading state_dict for TitleBodySimModel:
	Missing key(s) in state_dict: "news_encoder.bert_model.embeddings.position_ids", "news_encoder.bert_model.embeddings.word_embeddings.weight", "news_encoder.bert_model.embeddings.position_embeddings.weight", "news_encoder.bert_model.embeddings.token_type_embeddings.weight", "news_encoder.bert_model.embeddings.LayerNorm.weight", "news_encoder.bert_model.embeddings.LayerNorm.bias", "news_encoder.bert_model.encoder.layer.0.attention.self.query.weight", "news_encoder.bert_model.encoder.layer.0.attention.self.query.bias", "news_encoder.bert_model.encoder.layer.0.attention.self.key.weight", "news_encoder.bert_model.encoder.layer.0.attention.self.key.bias", "news_encoder.bert_model.encoder.layer.0.attention.self.value.weight", "news_encoder.bert_model.encoder.layer.0.attention.self.value.bias", "news_encoder.bert_model.encoder.layer.0.attention.output.dense.weight", "news_encoder.bert_model.encoder.layer.0.attention.output.dense.bias", "news_encoder.bert_model.encoder.layer.0.attention.output.LayerNorm.weight", "news_encoder.bert_model.encoder.layer.0.attention.output.LayerNorm.bias", "news_encoder.bert_model.encoder.layer.0.intermediate.dense.weight", "news_encoder.bert_model.encoder.layer.0.intermediate.dense.bias", "news_encoder.bert_model.encoder.layer.0.output.dense.weight", "news_encoder.bert_model.encoder.layer.0.output.dense.bias", "news_encoder.bert_model.encoder.layer.0.output.LayerNorm.weight", "news_encoder.bert_model.encoder.layer.0.output.LayerNorm.bias", "news_encoder.bert_model.encoder.layer.1.attention.self.query.weight", "news_encoder.bert_model.encoder.layer.1.attention.self.query.bias", "news_encoder.bert_model.encoder.layer.1.attention.self.key.weight", "news_encoder.bert_model.encoder.layer.1.attention.self.key.bias", "news_encoder.bert_model.encoder.layer.1.attention.self.value.weight", "news_encoder.bert_model.encoder.layer.1.attention.self.value.bias", "news_encoder.bert_model.encoder.layer.1.attention.output.dense.weight", "news_encoder.bert_model.encoder.layer.1.attention.output.dense.bias", "news_encoder.bert_model.encoder.layer.1.attention.output.LayerNorm.weight", "news_encoder.bert_model.encoder.layer.1.attention.output.LayerNorm.bias", "news_encoder.bert_model.encoder.layer.1.intermediate.dense.weight", "news_encoder.bert_model.encoder.layer.1.intermediate.dense.bias", "news_encoder.bert_model.encoder.layer.1.output.dense.weight", "news_encoder.bert_model.encoder.layer.1.output.dense.bias", "news_encoder.bert_model.encoder.layer.1.output.LayerNorm.weight", "news_encoder.bert_model.encoder.layer.1.output.LayerNorm.bias", "news_encoder.bert_model.encoder.layer.2.attention.self.query.weight", "news_encoder.bert_model.encoder.layer.2.attention.self.query.bias", "news_encoder.bert_model.encoder.layer.2.attention.self.key.weight", "news_encoder.bert_model.encoder.layer.2.attention.self.key.bias", "news_encoder.bert_model.encoder.layer.2.attention.self.value.weight", "news_encoder.bert_model.encoder.layer.2.attention.self.value.bias", "news_encoder.bert_model.encoder.layer.2.attention.output.dense.weight", "news_encoder.bert_model.encoder.layer.2.attention.output.dense.bias", "news_encoder.bert_model.encoder.layer.2.attention.output.LayerNorm.weight", "news_encoder.bert_model.encoder.layer.2.attention.output.LayerNorm.bias", "news_encoder.bert_model.encoder.layer.2.intermediate.dense.weight", "news_encoder.bert_model.encoder.layer.2.intermediate.dense.bias", "news_encoder.bert_model.encoder.layer.2.output.dense.weight", "news_encoder.bert_model.encoder.layer.2.output.dense.bias", "news_encoder.bert_model.encoder.layer.2.output.LayerNorm.weight", "news_encoder.bert_model.encoder.layer.2.output.LayerNorm.bias", "news_encoder.bert_model.encoder.layer.3.attention.self.query.weight", "news_encoder.bert_model.encoder.layer.3.attention.self.query.bias", "news_encoder.bert_model.encoder.layer.3.attention.self.key.weight", "news_encoder.bert_model.encoder.layer.3.attention.self.key.bias", "news_encoder.bert_model.encoder.layer.3.attention.self.value.weight", "news_encoder.bert_model.encoder.layer.3.attention.self.value.bias", "news_encoder.bert_model.encoder.layer.3.attention.output.dense.weight", "news_encoder.bert_model.encoder.layer.3.attention.output.dense.bias", "news_encoder.bert_model.encoder.layer.3.attention.output.LayerNorm.weight", "news_encoder.bert_model.encoder.layer.3.attention.output.LayerNorm.bias", "news_encoder.bert_model.encoder.layer.3.intermediate.dense.weight", "news_encoder.bert_model.encoder.layer.3.intermediate.dense.bias", "news_encoder.bert_model.encoder.layer.3.output.dense.weight", "news_encoder.bert_model.encoder.layer.3.output.dense.bias", "news_encoder.bert_model.encoder.layer.3.output.LayerNorm.weight", "news_encoder.bert_model.encoder.layer.3.output.LayerNorm.bias", "news_encoder.bert_model.pooler.dense.weight", "news_encoder.bert_model.pooler.dense.bias", "news_encoder.attn.att_fc1.weight", "news_encoder.attn.att_fc1.bias", "news_encoder.attn.att_fc2.weight", "news_encoder.attn.att_fc2.bias", "news_encoder.dense.weight", "news_encoder.dense.bias". 
	Unexpected key(s) in state_dict: "model_state_dict". 

In [None]:
for i, ckpt_path in enumerate(ckpt_paths):
    pretrain_model.load_state_dict(torch.load(ckpt_path))

    body_scoring = []
    with torch.no_grad():
        for input_ids in tqdm(body_dataloader):
            input_ids = input_ids.cuda()
            news_vec = pretrain_model.news_encoder(input_ids)
            news_vec = news_vec.to(torch.device("cpu")).detach().numpy()
            body_scoring.extend(news_vec)

    body_scoring = np.array(body_scoring)
    with open(f'./teacher_body_emb_{i}.pkl', 'wb') as f:
        pickle.dump(body_scoring, f)