Adapted From: https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/BERT/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb#scrollTo=mjJGEXShp7te

In [None]:
!pip install accelerate -U

In [1]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
from transformers import BertTokenizer, BertModel
from torch import nn
from transformers import DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import pickle

In [7]:
classnames = ["mistral", "rag", "3hop-rag", "no-label"]
# classnames = ["mistral", "rag"]

num_labels=len(classnames)
batch_size = 16

In [2]:
with open("labeled_data.pkl", "rb") as f:
  labeled_data = pickle.load(f)


In [8]:
labels = {cls: 0 for cls in classnames}
for example in labeled_data:
  label = example["labels"]
  labels[label] += 1


In [9]:
labels

{'mistral': 380, 'rag': 256, '3hop-rag': 63, 'no-label': 301}

In [None]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [None]:
# Custom Model
def encode_astensors(data):
  text_lower = data["question"].lower()
  label = data["labels"]
  encoded_text = tokenizer(text_lower,
                            padding='max_length',
                            truncation=True,
                            max_length=256,
                            return_tensors="pt"
                            )
  # set unlabeled to most difficult level
  if label not in classnames:
    label = classnames[-1]

  encoded_text['labels'] = torch.LongTensor([classnames.index(label)])
  return encoded_text

# Use this for HF trainer
def encode_aslist(data):
  text_lower = data["question"].lower()
  label = data["labels"]
  encoded_text = tokenizer(text_lower,
                            padding='max_length',
                            truncation=True,
                            max_length=128,
                            # return_tensors="pt"
                            )
  if label not in classnames:
    label = classnames[-1]

  encoded_text['labels'] = [classnames.index(label)]
  return encoded_text

In [None]:
# encoded_data = list(map(encode_aslist, labeled_data))
encoded_data = list(map(encode_astensors, labeled_data))

In [None]:
torch.random.manual_seed(42)
train_set, val_set = torch.utils.data.random_split(list(encoded_data), [0.85, 0.15])

In [None]:
print(len(encoded_data))

1000


# Using BERTForSequenceClassification + HF Trainer

In [None]:
import numpy as np

def get_accuracy(output):
  pred = np.argmax(output.predictions, axis=1)
  labels = output.label_ids
  accuracy = (pred[:, np.newaxis] == labels).sum() / pred.shape[0]
  return {"accuracy": accuracy}

In [None]:
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased",
                                                           problem_type="single_label_classification",
                                                           num_labels=num_labels,
                                                           )

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
for param in model.base_model.parameters():
  param.requires_grad = False

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)

In [None]:
from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    f"bert-hotpotqa-classifier-frozen",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    # push_to_hub=True,
)

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_set,
    eval_dataset=val_set,
    tokenizer=tokenizer,
    compute_metrics=get_accuracy
)

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,1.133072,0.44
2,No log,1.109798,0.453333
3,No log,1.098037,0.46
4,No log,1.09329,0.46
5,No log,1.092224,0.46


TrainOutput(global_step=270, training_loss=1.1308730513961227, metrics={'train_runtime': 49.5774, 'train_samples_per_second': 85.725, 'train_steps_per_second': 5.446, 'total_flos': 279558006336000.0, 'train_loss': 1.1308730513961227, 'epoch': 5.0})

# Custom Model (performed worse than default BERT Classifier)

In [None]:
train_dataloader = DataLoader(train_set, batch_size=batch_size)
val_dataloader = DataLoader(val_set, batch_size=batch_size)

In [None]:
class BertQuestionClassifier(nn.Module):

  def __init__(self, hidden_dim=768, device=None):
    super(BertQuestionClassifier, self).__init__()

    if device is None:
      self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
    else:
      self.device = device

    self.bert = BertModel.from_pretrained('bert-base-uncased').to(self.device)
    self.classifier = nn.Sequential(
      nn.Linear(hidden_dim, hidden_dim),
      nn.ReLU(),
      nn.Dropout(0.1),
      nn.Linear(hidden_dim, num_labels),
    ).to(self.device)

  def forward(self, input_ids, attention_mask=None, labels=None,
              token_type_ids=None):
    outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
    sequence_output = outputs[1]
    logits = self.classifier(sequence_output)
    return logits


In [None]:
model = BertQuestionClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

In [None]:
for param in model.bert.parameters():
  param.requires_grad = False

In [None]:
from tqdm import tqdm

def train_single_epoch():
  n_examples = len(train_set)
  model.train()
  total_loss = 0
  total_accuracy = 0
  for i ,batch in tqdm(enumerate(train_dataloader)):
    #print([v[0] for k,v in batch.items()])
    batch = {k: v.to(model.device).squeeze(1) for k,v in batch.items()}
    logits = model(**batch)
    loss = criterion(logits, batch["labels"])
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    total_loss += loss.sum().item()
    total_accuracy += (logits.argmax(dim=1) == batch["labels"]).sum().item()
  return total_loss / n_examples, total_accuracy / n_examples


def eval_loop():
  n_examples = len(val_set)
  model.eval()
  total_loss = 0
  total_accuracy = 0
  for i ,batch in tqdm(enumerate(val_dataloader)):
    batch = {k: v.to(model.device).squeeze(1) for k,v in batch.items()}
    logits = model(**batch)
    loss = criterion(logits, batch["labels"])
    total_loss += loss.sum().item()
    total_accuracy += (logits.argmax(dim=1) == batch["labels"]).sum().item()
  return total_loss / n_examples, total_accuracy / n_examples


In [None]:
best_acc = 0.0
for epoch in range(5):
  print(epoch, "train (loss, acc)", train_single_epoch())
  print(epoch, "val (loss, acc)", eval_loop())

54it [00:34,  1.57it/s]


0 train (loss, acc) (0.04211344010689679, 0.6058823529411764)


10it [00:02,  4.49it/s]


0 val (loss, acc) (0.045101087093353275, 0.5866666666666667)


54it [00:34,  1.54it/s]


1 train (loss, acc) (0.03533679548431845, 0.7282352941176471)


10it [00:02,  4.70it/s]


1 val (loss, acc) (0.042889880339304604, 0.6666666666666666)


54it [00:34,  1.57it/s]


2 train (loss, acc) (0.029017725166152506, 0.8094117647058824)


10it [00:02,  4.64it/s]


2 val (loss, acc) (0.04265058239301046, 0.7133333333333334)


54it [00:34,  1.56it/s]


3 train (loss, acc) (0.025109538339516697, 0.84)


10it [00:02,  4.64it/s]


3 val (loss, acc) (0.050702066818873084, 0.6733333333333333)


54it [00:34,  1.56it/s]


4 train (loss, acc) (0.017865023972357022, 0.8929411764705882)


10it [00:02,  4.68it/s]

4 val (loss, acc) (0.05723710079987844, 0.6866666666666666)



