The BERT+LSTM model is used for processing the claims information from patent documents. 

In [1]:
import torch 
import torch.nn as nn 
import torch.functional as f 
from torch.optim import AdamW 
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler  
from sklearn.model_selection import train_test_split 
import time 
import datetime 
import re 
import math 
import os 
import json 
from tqdm import tqdm 
import numpy as np 
import pandas as pd 
import seaborn as sns
import matplotlib.pyplot as plt
from transformers import BigBirdModel, BigBirdTokenizer, get_linear_schedule_with_warmup, BertModel, BertTokenizer

# Inspect Sample Data

In [2]:
files = os.listdir('sample_data') 
print("number of sample files = {}".format(len(files)))

number of sample files = 1471


In [3]:
with open('sample_data/'+files[0]) as f: 
    data = json.load(f) 
    print(data.keys())

dict_keys(['patent1id', 'patent2id', 'fullpatent1id', 'fullpatent2id', 'patent1abstract', 'patent2abstract', 'patent1claims', 'patent2claims'])


In [4]:
sample_train = [] 

cnt = 0 

for idx, file in enumerate(files): 
    filename = 'sample_data/' + file
    if '.json' not in filename: 
        continue
    try: 
        with open(filename) as f:  
            data = json.load(f) 
            sample_train.append(data)
    except Exception as e: 
        print("problem occured with file {}".format(file))
        print(e)
        cnt += 1 

problem occured with file 12001939_20030108880.json
Extra data: line 89 column 2 (char 14144)
problem occured with file 12126663_20080250190.json
Extra data: line 85 column 2 (char 20614)
problem occured with file 12001631_5382185.json
Extra data: line 106 column 2 (char 29769)
problem occured with file 12001003_4589656.json
Extra data: line 114 column 2 (char 25639)
problem occured with file 12002104_6609036.json
Extra data: line 191 column 2 (char 37994)
problem occured with file 12002154_20010029974.json
Extra data: line 114 column 2 (char 36254)
problem occured with file 12001021_20060183900.json
Extra data: line 395 column 2 (char 126007)
problem occured with file 12002137_7338160.json
Extra data: line 77 column 2 (char 15328)
problem occured with file 12000641_20070108429.json
Extra data: line 125 column 2 (char 30036)
problem occured with file 12001962_20040263061.json
Extra data: line 104 column 2 (char 31536)
problem occured with file 12000959_20050276053.json
Extra data: line

problem occured with file 12001062_20060257718.json
Extra data: line 49 column 2 (char 14449)
problem occured with file 12126762_6777660.json
Extra data: line 80 column 2 (char 16719)
problem occured with file 12000298_6851645.json
Extra data: line 120 column 2 (char 29592)
problem occured with file 12126762_6587142.json
Extra data: line 75 column 2 (char 12556)
problem occured with file 12000716_20030156467.json
Extra data: line 172 column 2 (char 34440)
problem occured with file 12002154_3879229.json
Extra data: line 47 column 2 (char 31372)


In [5]:
print("number of problematic files = {}/{}".format(cnt, len(files)))

number of problematic files = 105/1471


# Preprocess Data

In [6]:
def split_text(s, overlap, chunk_size): 
    total = [] 
    partial = [] 
    if len(s.split()) // (chunk_size - overlap) > 0: 
        n = len(s.split()) // (chunk_size - overlap) 
    else: 
        n = 1 
    for w in range(n): 
        if w == 0: 
            partial = s.split()[:chunk_size]
            total.append(" ".join(partial)) 
        else: 
            partial = s.split()[w*(chunk_size - overlap):w*(chunk_size - overlap) + chunk_size] 
            total.append(" ".join(partial))  
    return total 

In [7]:
def chunk(in_string,num_chunks):
    chunk_size = len(in_string)//num_chunks
    if len(in_string) % num_chunks: chunk_size += 1
    iterator = iter(in_string)
    for _ in range(num_chunks):
        accumulator = list()
        for _ in range(chunk_size):
            try: accumulator.append(next(iterator))
            except StopIteration: break
        yield ''.join(accumulator)

In [8]:
### define tokenizer ### 
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [9]:
def split_into_parts(s, n): 
    splitted_ids, splitted_masks = [], [] 
    encoded_dict = tokenizer.encode_plus(
        text = s, 
        add_special_tokens = True, 
        pad_to_max_length = False,  
        return_attention_mask = True 
    ) 
    input_id = encoded_dict['input_ids'] 
    attention_mask = encoded_dict['attention_mask'] 
    
    ### split into n chunks ### 
    splitted_input_id = np.array_split(np.asarray(input_id), n) 
    splitted_attention_mask = np.array_split(np.asarray(attention_mask), n) 
    
    return splitted_input_id, splitted_attention_mask
    

In [10]:
splitted_ids1, splitted_ids2 = [], [] 
splitted_masks1, splitted_masks2 = [], []   
labels = [] 
longcnts = 0 
for i in tqdm(range(len(sample_train)), position=0, leave=True):  

    text = sample_train[i]['patent1claims']
    text = ' '.join(text)
    splitted_id, splitted_attention_mask = split_into_parts(text, 50)
    splitted_id1, splitted_mask1 = [], []  
    for j in range(len(splitted_id)): 
        if len(splitted_id[j]) < 100: 
            splitted_id1.append(np.concatenate([splitted_id[j], np.zeros((100-len(splitted_id[j])))]))
            splitted_mask1.append(np.concatenate([splitted_attention_mask[j], np.zeros((100-len(splitted_attention_mask[j])))]))
        elif len(splitted_id[j]) >= 100: 
            splitted_id1.append(splitted_id[j][:100]) 
            splitted_mask1.append(splitted_attention_mask[j][:100])
    splitted_ids1.append(splitted_id1) 
    splitted_masks1.append(splitted_mask1)
        
    text2 = sample_train[i]['patent2claims']
    text2 = ' '.join(text2) 
    splitted_id, splitted_attention_mask = split_into_parts(text2, 50)
    splitted_id2, splitted_mask2 = [], [] 
    for j in range(len(splitted_id)): 
        if len(splitted_id[j]) < 100: 
            splitted_id2.append(np.concatenate([splitted_id[j], np.zeros((100-len(splitted_id[j])))])) 
            splitted_mask2.append(np.concatenate([splitted_attention_mask[j], np.zeros((100-len(splitted_attention_mask[j])))]))
        elif len(splitted_id[j]) >= 100:
            splitted_id2.append(splitted_id[j][:100])  
            splitted_mask2.append(splitted_attention_mask[j][:100])  
    
    
    splitted_ids2.append(splitted_id2) 
    splitted_masks2.append(splitted_mask2)
    
    ### currently all data are positive samples ### 
    labels.append(1.0)
    

  0%|          | 0/1365 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1959 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 1365/1365 [02:19<00:00,  9.75it/s]


In [11]:
### shape should be (batch, 50, 100) ### 

splitted_ids1 = torch.tensor(splitted_ids1, dtype=int) 
splitted_ids2 = torch.tensor(splitted_ids2, dtype=int) 
splitted_masks1 = torch.tensor(splitted_masks1, dtype=int) 
splitted_masks2 = torch.tensor(splitted_masks2, dtype=int)
labels = torch.tensor(labels, dtype=float)
labels = torch.reshape(labels, (-1,1))

splitted_ids1.shape, splitted_ids2.shape, splitted_masks1.shape, splitted_masks2.shape, labels.shape

(torch.Size([1365, 50, 100]),
 torch.Size([1365, 50, 100]),
 torch.Size([1365, 50, 100]),
 torch.Size([1365, 50, 100]),
 torch.Size([1365, 1]))

# Create DataLoader

In [12]:
train_claims1_ids, val_claims1_ids, train_claims1_attention_masks, val_claims1_attention_masks = train_test_split(splitted_ids1, 
                                                                                                                  splitted_masks1, 
                                                                                                                  random_state=888, 
                                                                                                                  test_size=0.1)

train_claims2_ids, val_claims2_ids, train_claims2_attention_masks, val_claims2_attention_masks = train_test_split(splitted_ids2,
                                                                                                                  splitted_masks2,
                                                                                                                  random_state=888, 
                                                                                                                  test_size=0.1)

_, _, train_labels, val_labels = train_test_split(splitted_ids1, 
                                                  labels, 
                                                  random_state=888, 
                                                  test_size=0.1)


train_claims1_ids.shape, val_claims1_ids.shape, train_claims1_attention_masks.shape, val_claims1_attention_masks.shape, train_claims2_ids.shape, val_claims2_ids.shape, train_claims2_attention_masks.shape, val_claims2_attention_masks.shape, train_labels.shape, val_labels.shape

(torch.Size([1228, 50, 100]),
 torch.Size([137, 50, 100]),
 torch.Size([1228, 50, 100]),
 torch.Size([137, 50, 100]),
 torch.Size([1228, 50, 100]),
 torch.Size([137, 50, 100]),
 torch.Size([1228, 50, 100]),
 torch.Size([137, 50, 100]),
 torch.Size([1228, 1]),
 torch.Size([137, 1]))

In [13]:
batch_size = 2
train_data = TensorDataset(train_claims1_ids, 
                           train_claims1_attention_masks, 
                           train_claims2_ids, 
                           train_claims2_attention_masks, 
                           train_labels)

train_sampler = RandomSampler(train_data) 

train_dataloader = DataLoader(train_data, sampler = train_sampler, batch_size = batch_size) 

In [14]:
validation_data = TensorDataset(val_claims1_ids, 
                                val_claims1_attention_masks, 
                                val_claims2_ids, 
                                val_claims2_attention_masks,  
                                val_labels) 

val_sampler = SequentialSampler(validation_data) 

val_dataloader = DataLoader(validation_data, sampler = val_sampler, batch_size = batch_size)

# Claims model

In [15]:
class BERT_LSTM(nn.Module): 
    def __init__(self): 
        super(BERT_LSTM, self).__init__() 
        self.bert1 = BertModel.from_pretrained("bert-base-uncased") 
        self.bert2 = BertModel.from_pretrained("bert-base-uncased") 
        self.lstm1 = nn.LSTM(768, 256, batch_first=True, bidirectional=True) 
        self.lstm2 = nn.LSTM(768, 256, batch_first=True, bidirectional=True) 
        self.fc1 = nn.Linear(1024, 256) 
        self.batchnorm1 = nn.BatchNorm1d(256) 
        self.fc2 = nn.Linear(256, 64) 
        self.batchnorm2 = nn.BatchNorm1d(64) 
        self.fc3 = nn.Linear(64, 1)
        self.activation = nn.Sigmoid() 
        self.seq_len = 50
        
    def forward(self, ids1, masks1, ids2, masks2): 
        first_seq = [] 
        batch_size = ids1.shape[0]
        for i in range(self.seq_len): 
            outputs1 = self.bert1(input_ids = ids1[:,i,:], 
                                  attention_mask = masks1[:,i,:]) 
            pooler1 = outputs1.pooler_output
            pooler1 = pooler1.flatten()
            first_seq.append(pooler1)
        first_seq = torch.stack(first_seq) 
        first_seq = torch.reshape(first_seq, (batch_size, self.seq_len, 768))
        # print(first_seq.shape)
        
        second_seq = [] 
        for i in range(self.seq_len): 
            outputs2 = self.bert2(input_ids = ids2[:,i,:], 
                                  attention_mask = masks2[:,i,:])
            pooler2 = outputs2.pooler_output
            pooler2 = pooler2.flatten() 
            second_seq.append(pooler2) 
        second_seq = torch.stack(second_seq) 
        second_seq = torch.reshape(second_seq, (batch_size, self.seq_len, 768))
        # print(second_seq.shape)
        
        lstm1_output, (h, c) = self.lstm1(first_seq) 
        lstm2_output, (h, c) = self.lstm2(second_seq) 
        
        lstm1_output = lstm1_output[:,-1,:] 
        lstm2_output = lstm2_output[:,-1,:] 
        
        # print(lstm1_output.shape, lstm2_output.shape) 
        
        hidden = torch.cat((lstm1_output, lstm2_output), axis=1) 
        # print("hidden shape = {}".format(hidden.shape)) 
        fc1 = self.fc1(hidden) 
        bn1 = self.batchnorm1(fc1) 
        fc2 = self.fc2(bn1)
        bn2 = self.batchnorm2(fc2) 
        fc3 = self.activation(self.fc3(bn2)) 
        return fc3 

In [16]:
model = BERT_LSTM() 
model.cuda() 

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_

BERT_LSTM(
  (bert1): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
 

In [32]:
cnt = 0
device = torch.device('cuda')
for batch in train_dataloader: 
    if cnt > 0: 
        break 
    batch = tuple(t.to(device) for t in batch) 
    
    b_ids1, b_masks1, b_ids2, b_masks2, b_labels = batch  
    outputs = model(ids1 = b_ids1, 
                    masks1 = b_masks1, 
                    ids2 = b_ids2,  
                    masks2 = b_masks2) 
    
    print(outputs)
    
    cnt += 1


2
tensor([[0.4926],
        [0.5562]], device='cuda:0', grad_fn=<SigmoidBackward>)


# Train

In [17]:
def flat_accuracy(preds, labels):  
    pred_flat = np.argmax(preds, axis=1).flatten() 
    labels_flat = labels.flatten() 
    return np.sum(pred_flat == labels_flat) / len(labels_flat) 

def format_time(elapsed):  
    elapsed_rounded = int(round(elapsed))
    return str(datetime.timedelta(seconds=elapsed_rounded)) 

In [18]:
class EarlyStopping: 
    ''' if validation loss does not decrease anymore, we stop training '''
    def __init__(self, patience, verbose, delta, path): 
        self.patience = patience 
        self.verbose = verbose 
        self.counter = 0 
        self.best_score = None 
        self.early_stop = False 
        self.val_loss_min = np.Inf 
        self.delta = delta 
        self.path = path 
    
    def __call__(self, val_loss, model): 
        score = -val_loss 
        if self.best_score is None:  
            self.best_score = score 
            self.save_checkpoint(val_loss, model) 
        elif score < self.best_score + self.delta: 
            self.counter += 1 
            print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience: 
                self.early_stop = True 
        else: 
            self.best_score = score 
            self.save_checkpoint(val_loss, model) 
            self.counter = 0 
    
    def save_checkpoint(self, val_loss, model): 
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [19]:
device = torch.device('cuda')

optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)

epochs = 2

total_steps = len(train_dataloader) * epochs 

scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0, 
                                            num_training_steps = total_steps)

train_losses, val_losses = [], [] 
### binary crossentropy loss ### 
criterion = nn.BCELoss() 

### early stopping ### 
early_stopping = EarlyStopping(patience = 3, verbose = True, delta=0, path="BERT_CLAIMS.pt")
tolerance = True 

### initialize gradient ###
model.zero_grad() 

for epoch_i in range(0, epochs):  
    ### Training ### 
    print("")
    print("===== Epoch {:} / {:} =====".format(epoch_i + 1, epochs))
    print("Training...")
    t0 = time.time() 
    total_loss = 0 
    model.train() 
    for step, batch in enumerate(train_dataloader): 
        if step%20 == 0 and not step == 0: 
            elapsed = format_time(time.time() - t0)
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))
            print('  current average loss = {}'.format(total_loss / step))
        
        batch = tuple(t.to(device) for t in batch) 
        
        b_ids1, b_masks1, b_ids2, b_masks2, b_labels = batch 
        outputs = model(ids1 = b_ids1, 
                        masks1 = b_masks1, 
                        ids2 = b_ids2, 
                        masks2 = b_masks2) 
        
        
        
        loss = criterion(outputs.float(), b_labels.float()) 
        total_loss += loss.item() 
        loss.backward() 
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 
        optimizer.step() 
        scheduler.step() 
        model.zero_grad() 
        
    avg_train_loss = total_loss / len(train_dataloader) 
    train_losses.append(avg_train_loss)
    print("")
    print(" Average Training Loss: {}".format(avg_train_loss)) 
    print(" Training epoch took: {:}".format(format_time(time.time() - t0))) 
    ### Validation ### 
    print("")
    print("Running Validation") 
    t0 = time.time() 
    model.eval() 
    
    eval_loss = 0 
    
    for batch in val_dataloader: 
        batch = tuple(t.to(device) for t in batch) 
        b_ids1, b_masks1, b_ids2, b_masks2, b_labels = batch 
        with torch.no_grad(): 
            outputs = model(ids1 = b_ids1, 
                            masks1 = b_masks1, 
                            ids2 = b_ids2, 
                            masks2 = b_masks2)
        
        loss = criterion(outputs.float(), b_labels.float())
        eval_loss += loss.item() 
        
    avg_val_loss = eval_loss / len(val_dataloader)  
    val_losses.append(avg_val_loss)
    print(" Average validation loss: {}".format(avg_val_loss))  
    if tolerance == True:  
        early_stopping(avg_val_loss, model) 
    elif tolerance == False: 
        if avg_val_loss == np.min(val_losses): 
            print("saving best checkpoint after early stopping!") 
            torch.save(model.state_dict(), "BERT_CLAIMS_" + str(epoch_i))
    
    if early_stopping.early_stop: 
        print("We are out of patience") 
        tolerance = False 


===== Epoch 1 / 2 =====
Training...
  Batch    20  of    614.    Elapsed: 0:01:20.
  current average loss = 0.7439177662134171
  Batch    40  of    614.    Elapsed: 0:02:31.
  current average loss = 0.7469433724880219
  Batch    60  of    614.    Elapsed: 0:03:41.
  current average loss = 0.7452892124652862
  Batch    80  of    614.    Elapsed: 0:04:55.
  current average loss = 0.7434857599437237
  Batch   100  of    614.    Elapsed: 0:06:04.
  current average loss = 0.739684824347496
  Batch   120  of    614.    Elapsed: 0:07:13.
  current average loss = 0.7373640323678653
  Batch   140  of    614.    Elapsed: 0:08:22.
  current average loss = 0.7356224200555257
  Batch   160  of    614.    Elapsed: 0:09:38.
  current average loss = 0.7345329008996486
  Batch   180  of    614.    Elapsed: 0:10:52.
  current average loss = 0.7349275218115913
  Batch   200  of    614.    Elapsed: 0:12:15.
  current average loss = 0.733034462928772
  Batch   220  of    614.    Elapsed: 0:13:26.
  curren