# Train a Bertflow model for sentence embedding (3 epochs)

In [22]:
import pandas as pd
import pickle
import numpy as np
from torch.utils.data import Dataset,DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AdamW
from operator import itemgetter
from sklearn.model_selection import StratifiedKFold

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.cuda.set_device(0)

In [23]:
label2id = pickle.load(open('../temp_results/mini_label2id_dict.pkl','rb'))
id2label = pickle.load(open('../temp_results/mini_id2label_lst.pkl','rb'))

In [24]:
train_data = pd.read_csv('../data/Patent14K/train.csv')
test_data = pd.read_csv('../data/Patent14K/test.csv')

# Define Model

In [25]:
from tflow_utils import TransformerGlow, AdamWeightDecayOptimizer
from transformers import AutoTokenizer

model_name_or_path = 'anferico/bert-for-patents'
bertflow = TransformerGlow(model_name_or_path, pooling='first-last-avg')  # pooling could be 'mean', 'max', 'cls' or 'first-last-avg' (mean pooling over the first and the last layers)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters= [
    {
        "params": [p for n, p in bertflow.glow.named_parameters()  \
                        if not any(nd in n for nd in no_decay)],  # Note only the parameters within bertflow.glow will be updated and the Transformer will be freezed during training.
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in bertflow.glow.named_parameters()  \
                        if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamWeightDecayOptimizer(
    params=optimizer_grouped_parameters, 
    lr=1e-3, 
    eps=1e-6,
)
num_epochs = 3


Some weights of the model checkpoint at anferico/bert-for-patents were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Data Loader

In [26]:
def str2id_lst(str_label):
    id_lst = []
    for l in str_label.split(','):
        id_lst.append(label2id[l])
    return id_lst

class PatentDataset(Dataset):
    def __init__(self,df,labeled = True):
        self.df = df
        self.labeled = labeled
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self,idx):
        text = self.df.iloc[idx]['text'][3:]
        label = str2id_lst(self.df.iloc[idx]['cpc_ids'])
        
        if self.labeled:
            return text,label
        else:
            return text,None
        

In [27]:
train_dataset = PatentDataset(train_data)

In [28]:
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]
    
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents,
                                      truncation=True,
                                      padding='longest',
                                      max_length=512,
                                      return_tensors='pt',
                                      return_length=True,
                                      )
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    
    batch_label = np.zeros((len(labels),len(id2label)))
    for i,_label in enumerate(labels):
        batch_label[i,_label]=1
    
    batch_label = torch.tensor(batch_label,dtype=torch.float32)
    
    return input_ids, attention_mask, token_type_ids, batch_label
    

In [29]:
train_dataloader = DataLoader(dataset = train_dataset,
                              batch_size = 32,
                              collate_fn = collate_fn,
                              shuffle = True)

# Training

In [30]:
from tqdm import tqdm
for epoch in range(num_epochs):
    train_loss = 0
    bertflow = bertflow.cuda()
    bertflow.train()
    for iter,(input_ids, attention_mask, token_type_ids, batch_label) in enumerate(tqdm(train_dataloader)):
        input_ids = input_ids.cuda()
        attention_mask = attention_mask.cuda()
        token_type_ids = token_type_ids.cuda()
        z, loss = bertflow(input_ids, attention_mask, return_loss=True)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.detach().item()
    
    print(f'epoch:{epoch+1}/{num_epochs},the training loss is {train_loss}')

bertflow.save_pretrained('output')  # Save model

100%|██████████| 438/438 [03:56<00:00,  1.85it/s]


epoch:1/3,the training loss is -456.16820472478867


100%|██████████| 438/438 [03:58<00:00,  1.83it/s]


epoch:2/3,the training loss is -508.12320351600647


100%|██████████| 438/438 [03:57<00:00,  1.84it/s]


epoch:3/3,the training loss is -525.2040989398956


In [31]:
bertflow = TransformerGlow.from_pretrained('output')  # Load model

In [16]:
# patentModel.load_state_dict(torch.load('ckpt/001/best_model.pth'))

<All keys matched successfully>

In [15]:
# torch.save(patentModel.state_dict(), 'ckpt/001/best_model.pth')