<a href="https://colab.research.google.com/gist/slachitoff/ebb532421da6ca4fe10911164adeca45/patent-transformer_fine-tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook was submitted by NYU student  Sky Achitoff

In [1]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from collections import defaultdict
import random

import logging
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.WARNING)

In [2]:
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

BATCH_SIZE = 32

def encodeText(text, max_length=512):
    encodedDict = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        truncation=True,
        padding='max_length',
        max_length=max_length,
        return_attention_mask=True,
        return_tensors='pt'
    )

    return encodedDict

def encodeData(data, max_length=512):
    encodedData = []
    for example in data:
        text = ' '.join([example[section] for section in ['abstract', 'claims']])
        decision = example['decision']
        if decision == 'ACCEPTED':
            label = 1
            encodedExample = encodeText(text, max_length=max_length)
            encodedData.append((encodedExample['input_ids'], encodedExample['attention_mask'], label))
        elif decision == 'REJECTED':
            label = 0
            encodedExample = encodeText(text, max_length=max_length)
            encodedData.append((encodedExample['input_ids'], encodedExample['attention_mask'], label))
        else:
            continue
    return encodedData

def oversampleData(data):

    classCounts = {}
    for example in data:
        label = example[2]
        if label not in classCounts:
            classCounts[label] = 0
        classCounts[label] += 1

    minCount = min(classCounts.values())
    minClass = None
    for label, count in classCounts.items():
        if count == minCount:
            minClass = label

    oversampledData = []
    for example in data:
        oversampledData.append(example)
        if example[2] == minClass:
            oversampledData.append(example)

    return oversampledData

def getDataLoader(oversampledTrainData, valDataset, BATCH_SIZE):
    trainDataLoader = DataLoader(oversampledTrainData, batch_size=BATCH_SIZE, shuffle=True)
    valDataLoader = None
    
    if valDataset is not None:
        valDataLoader = DataLoader(valDataset, batch_size=BATCH_SIZE)
        valDatasetTensors = []
        for x in valDataset:
            inputIds = torch.tensor(x[0]).squeeze()
            attentionMask = torch.tensor(x[1]).squeeze()
            label = torch.tensor(x[2])
            valDatasetTensors.append((inputIds, attentionMask, label))
        valDatasetTensors = tuple([torch.stack(t) for t in zip(*valDatasetTensors)])
        valDataLoader = DataLoader(TensorDataset(*valDatasetTensors), batch_size=BATCH_SIZE)
    
    return trainDataLoader, valDataLoader



Downloading tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [3]:
def evaluate(model, dataloader):
    model.eval()
    totalCount = 0
    totalCorrect = 0
    tp = 0
    fp = 0
    fn = 0

    with torch.no_grad():
        for batch in dataloader:
            batchInputIds = batch[0].to(device)
            batchAttentionMask = batch[1].to(device)
            batchLabels = batch[2].to(device)

            outputs = model(batchInputIds, attention_mask=batchAttentionMask, labels=batchLabels)
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=1)

            totalCount += batchLabels.size(0)
            totalCorrect += torch.sum(predictions == batchLabels)

            tp += ((predictions == 1) & (batchLabels == 1)).sum().item()
            fp += ((predictions == 1) & (batchLabels == 0)).sum().item()
            fn += ((predictions == 0) & (batchLabels == 1)).sum().item()

    accuracy = totalCorrect / totalCount
    
    if (tp + fp) == 0:
        precision = 0
    else:
        precision = tp / (tp + fp)
    
    if (tp + fn) == 0:
        recall = 0
    else:
        recall = tp / (tp + fn)

    print(f'Accuracy = {accuracy}')
    print(f'Precision = {precision}')
    print(f'Recall = {recall}')



In [4]:
from pynvml import *


def print_gpu_utilization():
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"GPU memory occupied: {info.used//1024**2} MB.")


def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()

In [5]:
def main():

    datasetDict = load_dataset('HUPD/hupd',
                                name='sample',
                                #data_files="https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather",
                                train_filing_start_date='2016-01-01',
                                train_filing_end_date='2016-01-21',
                                val_filing_start_date='2016-01-22',
                                val_filing_end_date='2016-01-31')

    #datasetDict = load_dataset('HUPD/hupd', name='sample')
    trainData = datasetDict['train']
    valData = datasetDict['validation']
    trainData = encodeData(trainData)
    valData = encodeData(valData)
    valDataloader = getDataLoader(trainData, valData, BATCH_SIZE)
    trainData = oversampleData(trainData)
    trainDataLoader, valDataLoader = getDataLoader(trainData, valData, BATCH_SIZE)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    lossFn = torch.nn.BCEWithLogitsLoss()
    model.train()
    for epoch in range(40):
        totalLoss = 0
        for batch in trainDataLoader:
            batchInputIds = batch[0].to(device)
            batchAttentionMask = batch[1].to(device)
            batchInputIds = batchInputIds.squeeze(1)
            batchAttentionMask = batchAttentionMask.squeeze(1)
            batchLabels = batch[2].to(device)
            optimizer.zero_grad()
            outputs = model(batchInputIds, attention_mask=batchAttentionMask, labels=batchLabels)
            logits = outputs.logits
            loss = lossFn(logits[:, 1], batchLabels.float())
            loss.backward()
            optimizer.step()
            totalLoss += loss.item()
        print(f'Total loss = {totalLoss}')
        evaluate(model, valDataLoader)
        valAccuracy = evaluate(model, valDataLoader)
        print(f'Epoch {epoch}')

    model.save_pretrained('Patent-Tuned-distilbert')
    tokenizer.save_pretrained('Patent-Tuned-distilbert')


In [6]:
if __name__ == '__main__':
    main()

Downloading builder script:   0%|          | 0.00/14.7k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/10.9k [00:00<?, ?B/s]

Loading dataset with config: PatentsConfig(name='sample', version=0.0.0, data_dir='sample', data_files=None, description='Patent data from January 2016, for debugging')


Downloading data:   0%|          | 0.00/6.67M [00:00<?, ?B/s]

Using metadata file: /home/vscode/.cache/huggingface/datasets/downloads/bac34b767c2799633010fa78ecd401d2eeffd62eff58abdb4db75829f8932710


Downloading data:   0%|          | 0.00/388M [00:00<?, ?B/s]

Reading metadata file: /home/vscode/.cache/huggingface/datasets/downloads/bac34b767c2799633010fa78ecd401d2eeffd62eff58abdb4db75829f8932710
Filtering train dataset by filing start date: 2016-01-01
Filtering train dataset by filing end date: 2016-01-21
Filtering val dataset by filing start date: 2016-01-22
Filtering val dataset by filing end date: 2016-01-31


Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

  inputIds = torch.tensor(x[0]).squeeze()
  attentionMask = torch.tensor(x[1]).squeeze()
