In [36]:
import numpy as np 
import torch
from torch import nn
from torch.utils.data import DataLoader
import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification
from transformers import AutoModelForTokenClassification
from transformers import TrainingArguments, Trainer
import evaluate


# Prepare data

In [2]:
raw_datasets = load_dataset("conll2003")

Found cached dataset conll2003 (/Users/wgw/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
model_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [5]:
def align_labels_with_tokens(labels, word_ids):
    # Input of NER will be list of words, i.e., splitted text
    # Each word has an associated NER label, but with tokenization,
    # some words will be splitted into >1 tokens. This function is used to
    # align labels with input after tokenization
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            new_labels.append(-100)
        else:
            label = labels[word_id]
            if label % 2== 1:
                label += 1
            new_labels.append(label)
    return new_labels


In [6]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"], truncation=True, is_split_into_words=True
    )
    all_labels = examples["ner_tags"]
    new_labels=[]
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))
    tokenized_inputs["labels"] = new_labels
    return tokenized_inputs

In [7]:
tokenized_datasets = raw_datasets.map(
    tokenize_and_align_labels, 
    batched=True, 
    remove_columns=raw_datasets["train"].column_names)
tokenized_datasets.set_format("torch")

Map:   0%|          | 0/14041 [00:00<?, ? examples/s]

Map:   0%|          | 0/3250 [00:00<?, ? examples/s]

Map:   0%|          | 0/3453 [00:00<?, ? examples/s]

In [11]:
data_collator = DataCollatorForTokenClassification(tokenizer = tokenizer)
train_dataloader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=64
)
val_dataloader = DataLoader(
    tokenized_datasets["validation"],
    collate_fn=data_collator,
    batch_size=64
)
test_dataloader = DataLoader(
    tokenized_datasets["test"],
    collate_fn=data_collator,
    batch_size=64
)

# Base Bert Class for Token Classification 

In [47]:
class BertNERModel(pl.LightningModule):
    def __init__(self, pretrained_model, label_names, finetune_encoder = True):
        super().__init__()
        self.metric = evaluate.load("seqeval")
        self.label_names = label_names
        self.id2label = {i: label for i, label in enumerate(self.label_names)}
        self.label2id = {v: k for k, v in self.id2label.items()}
        self.transformer = AutoModelForTokenClassification.from_pretrained(
            pretrained_model,
            id2label=self.id2label,
            label2id=self.label2id,
        )
        if not finetune_encoder:
            for name, param in self.transformer.named_parameters():
                if not name.startswith("classifier"):
                    param.requires_grad = False
    def compute_metrics(self, eval_preds, stage):
        logits, labels = eval_preds
        predictions = torch.argmax(logits, dim=-1)

        # Remove ignored index (special tokens) and convert to labels
        true_labels = [[self.label_names[l] for l in label if l != -100] for label in labels]
        true_predictions = [
            [self.label_names[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
        all_metrics = self.metric.compute(predictions=true_predictions, references=true_labels)
        return {
            stage+"_"+"precision": torch.tensor(all_metrics["overall_precision"],dtype=torch.float32),
            stage+"_"+"recall": torch.tensor(all_metrics["overall_recall"],dtype=torch.float32),
            stage+"_"+"f1": torch.tensor(all_metrics["overall_f1"],dtype=torch.float32),
            stage+"_"+"accuracy": torch.tensor(all_metrics["overall_accuracy"],dtype=torch.float32),
        }
    
    def forward(self, **inputs):
        return self.transformer(**inputs)
    
    def training_step(self, batch, batch_idx):
        outputs = self.transformer(**batch)
        loss = outputs.loss
        self.log("training loss", loss, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self,batch, batch_idx):
        y= batch["labels"]
        outputs = self.transformer(**batch)
        loss, logits = outputs[:2]
        metric_dic = self.compute_metrics((logits,y),"validation")
        metric_dic["validation loss"] = loss
        #print(metric_dic)
        self.log_dict(metric_dic, on_epoch=True, logger=True)
    
    def test_step(self,batch, batch_idx):
        y = batch["labels"]
        outputs = self.transformer(**batch)
        loss, logits = outputs[:2]
        metric_dic = self.compute_metrics((logits,y),"test")
        metric_dic["test loss"] = loss
        self.log_dict(metric_dic)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr = 2e-5)
        return optimizer
    
        

# Fully Finetuning DstilBert for NER

In [48]:
label_names = raw_datasets["train"].features["ner_tags"].feature.names
net = BertNERModel(model_checkpoint, label_names)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForTokenClassification: ['vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN t

In [49]:
trainer = pl.Trainer(callbacks=[EarlyStopping(monitor="validation_f1", mode="min", patience=2)])
trainer.fit(model=net, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name        | Type                             | Params
-----------------------------------------------------------------
0 | transformer | DistilBertForTokenClassification | 66.4 M
-----------------------------------------------------------------
66.4 M    Trainable params
0         Non-trainable params
66.4 M    Total params
265.479   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [50]:
trainer.test(dataloaders=test_dataloader)

Restoring states from the checkpoint path at /Users/wgw/Documents/Projects/nlp-tutorial/lightning_logs/version_22/checkpoints/epoch=2-step=660.ckpt
Loaded model weights from the checkpoint at /Users/wgw/Documents/Projects/nlp-tutorial/lightning_logs/version_22/checkpoints/epoch=2-step=660.ckpt


Testing: 0it [00:00, ?it/s]

[{'test_precision': 0.8515589833259583,
  'test_recall': 0.8795795440673828,
  'test_f1': 0.8648689389228821,
  'test_accuracy': 0.9741219282150269,
  'test loss': 0.03447655215859413}]

# TODO
1. Impletment validation_epoch_end method to print validation metrics at the end of each epoch
2. Change dataloaders to a LightningDateModule