<a href="https://colab.research.google.com/github/sjpark0605/NLP-FYP/blob/main/NER_Training_Loop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install datasets evaluate transformers[sentencepiece] seqeval accelerate

In [None]:
# Imports for Data Processing
import pandas as pd
import torch
from datasets import load_from_disk

In [None]:
from google.colab import drive
drive.mount('/content/drive')
dataset_dir = '/content/drive/MyDrive/COMP0029/datasets/'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
device = torch.device('cpu')

if torch.cuda.is_available():
  device = torch.device('cuda')

In [None]:
corpus_datasets = load_from_disk(dataset_dir + 'english-recipe-ner')

In [None]:
ner_feature = corpus_datasets["train"].features["ner_tags"]
label_names = ner_feature.feature.names
pure_label_names = list(set(label.replace("-B", "").replace("-I", "") for label in label_names))

In [None]:
from transformers import AutoTokenizer

model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
# def align_labels_with_tokens(labels, word_ids):
#   new_labels = []
#   current_word = None
#   for word_id in word_ids:
#     if word_id != current_word:
#       # Start of a new word!
#       current_word = word_id
#       label = -100 if word_id is None else labels[word_id]
#       new_labels.append(label)
#     elif word_id is None:
#       # Special token
#       new_labels.append(word_id)
#     else:
#       new_labels.append(labels[word_id])

#       #new_labels.append(labels[word_id])
#       # # Same word as previous token
#       # label = labels[word_id]      
#       # # If the label is XXX-B we change it to XXX-I
#       # if label % 2 == 0 and label != 20:
#       #   label += 1
#       #   new_labels.append(label)

#   if len(new_labels) != len(word_ids):
#     print("FATAL LENGTH MATCHING ERROR")

#   return new_labels

In [None]:
def align_labels_with_tokens(labels, word_ids):
  new_labels = []

  for word_id in word_ids:
    if word_id is None:
      new_labels.append(-100)
    else:
      new_labels.append(labels[word_id])

  if len(new_labels) != len(word_ids):
    print("FATAL LENGTH MATCHING ERROR")

  return new_labels

In [None]:
def tokenize_and_align_labels(examples):
  tokenized_inputs = tokenizer(
      examples["tokens"], truncation=True, is_split_into_words=True, max_length=150
  )
  all_labels = examples["ner_tags"]
  new_labels = []
  for i, labels in enumerate(all_labels):
    word_ids = tokenized_inputs.word_ids(i)
    new_labels.append(align_labels_with_tokens(labels, word_ids))

  tokenized_inputs["labels"] = new_labels
  return tokenized_inputs

In [None]:
tokenized_datasets = corpus_datasets.map(
  tokenize_and_align_labels,
  batched=True,
  remove_columns=corpus_datasets["train"].column_names,
)

Map:   0%|          | 0/2201 [00:00<?, ? examples/s]

Map:   0%|          | 0/551 [00:00<?, ? examples/s]

In [None]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [None]:
import evaluate

metric = evaluate.load("seqeval")

In [None]:
id2label = {i: label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
  tokenized_datasets["train"],
  shuffle=True,
  collate_fn=data_collator,
  batch_size=32,
)

eval_dataloader = DataLoader(
  tokenized_datasets["valid"], 
  collate_fn=data_collator, 
  batch_size=32,
)

In [None]:
# Calculate Weights for Cross Entropy Loss
num_labels = len(label_names)
frequencies = [0] * num_labels

for batch in train_dataloader:
  for labels in batch['labels']:
    for label in labels:
      if label != -100:
        frequencies[label] += 1

weights = [0.] * num_labels

total_samples = sum(frequencies)
for i in range(num_labels):
    weights[i] = total_samples / (num_labels * frequencies[i])

weights = torch.tensor(weights)
weights /= weights.sum()

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


[6064, 623, 145, 155, 247, 66, 11, 8, 512, 830, 6020, 1270, 518, 302, 1105, 304, 879, 308, 1929, 522, 15287]
tensor([6.0928e-04, 5.9305e-03, 2.5481e-02, 2.3837e-02, 1.4958e-02, 5.5980e-02,
        3.3588e-01, 4.6184e-01, 7.2162e-03, 4.4514e-03, 6.1374e-04, 2.9092e-03,
        7.1326e-03, 1.2234e-02, 3.3436e-03, 1.2154e-02, 4.2033e-03, 1.1996e-02,
        1.9153e-03, 7.0779e-03, 2.4169e-04])


In [None]:
from transformers import AutoModelForTokenClassification

ner_model = AutoModelForTokenClassification.from_pretrained(
  model_checkpoint,
  id2label=id2label,
  label2id=label2id,
)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

In [None]:
from torch.optim import AdamW

param_optimizer = list(ner_model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
      'weight_decay_rate': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
      'weight_decay_rate': 0.0}
]

optimizer = AdamW(
    optimizer_grouped_parameters,
    lr=3e-5,
    eps=1e-8
)

# optimizer = AdamW(ner_model.parameters(), lr=3e-5, eps=1e-8, weight_decay=0.1)

In [None]:
from accelerate import Accelerator

accelerator = Accelerator()
ner_model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
  ner_model, optimizer, train_dataloader, eval_dataloader
)

In [None]:
print(len(train_dataloader))

69


In [None]:
from transformers import get_scheduler

num_train_epochs = 10
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
  "linear",
  optimizer=optimizer,
  num_warmup_steps=0,
  num_training_steps=num_training_steps,
)

In [None]:
def postprocess(predictions, labels):
    predictions = predictions.detach().cpu().clone().numpy()
    labels = labels.detach().cpu().clone().numpy()

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    return true_predictions, true_labels

In [None]:
from tqdm.auto import tqdm
from torch import nn
import torch

progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
    
    # Training
    train_loss_val = 0

    ner_model.train()
    for batch in train_dataloader:
        labels = batch.get("labels")

        outputs = ner_model(**batch)

        logits = outputs.get("logits")
        loss = outputs.loss

        train_loss_val += loss.item()

        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    print(f"Training Loss: {train_loss_val / len(train_dataloader)}")

    # Evaluation
    eval_loss_val = 0

    ner_model.eval()
    for batch in eval_dataloader:
        with torch.no_grad():
            outputs = ner_model(**batch)
            
        eval_loss_val += outputs.get("loss").item()

        predictions = outputs.logits.argmax(dim=-1)
        labels = batch["labels"]

        # Necessary to pad predictions and labels for being gathered
        predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100)
        labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)

        predictions_gathered = accelerator.gather(predictions)
        labels_gathered = accelerator.gather(labels)

        true_predictions, true_labels = postprocess(predictions_gathered, labels_gathered)
        metric.add_batch(predictions=true_predictions, references=true_labels)


    print(f"Validation Loss: {eval_loss_val / len(eval_dataloader)}")

    results = metric.compute(suffix=True)
    print(
        f"epoch {epoch}:",
        {
            f"overall_{key}": results[f"overall_{key}"]
            for key in ["precision", "recall", "f1", "accuracy"]
        },
        "\n"
    )

  0%|          | 0/690 [00:00<?, ?it/s]

Training Loss: 1.0164752922196318
Validation Loss: 0.38953473170598346


  _warn_prf(average, modifier, msg_start, len(result))


epoch 0: {'overall_precision': 0.8101574628520737, 'overall_recall': 0.8563056727613689, 'overall_f1': 0.8325925925925927, 'overall_accuracy': 0.8928097345132744} 

Training Loss: 0.3225571701060171
Validation Loss: 0.2820766758587625
epoch 1: {'overall_precision': 0.8609906414060716, 'overall_recall': 0.8842006563525551, 'overall_f1': 0.8724413091245519, 'overall_accuracy': 0.9179203539823009} 

Training Loss: 0.22057598872461182
Validation Loss: 0.27001895010471344
epoch 2: {'overall_precision': 0.8756607676396231, 'overall_recall': 0.8931082981715893, 'overall_f1': 0.8842984797493327, 'overall_accuracy': 0.923783185840708} 

Training Loss: 0.16937695044106332
Validation Loss: 0.2647460682524575
epoch 3: {'overall_precision': 0.8812471343420449, 'overall_recall': 0.9010782934833568, 'overall_f1': 0.8910523875753361, 'overall_accuracy': 0.927212389380531} 

Training Loss: 0.1306333093755487
Validation Loss: 0.2665892475181156
epoch 4: {'overall_precision': 0.8867006487488415, 'overall

In [None]:
import numpy as np

def evaluate(dataloader_val):

    ner_model.eval()
    
    loss_val_total = 0
    predictions, true_vals = [], []
    
    for batch in dataloader_val:
        inputs = {
                  'input_ids':      batch['input_ids'],
                  'attention_mask': batch['attention_mask'],
                  'labels':         batch['labels'],
                 }

        with torch.no_grad():        
            outputs = ner_model(**inputs)
            
        loss = outputs.get("loss")
        logits = outputs.get("logits")
        loss_val_total += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = inputs['labels'].cpu().numpy()
        predictions.append(logits)
        true_vals.append(label_ids)
    
    loss_val_avg = loss_val_total/len(dataloader_val) 

    flat_predictions = []
    flat_true_vals = []

    for prediction, true_val in zip(predictions, true_vals):
      for p, t in zip(prediction, true_val):
        for p_item, t_item in zip(p, t):
          if t_item != -100:
            flat_predictions.append(p_item)
            flat_true_vals.append(t_item)
            
    return loss_val_avg, flat_predictions, flat_true_vals

In [None]:
from sklearn.metrics import classification_report

_, predictions, true_vals = evaluate(eval_dataloader)

labeled_preds = np.argmax(predictions, axis=1)
labeled_preds = [id2label[prediction].replace("-I", "").replace("-B", "") for prediction in labeled_preds]
labeled_trues = [id2label[true].replace("-I", "").replace("-B", "") for true in true_vals]
print(classification_report(labeled_preds, labeled_trues))

              precision    recall  f1-score   support

          Ac       0.95      0.96      0.96      1672
         Ac2       0.56      0.45      0.50        74
          Af       0.57      0.73      0.64        77
          At       0.80      1.00      0.89         4
           D       0.97      0.99      0.98       381
           F       0.96      0.97      0.97      1705
           O       0.96      0.96      0.96      3647
           Q       0.92      0.79      0.85       194
          Sf       0.75      0.77      0.76       328
          St       0.86      0.88      0.87       345
           T       0.97      0.91      0.94       613

    accuracy                           0.94      9040
   macro avg       0.84      0.85      0.85      9040
weighted avg       0.94      0.94      0.94      9040

