In [1]:
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from transformers import (
    DataCollatorWithPadding,
    Trainer,
    default_data_collator,
    set_seed,
    TrainingArguments,
    HfArgumentParser,
    EvalPrediction,
)
from datasets import load_dataset
import random
import numpy as np
import torch
import evaluate

import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import evaluate


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# data parameters
dataset_name="fancyzhx/ag_news"
text_column_name = "text"

# model parameters
model_name_or_path='prajjwal1/bert-mini'
use_fast_tokenizer=True
finetuning_task="text-classification",
max_seq_length=512

# training parameters
pad_to_max_length = True
max_train_samples=120000
fp16 = False

## Load data



In [None]:
raw_datasets = load_dataset(dataset_name)
label_list = raw_datasets['train'].unique("label")
# we will treat the label list as a list of string instead of int, consistent with model.config.label2id
label_list = [str(label) for label in label_list]
label_list.sort()
num_labels = len(label_list)


tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

# Padding strategy
if pad_to_max_length:
    padding = "max_length"
else:
    # We will pad later, dynamically at batch creation, to the max sequence length in each batch
    padding = False
    

label_to_id = {v: i for i, v in enumerate(label_list)}
# # update config with label infos
# if model.config.label2id != label_to_id:

#     model.config.label2id = label_to_id
#     model.config.id2label = {id: label for label, id in label_to_id.items()}


max_seq_length = min(max_seq_length, tokenizer.model_max_length)

def preprocess_function(examples):
# return a dict
    examples["sentence"] = examples[text_column_name]
    # Tokenize the texts
    result = tokenizer(examples["sentence"], padding=padding, max_length=max_seq_length, truncation=True)
    if label_to_id is not None and "label" in examples:
        result["label"] = [(label_to_id[str(l)] if l != -1 else -1) for l in examples["label"]]
    
    # add additional keys: 'input_ids','token_type_ids', 'attention_mask','label'  
    return result

# test = preprocess_function(raw_datasets['train'][0])
# Running the preprocessing pipeline on all the datasets

raw_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    num_proc=64,
    desc="Running tokenizer on dataset",
)


In [4]:

train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]


In [5]:
if pad_to_max_length:
    data_collator = default_data_collator
elif fp16:
    data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
else:
    data_collator = None

In [6]:
train_dataloader = DataLoader(train_dataset, batch_size=32, collate_fn=data_collator)

## Model

In [7]:
config = AutoConfig.from_pretrained(model_name_or_path,
                                    num_labels=num_labels,
                                    finetuning_task=finetuning_task)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name_or_path,
    config=config
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-mini and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [45]:
class BertClassifier(nn.Module):
    def __init__(self, model_name, config, num_labels):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name, config)
        self.pred_head = nn.Linear(self.encoder.config.hidden_size, num_labels)
        
    def forward(self, batch):
        
        encoder_outputs = self.encoder(**batch)
        print(encoder_outputs.last_hidden_state.shape)
        pred_out = self.pred_head(encoder_outputs.last_hidden_state[:,0,:].squeeze())
        return pred_out
        
        

In [46]:
bert_model = BertClassifier(model_name=model_name_or_path, config=config, num_labels=num_labels)

## Training & Evaluation

In [9]:
def move_to_device(batch, device):
    
    for k, v in batch.items():
        batch[k] = v.to(device)
    return batch

In [10]:
def compute_batch_loss(model, batch, loss_fn):
    
    model_outputs = model(**batch)
    labels = batch['labels']
    loss = loss_fn(model_outputs.logits, labels)
    
    return loss

def train_one_epoch(train_dataloader, optimizer, loss_fn, model, device, output_freq):
    
    model.train()
    training_loss = []
    step = 0
    for batch in train_dataloader:
        
        batch = move_to_device(batch, device)
        loss = compute_batch_loss(model, batch, loss_fn)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        training_loss.append(loss.item())
        step += 1
        
        if step % output_freq == 0:
            print(f"steps: {step}, loss: {sum(training_loss)/step}")
            
    return model

In [14]:
lr = 2e-5
set_seed(1)
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
device = 'cuda'
model = model.to(device)


model_trained = train_one_epoch(train_dataloader, optimizer, loss_fn, model, device, output_freq=500)

steps: 500, loss: 0.27647993139922616
steps: 1000, loss: 0.24354113303497435
steps: 1500, loss: 0.23178140292192498
steps: 2000, loss: 0.2284550635199994
steps: 2500, loss: 0.21780821318142116
steps: 3000, loss: 0.21243902171993007
steps: 3500, loss: 0.21010713723168842


In [67]:


def compute_metrics(eval_dataloader, metrics, model, device):
    model.eval()
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for batch in eval_dataloader:
            batch = move_to_device(batch, device)
            model_output = model(**batch)
            logits = model_output.logits
            labels = batch['labels']
            
            preds = torch.argmax(logits, dim=-1)
            
            all_labels.extend(labels.detach().cpu().tolist())
            all_preds.extend(preds.detach().cpu().tolist())
            
        result = metrics(predictions=all_preds, references = all_labels)
    print(result)
    print(len(all_labels))
    return all_preds, all_labels
            

In [68]:
def accuracy_metrics(predictions, references):
    
    count = len(predictions)
    correct_count = np.sum(np.array(predictions) == np.array(references))
    
    print(f"accuracy: {correct_count / count}; prediction count {count}")

In [69]:
def precision_recall(predictions, references, class_id):
    
    predictions, references = np.array(predictions), np.array(references)
    

    TP = np.sum((predictions == class_id) & (references == class_id))
    precision = TP / np.sum(predictions == class_id)
    
    recall = TP / np.sum(references == class_id)
    
    print(f"class id: {class_id}, precision: {precision}, recall: {recall}")
    

In [70]:
def metrics(predictions, references):
    labels = set(references)
    
    accuracy_metrics(predictions, references)
    for label in labels:
        precision_recall(predictions, references, label)

In [71]:
eval_dataloader = DataLoader(eval_dataset, batch_size=32, collate_fn=data_collator)
metric = evaluate.load("accuracy")
all_preds, all_labels = compute_metrics(eval_dataloader, metrics, model_trained, device)

# {'accuracy': 0.9269736842105263}

accuracy: 0.9269736842105263; prediction count 7600
class id: 0, precision: 0.9310889005786428, recall: 0.9315789473684211
class id: 1, precision: 0.9606942317508933, recall: 0.9905263157894737
class id: 2, precision: 0.9004259850905219, recall: 0.89
class id: 3, precision: 0.9140708915145005, recall: 0.8957894736842106
None
7600


In [72]:
from sklearn.metrics import classification_report

In [74]:
out = classification_report(all_labels, all_preds)

In [78]:
np.bincount(all_labels)

array([1900, 1900, 1900, 1900])

In [75]:
print(out)

              precision    recall  f1-score   support

           0       0.93      0.93      0.93      1900
           1       0.96      0.99      0.98      1900
           2       0.90      0.89      0.90      1900
           3       0.91      0.90      0.90      1900

    accuracy                           0.93      7600
   macro avg       0.93      0.93      0.93      7600
weighted avg       0.93      0.93      0.93      7600

