In [None]:
import torch
import torch.nn as nn

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import numpy as np
from tqdm.auto import tqdm

In [None]:
path_turing = './unilmv2'

In [None]:
import sys
sys.path.insert(0, './PLM-NR')

In [None]:
from tnlrv3.modeling import TuringNLRv3ForSequenceClassification
from tnlrv3.configuration_tnlrv3 import TuringNLRv3Config
from tnlrv3.tokenization_tnlrv3 import TuringNLRv3Tokenizer

In [None]:
tokenizer = TuringNLRv3Tokenizer.from_pretrained(os.path.join(path_turing, 'unilm2-base-uncased-vocab.txt'),
                                            do_lower_case=True)

In [None]:
MODEL_CLASSES = {
    'tnlrv3': (TuringNLRv3Config, TuringNLRv3ForSequenceClassification, TuringNLRv3Tokenizer),
}

# Pretrain Dataset

In [None]:
MAX_TEXT_LEN = 512
BATCH_SIZE = 32

In [None]:
file = './docs_filter.tsv'
with open(file, encoding='utf-8') as f:
    total_lines = f.readlines()

In [None]:
len(total_lines)

In [None]:
corpus = []
for line in total_lines:
    splited = line.strip('\n').split('\t')
    nid,cate,subcate,title, body,abstract,url,time = splited
    corpus.append(body)

# Pretrain Dataset

In [None]:
from torch.utils.data import DataLoader, Dataset
import random

In [None]:
class PretrainDataset(Dataset):
    def __init__(self, corpus):
        self.corpus = corpus
        self.len = len(corpus)
        
    def __getitem__(self, idx):
        text = self.corpus[idx]
        tokenized_text = tokenizer(text, max_length=MAX_TEXT_LEN, pad_to_max_length=True, truncation=True)
        input_ids = np.array(tokenized_text['input_ids'])
        attn_mask = np.array(tokenized_text['attention_mask'])
        processed_ids, masked_lm_labels = self.create_masked_lm_labels(input_ids)
        return processed_ids, attn_mask, masked_lm_labels

    def __len__(self):
        return self.len
    
    def create_masked_lm_labels(self, tokens):
        processed_ids = np.array(tokens)
        total_len = np.sum(tokens > 0)
        candi_index = list(range(1, total_len - 1))
        num_to_predict = max(1, int(np.floor((total_len - 2) * 0.15)))

        selected_index = random.sample(candi_index, k=num_to_predict)
        masked_lm_labels = np.array([-1] * len(tokens))

        for i in selected_index:
            masked_lm_labels[i] = tokens[i]
            if random.random() < 0.8:
                masked_token = 104
            else:
            # 10% of the time, keep original
                if random.random() < 0.5:
                    masked_token = tokens[i]
            # 10% of the time, replace with random word
                else:
                    masked_token = random.randint(0, 30521)
            processed_ids[i] = masked_token
        
        return processed_ids, masked_lm_labels

In [None]:
pretrain_ds = PretrainDataset(corpus)

In [None]:
pretrain_dl = DataLoader(pretrain_ds, batch_size=BATCH_SIZE, num_workers=32, shuffle=True, pin_memory=True)

# Pretrain Model

In [None]:
from utils import MODEL_CLASSES
from tnlrv3.modeling import BertOnlyMLMHead

def init_weights(module):
    if isinstance(module, (nn.Linear, nn.Embedding)):
        module.weight.data.normal_(mean=0.0, std=0.02)
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
        module.bias.data.zero_()
        
class Model(torch.nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()
        config_class, model_class, tokenizer_class = MODEL_CLASSES['tnlrv3']
        self.bert_config = config_class.from_pretrained(
            args.config_name,
            num_hidden_layers=args.num_hidden_layers)
        self.bert_model = model_class.from_pretrained(args.model_name, config=self.bert_config)
        self.cls = BertOnlyMLMHead(self.bert_config, self.bert_model.bert.embeddings.word_embeddings.weight)
        self.cls.apply(init_weights)
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=-1)

    def forward(self, input_ids, attention_mask, masked_lm_labels):
        sequence_output = self.bert_model(input_ids, attention_mask)[1]
        prediction_scores = self.cls(sequence_output)

        masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.bert_config.vocab_size), masked_lm_labels.view(-1))
        return masked_lm_loss, prediction_scores

In [None]:
class Args:
    def __init__(self):
        self.model_name = os.path.join(path_turing, 'unilm2-base-uncased.bin')
        self.num_hidden_layers = 12
        self.config_name = os.path.join(path_turing, 'unilm2-base-uncased-config.json')

In [None]:
args = Args()

In [None]:
pretrain_model = Model(args)

In [None]:
device = torch.device('cuda')

In [None]:
pretrain_model.to(device)

In [None]:
def acc(y_true, y_hat):
    y_hat = torch.argmax(y_hat, dim=-1)
    tot = torch.sum(y_true != -1)
    hit = torch.sum((y_true == y_hat) & (y_true != -1))
    return hit.data.float() * 1.0 / tot

In [None]:
import torch.optim as optim
import torch.nn.functional as F

In [None]:
for param in pretrain_model.bert_model.parameters():
    param.requires_grad = False

for index, layer in enumerate(pretrain_model.bert_model.bert.encoder.layer):
    if index in [9, 10, 11]:
        for param in layer.parameters():
            param.requires_grad = True

In [None]:
for name, p in pretrain_model.named_parameters():
    print(name, p.requires_grad)

In [None]:
rest_param = filter(
    lambda x: id(x) not in list(map(id, pretrain_model.bert_model.parameters())), pretrain_model.parameters())

optimizer = optim.Adam(
    [{'params':pretrain_model.bert_model.parameters(),'lr':1e-6},
    {'params':rest_param,'lr':1e-5}]
)

In [None]:
for ep in range(1):
    loss = 0.0
    accuary = 0.0
    cnt = 0
    tqdm_util = tqdm(pretrain_dl)
    pretrain_model.train()
    for input_ids, attn_mask, mlm_labels in tqdm_util: 
        input_ids = input_ids.to(device)
        attn_mask = attn_mask.to(device)
        mlm_labels = mlm_labels.to(device)
        
        bz_loss, y_hat = pretrain_model(input_ids, attn_mask, mlm_labels)
        loss += bz_loss.data.float()
        bz_acc = acc(mlm_labels, y_hat)
        accuary += bz_acc

        optimizer.zero_grad()
        bz_loss.backward()
        optimizer.step()

        if cnt % 10 == 0:
            tqdm_util.set_description('ed: {}, train_loss: {:.5f}, acc: {:.5f}'.format(cnt * BATCH_SIZE, loss.data / cnt, accuary / cnt))
        
        cnt += 1


In [None]:
ckpt_path = './FP_12_layer.pt'
torch.save({'model_state_dict': pretrain_model.state_dict()}, ckpt_path)