In [4]:
import pandas as pd
import pickle
import numpy as np
from torch.utils.data import Dataset,DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AdamW
from operator import itemgetter
from sklearn.model_selection import StratifiedKFold

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.cuda.set_device(0)

In [5]:
label2id = pickle.load(open('temp_results/mini_label2id_dict.pkl','rb'))
id2label = pickle.load(open('temp_results/mini_id2label_lst.pkl','rb'))

In [3]:
train_data = pd.read_csv('data/mini_train_data.csv')
test_data = pd.read_csv('data/mini_test_data.csv')

In [5]:
from transformers import AutoModelForMaskedLM,AutoTokenizer

model = AutoModelForMaskedLM.from_pretrained("anferico/bert-for-patents")
tokenizer = AutoTokenizer.from_pretrained("anferico/bert-for-patents")

Some weights of the model checkpoint at anferico/bert-for-patents were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Data Loader

In [17]:
def str2id_lst(str_label):
    id_lst = []
    for l in str_label.split(','):
        id_lst.append(label2id[l])
    return id_lst

class PatentDataset(Dataset):
    def __init__(self,df,labeled = True):
        self.df = df
        self.labeled = labeled
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self,idx):
        text = self.df.iloc[idx]['text'][3:]
        label = str2id_lst(self.df.iloc[idx]['cpc_ids'])
        
        if self.labeled:
            return text,label
        else:
            return text,None
        

In [18]:
test_dataset = PatentDataset(test_data)

In [19]:
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]
    
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents,
                                      truncation=True,
                                      padding='max_length',
                                      max_length=500,
                                      return_tensors='pt',
                                      return_length=True)
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    
    batch_label = np.zeros((len(labels),len(id2label)))
    for i,_label in enumerate(labels):
        batch_label[i,_label]=1
    
    batch_label = torch.tensor(batch_label,dtype=torch.float32)
    
    return input_ids, attention_mask, token_type_ids, batch_label
    

In [20]:
test_dataloader = DataLoader(dataset = test_dataset,
                            batch_size = 4,
                            collate_fn = collate_fn)

# Define Model

In [21]:
class PatentClsModel(nn.Module):
    def __init__(self,bert_model,backbone_fixed = True):
        super().__init__()
        self.fc = nn.Sequential(nn.BatchNorm1d(1024),
                               nn.Dropout(0.5),
                               nn.Linear(1024,768),
                               nn.ReLU(),
                               nn.BatchNorm1d(768),
                               nn.Dropout(0.5),
                               nn.Linear(768,len(id2label)))
        
        self.bert_model = bert_model
        self.sig = nn.Sigmoid()
        self.backbone_fixed = backbone_fixed
        
        for i, module in enumerate(self.fc):
            if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.Linear):
                if getattr(module, "weight_v", None) is not None:
                    nn.init.uniform_(module.weight_g, 0, 1)
                    nn.init.kaiming_normal_(module.weight_v)
                    assert model[i].weight_g is not None
                else:
                    nn.init.kaiming_normal_(module.weight)
                nn.init.constant_(module.bias, 0)

    def forward(self, input_ids, attention_mask, token_type_ids):
        if self.backbone_fixed:
            with torch.no_grad():
                x = self.bert_model(input_ids = input_ids,
                                    attention_mask = attention_mask,
                                    token_type_ids = token_type_ids,
                                    output_hidden_states=True).hidden_states[-1]
        else:
            x = self.bert_model(input_ids = input_ids,
                                attention_mask = attention_mask,
                                token_type_ids = token_type_ids,
                                output_hidden_states=True).hidden_states[-1]
            
        x = self.fc(x[:,0])
        x = self.sig(x)
        
        return x
        

# Training

In [22]:
kfold = StratifiedKFold(n_splits=5)
total_epochs = 30
test_predict_lst = []

In [None]:
from tqdm import tqdm
for train_index, valid_index in kfold.split(train_data,train_data['cpc_ids']):
    
    print('*'*20)
    print(f'Fold{len(test_predict_lst)+1}')
    print('*'*20)
    train_dataset = PatentDataset(train_data.iloc[train_index])
    val_dataset = PatentDataset(train_data.iloc[valid_index])

    train_dataloader = DataLoader(train_dataset,
                                 collate_fn = collate_fn,
                                 batch_size = 4,
                                 shuffle = True,
                                 drop_last = True)
    val_dataloader = DataLoader(val_dataset,
                               collate_fn = collate_fn,
                               batch_size = 4,
                               shuffle = True,
                               drop_last = True)

    patentModel = PatentClsModel(model,backbone_fixed = True).cuda()
    loss_func = nn.BCELoss()
    optimizer = AdamW(patentModel.parameters(), lr=5e-4)
    # reg_lambda = 0.035

    print('Dataloader Success---------------------')

    best_val_loss = 100
    for epoch in range(total_epochs):
        if epoch%5==0:
            print('|',">" * epoch," "*(total_epochs-epoch),'|')

        patentModel.train()
        for iter,(input_ids, attention_mask, token_type_ids, batch_label) in enumerate(tqdm(train_dataloader)):
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
            batch_label = batch_label.cuda()

            prediction = patentModel(input_ids, attention_mask, token_type_ids)
            
            # l2_reg = None
            # for w in patentModel.fc.parameters():
            #     if not l2_reg:
            #         l2_reg = w.norm(2)
            #     else:
            #         l2_reg = l2_reg + w.norm(2)

            loss = loss_func(prediction,batch_label)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        val_loss = 0
        patentModel.eval()
        with torch.no_grad():
            for iter,(input_ids, attention_mask, token_type_ids, batch_label) in enumerate(tqdm(val_dataloader)):
                input_ids = input_ids.cuda()
                attention_mask = attention_mask.cuda()
                token_type_ids = token_type_ids.cuda()
                batch_label = batch_label.cuda()
                prediction = patentModel(input_ids, attention_mask, token_type_ids)
                loss = loss_func(prediction,batch_label)
                val_loss += loss.detach().item()
            val_loss = val_loss/(iter+1)

        if epoch%10 == 0:
            print('Epoch {}, val_loss {:.4f}'.format(epoch, val_loss))

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(patentModel.state_dict(), 'ckpt/004/best_model_mini_{}.pth'.format(len(test_predict_lst)+1))
            print('Best val loss found: ', best_val_loss)

    print('This fold, the best val loss is: ', best_val_loss)

    test_loss = 0
    test_predict = None
    patentModel = PatentClsModel(model,backbone_fixed = True).cuda()
    patentModel.load_state_dict(torch.load('ckpt/004/best_model_mini_{}.pth'.format(len(test_predict_lst)+1)))

    patentModel.eval()
    with torch.no_grad():
        for iter,(input_ids, attention_mask, token_type_ids, batch_label) in enumerate(tqdm(test_dataloader)):
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
            batch_label = batch_label.cuda()
            prediction = patentModel(input_ids, attention_mask, token_type_ids)

            if test_predict is None:
                test_predict = prediction
            else:
                test_predict = torch.cat((test_predict,prediction),axis = 0)

            loss = loss_func(prediction,batch_label)
            test_loss += loss.detach().item()

    test_loss /= (iter+1)
    print('This fold, the test loss is: ', test_loss)

    test_predict_lst.append(test_predict)

In [None]:
torch.save(test_predict_lst,'outputs/mini_test_pred_lst/mini_50e_004.pt')

# Last All Last

In [16]:
# patentModel.load_state_dict(torch.load('ckpt/001/best_model.pth'))

<All keys matched successfully>

In [15]:
# torch.save(patentModel.state_dict(), 'ckpt/001/best_model.pth')