<a href="https://colab.research.google.com/github/sjpark0605/NLP-FYP/blob/main/NER_Training_Loop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [47]:
%%capture
!pip install datasets evaluate transformers[sentencepiece] seqeval accelerate

In [48]:
# Imports for Data Processing
import pandas as pd
import torch
from datasets import load_from_disk

In [49]:
from google.colab import drive
drive.mount('/content/drive')
dataset_dir = '/content/drive/MyDrive/COMP0029/datasets/'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [50]:
device = torch.device('cpu')

if torch.cuda.is_available():
  device = torch.device('cuda')

In [51]:
corpus_datasets = load_from_disk(dataset_dir + 'english-recipe-ner')

In [52]:
ner_feature = corpus_datasets["train"].features["ner_tags"]
label_names = ner_feature.feature.names
pure_label_names = list(set(label.replace("-B", "").replace("-I", "") for label in label_names))

In [53]:
from transformers import AutoTokenizer

model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [54]:
def align_labels_with_tokens(labels, word_ids):
  new_labels = []
  current_word = None
  for word_id in word_ids:
    if word_id != current_word:
      # Start of a new word!
      current_word = word_id
      label = -100 if word_id is None else labels[word_id]
      new_labels.append(label)
    elif word_id is None:
      # Special token
      new_labels.append(-100)
    else:
      new_labels.append(-100)

      #new_labels.append(labels[word_id])
      # # Same word as previous token
      # label = labels[word_id]      
      # # If the label is XXX-B we change it to XXX-I
      # if label % 2 == 0 and label != 20:
      #   label += 1
      #   new_labels.append(label)

  return new_labels

In [55]:
def tokenize_and_align_labels(examples):
  tokenized_inputs = tokenizer(
      examples["tokens"], truncation=True, is_split_into_words=True, max_length=75
  )
  all_labels = examples["ner_tags"]
  new_labels = []
  for i, labels in enumerate(all_labels):
    word_ids = tokenized_inputs.word_ids(i)
    new_labels.append(align_labels_with_tokens(labels, word_ids))

  tokenized_inputs["labels"] = new_labels
  return tokenized_inputs

In [56]:
tokenized_datasets = corpus_datasets.map(
  tokenize_and_align_labels,
  batched=True,
  remove_columns=corpus_datasets["train"].column_names,
)

Map:   0%|          | 0/2201 [00:00<?, ? examples/s]

Map:   0%|          | 0/551 [00:00<?, ? examples/s]

In [57]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [58]:
import evaluate

metric = evaluate.load("seqeval")

In [59]:
id2label = {i: label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}

In [60]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
  tokenized_datasets["train"],
  shuffle=True,
  collate_fn=data_collator,
  batch_size=32,
)

eval_dataloader = DataLoader(
  tokenized_datasets["valid"], 
  collate_fn=data_collator, 
  batch_size=32,
)

In [61]:
# Calculate Weights for DataLoader
num_labels = len(label_names)
frequencies = [0] * num_labels

for batch in train_dataloader:
  for labels in batch['labels']:
    for label in labels:
      if label != -100:
        frequencies[label] += 1

weights = [0.] * num_labels

total_samples = sum(frequencies)
for i in range(num_labels):
    weights[i] = total_samples / (num_labels * frequencies[i])

weights = torch.tensor(weights)
weights /= weights.sum()

print(weights)

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


tensor([8.9126e-04, 6.3804e-03, 2.4547e-02, 2.9213e-02, 1.7412e-02, 6.2013e-02,
        3.2134e-01, 4.4184e-01, 7.9254e-03, 4.5375e-03, 9.0310e-04, 3.5849e-03,
        8.8368e-03, 2.3565e-02, 4.2231e-03, 1.4083e-02, 5.2600e-03, 1.2901e-02,
        2.3803e-03, 7.9254e-03, 2.3680e-04])


In [62]:
from transformers import AutoModelForTokenClassification

ner_model = AutoModelForTokenClassification.from_pretrained(
  model_checkpoint,
  id2label=id2label,
  label2id=label2id,
)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

In [63]:
from torch.optim import AdamW

optimizer = AdamW(ner_model.parameters(), lr=3e-5, eps=1e-8, weight_decay=0.1)

In [64]:
from accelerate import Accelerator

accelerator = Accelerator()
ner_model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
  ner_model, optimizer, train_dataloader, eval_dataloader
)

In [65]:
print(len(train_dataloader))

69


In [66]:
from transformers import get_scheduler

num_train_epochs = 10
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
  "linear",
  optimizer=optimizer,
  num_warmup_steps=0,
  num_training_steps=num_training_steps,
)

In [67]:
def postprocess(predictions, labels):
    predictions = predictions.detach().cpu().clone().numpy()
    labels = labels.detach().cpu().clone().numpy()

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    return true_predictions, true_labels

In [68]:
from tqdm.auto import tqdm
from torch import nn
import torch

progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
    
    # Training
    train_loss_val = 0

    ner_model.train()
    for batch in train_dataloader:
        labels = batch.get("labels")

        # # Custom Loss
        # labels = batch.get("labels")
        # outputs = ner_model(**batch)
        # logits = outputs.get("logits")
        # # compute custom loss (suppose one has 3 labels with different weights)
        # loss_fct = torch.nn.CrossEntropyLoss(weight=weights.to(device))
        # loss = loss_fct(logits.view(-1, ner_model.config.num_labels), labels.view(-1))
        # # End of Custom Loss

        outputs = ner_model(**batch)

        logits = outputs.get("logits")
        loss = outputs.loss

        train_loss_val += loss.item()

        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    print(f"Training Loss: {train_loss_val / len(train_dataloader)}")

    # Evaluation
    eval_loss_val = 0

    ner_model.eval()
    for batch in eval_dataloader:
        with torch.no_grad():
            outputs = ner_model(**batch)
            
        eval_loss_val += outputs.get("loss").item()

        predictions = outputs.logits.argmax(dim=-1)
        labels = batch["labels"]

        # Necessary to pad predictions and labels for being gathered
        predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100)
        labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)

        predictions_gathered = accelerator.gather(predictions)
        labels_gathered = accelerator.gather(labels)

        true_predictions, true_labels = postprocess(predictions_gathered, labels_gathered)
        metric.add_batch(predictions=true_predictions, references=true_labels)


    print(f"Validation Loss: {eval_loss_val / len(eval_dataloader)}")

    results = metric.compute(suffix=True)
    print(
        f"epoch {epoch}:",
        {
            f"overall_{key}": results[f"overall_{key}"]
            for key in ["precision", "recall", "f1", "accuracy"]
        },
        "\n"
    )

  0%|          | 0/690 [00:00<?, ?it/s]

Training Loss: 0.9305899233921714
Validation Loss: 0.3523300058311886
epoch 0: {'overall_precision': 0.8018565941101152, 'overall_recall': 0.8394772117962467, 'overall_f1': 0.8202357563850687, 'overall_accuracy': 0.9036208732694355} 

Training Loss: 0.3059661243704782
Validation Loss: 0.2561356880598598
epoch 1: {'overall_precision': 0.8458049886621315, 'overall_recall': 0.875, 'overall_f1': 0.8601548344589031, 'overall_accuracy': 0.9239882854100107} 

Training Loss: 0.20484370738267899
Validation Loss: 0.2491388867298762
epoch 2: {'overall_precision': 0.8609206660137121, 'overall_recall': 0.8837131367292225, 'overall_f1': 0.8721680171986108, 'overall_accuracy': 0.9274494142705005} 

Training Loss: 0.15507678968318994
Validation Loss: 0.24344073401557076
epoch 3: {'overall_precision': 0.8683519369665135, 'overall_recall': 0.8863941018766756, 'overall_f1': 0.8772802653399668, 'overall_accuracy': 0.9306443024494143} 

Training Loss: 0.12116050374680672
Validation Loss: 0.2536125977834065

In [69]:
# Evaluation Metric of Final Iteration
for label in pure_label_names:
    if label != "O" and label != "At":
      print(
          f"{label}:",
          {
              key: results[label][key]
              for key in ["precision", "recall", "f1"]
          },
      )

Ac2: {'precision': 0.3611111111111111, 'recall': 0.4482758620689655, 'f1': 0.39999999999999997}
D: {'precision': 0.9230769230769231, 'recall': 0.9375, 'f1': 0.9302325581395349}
T: {'precision': 0.9190600522193212, 'recall': 0.946236559139785, 'f1': 0.9324503311258278}
Q: {'precision': 0.7448979591836735, 'recall': 0.7934782608695652, 'f1': 0.7684210526315789}
St: {'precision': 0.8494623655913979, 'recall': 0.8144329896907216, 'f1': 0.8315789473684212}
Ac: {'precision': 0.9226856561546287, 'recall': 0.9255102040816326, 'f1': 0.924095771777891}
Sf: {'precision': 0.6837209302325581, 'recall': 0.75, 'f1': 0.7153284671532846}
F: {'precision': 0.9135932560590094, 'recall': 0.9342672413793104, 'f1': 0.9238145977623867}
Af: {'precision': 0.6122448979591837, 'recall': 0.4838709677419355, 'f1': 0.5405405405405406}
