<a href="https://colab.research.google.com/github/tomonari-masada/course2024-nlp/blob/main/embedding_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install datasets captum

Collecting datasets
  Downloading datasets-2.19.2-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.1/542.1 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting captum
  Downloading captum-0.7.0-py3-none-any.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
Collecting requests>=2.32.1 (from datasets)
  Downloading requests-2.32.3-py3-none-any.whl (64 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.9/64.9 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import os
from collections import Counter
import re
import numpy as np
from sklearn.metrics import classification_report

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from transformers import set_seed
from datasets import load_dataset

from captum.attr import visualization as vis
from captum.attr import TokenReferenceBase, LayerIntegratedGradients

set_seed(1234)

In [3]:
dataset_id = "dair-ai/emotion"
dataset = load_dataset(dataset_id, trust_remote_code=True)

CATEGORIES = np.array(["sadness", "joy", "love", "anger", "fear", "surprise"])

ds_train = dataset['train']
ds_val = dataset['validation']
ds_test = dataset['test']

num_class = len(set(ds_train['label']))

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading builder script:   0%|          | 0.00/3.97k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/3.28k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/8.78k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/592k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/74.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/74.9k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/16000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [4]:
patterns = [r"\'", r"\"", r"\.", r"<br \/>", r",", r"\(", r"\)", r"\!", r"\?", r"\;", r"\:", r"\s+"]
replacements = [" '  ", "", " . ", " ", " , ", " ( ", " ) ", " ! ", " ? ", " ", " ", " "]
patterns_dict = list((re.compile(p), r) for p, r in zip(patterns, replacements))

def basic_english_normalize(line):
    line = line.lower()
    for pattern_re, replaced_str in patterns_dict:
        line = pattern_re.sub(replaced_str, line)
    return line.split()

tokenizer = basic_english_normalize

In [5]:
UNK_TOKEN = "<unk>"
BOS_TOKEN = "<bos>"
EOS_TOKEN = "<eos>"
PAD_TOKEN = "<pad>"

word_counter = Counter()
for doc in ds_train['text']:
    word_counter.update(tokenizer(doc))

for word in list(word_counter):
    if word_counter[word] < 5:
        del word_counter[word]

id2token = [PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN] + list(word_counter.keys())
voc = {k: v for v, k in enumerate(id2token)}

vocab_size = len(voc)

unknown_idx = voc[UNK_TOKEN]
bos_idx = voc[BOS_TOKEN]
eos_idx = voc[EOS_TOKEN]
padding_idx = voc[PAD_TOKEN]

print(f"vocabulary size: {vocab_size}")

vocabulary size: 3498


In [6]:
BATCH_SIZE = 64

def collate_batch(batch):
    labels = torch.tensor([b['label'] for b in batch])
    token_sequences = [tokenizer(b['text']) for b in batch]
    max_len = max([len(seq) for seq in token_sequences])
    token_ids = []
    for text in token_sequences:
        token_ids.append(
            [bos_idx] + [voc.get(token, voc[UNK_TOKEN]) for token in text]
            + [padding_idx] * (max_len - len(text)) + [eos_idx]
        )
    token_ids = torch.tensor(token_ids)
    return labels.to("cuda"), token_ids.to("cuda")

train_loader = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)
test_loader = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

In [7]:
class EmbeddingModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.linear = nn.Linear(embed_dim, num_class)

    def forward(self, inputs):
        out = self.embedding(inputs).mean(1)
        out = self.linear(out)
        return out

In [8]:
EMB_SIZE = 64
model = EmbeddingModel(vocab_size, EMB_SIZE, num_class)

USE_PRETRAINED = False

CHECKPOINT_DIR = f"./models/{dataset_id}/"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
CHECKPOINT = CHECKPOINT_DIR + "/embedding_bag.pt"
if USE_PRETRAINED:
    model.load_state_dict(torch.load(CHECKPOINT))

model = model.to("cuda")
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-3)

In [9]:
EPOCHS = 50

def train_model(train_loader, val_loader):
    best_val_acc = 0
    best_epoch = 0
    for epoch in range(1, EPOCHS + 1):
        model.train()
        for labels, token_ids in train_loader:
            optimizer.zero_grad()
            logits = model(token_ids)
            loss(logits, labels).backward()
            optimizer.step()
        model.eval()
        total_acc, total_count = 0, 0
        with torch.no_grad():
            for labels, token_ids in val_loader:
                logits = model(token_ids)
                total_acc += (logits.argmax(1) == labels).sum().item()
                total_count += labels.size(0)
        val_acc = total_acc / total_count
        print(f'epoch {epoch:3d} | validation accuracy {val_acc:8.3f} ')
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_epoch = epoch
            torch.save(model.state_dict(), CHECKPOINT)
    print(f"best validation accuracy {best_val_acc:.3f} at epoch {best_epoch}")

In [10]:
train_model(train_loader, val_loader)

epoch   1 | validation accuracy    0.480 
epoch   2 | validation accuracy    0.615 
epoch   3 | validation accuracy    0.762 
epoch   4 | validation accuracy    0.840 
epoch   5 | validation accuracy    0.871 
epoch   6 | validation accuracy    0.873 
epoch   7 | validation accuracy    0.877 
epoch   8 | validation accuracy    0.887 
epoch   9 | validation accuracy    0.890 
epoch  10 | validation accuracy    0.886 
epoch  11 | validation accuracy    0.887 
epoch  12 | validation accuracy    0.888 
epoch  13 | validation accuracy    0.887 
epoch  14 | validation accuracy    0.888 
epoch  15 | validation accuracy    0.887 
epoch  16 | validation accuracy    0.886 
epoch  17 | validation accuracy    0.887 
epoch  18 | validation accuracy    0.884 
epoch  19 | validation accuracy    0.893 
epoch  20 | validation accuracy    0.887 
epoch  21 | validation accuracy    0.887 
epoch  22 | validation accuracy    0.885 
epoch  23 | validation accuracy    0.885 
epoch  24 | validation accuracy   

In [11]:
model.load_state_dict(torch.load(CHECKPOINT))
model.eval()

y_test = torch.tensor([]).to("cuda")
y_test_pred = torch.tensor([]).to("cuda")

with torch.no_grad():
    for labels, token_ids in test_loader:
        predicted_labels = model(token_ids).argmax(1)
        y_test = torch.cat((y_test, labels), dim=0)
        y_test_pred = torch.cat((y_test_pred, predicted_labels), dim=0)

y_test = y_test.to('cpu').numpy()
y_test_pred = y_test_pred.to('cpu').numpy()

report = classification_report(
    y_true=y_test,
    y_pred=y_test_pred,
)

print(report)

              precision    recall  f1-score   support

         0.0       0.92      0.91      0.92       581
         1.0       0.89      0.92      0.90       695
         2.0       0.76      0.75      0.75       159
         3.0       0.85      0.87      0.86       275
         4.0       0.85      0.81      0.83       224
         5.0       0.73      0.67      0.70        66

    accuracy                           0.88      2000
   macro avg       0.83      0.82      0.83      2000
weighted avg       0.87      0.88      0.87      2000



In [12]:
token_reference = TokenReferenceBase(reference_token_idx=padding_idx)

In [13]:
def predict_prob(token_ids):
    return F.softmax(model(token_ids), dim=-1)

def class_prob_forward_func(token_ids, label):
    return predict_prob(token_ids)[:,label]

In [14]:
lig = LayerIntegratedGradients(
    forward_func=class_prob_forward_func,
    layer=model.embedding,
)

In [15]:
def add_attributions_to_visualizer(attributions, text, pred_prob, pred_class, true_class,
                                   attr_class, convergence_scores, vis_data_records):
    attributions = attributions.cpu()
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()
    vis_data_records.append(
        vis.VisualizationDataRecord(
            attributions,
            pred_prob,
            pred_class,
            true_class,
            attr_class,
            attributions.sum(),
            text,
            convergence_scores,
        )
    )

In [16]:
def interpret_text(example, attr_class=None, n_steps=50):
    text = example["text"]
    true_class = example["label"]
    labels, token_ids = collate_batch([{"text": text, "label": true_class}])
    true_class = labels[0]
    tokens = [id2token[id] for id in token_ids[0]]
    reference_input_ids = token_reference.generate_reference(
        token_ids.shape[-1],
        device="cuda",
    ).unsqueeze(0)
    probs = predict_prob(token_ids)[0]
    prediction = probs.argmax().item()
    if attr_class is None:
        attr_class = prediction
    print(
        f"prediction={prediction} "
        f"probability={probs.max().item():.3f} ",
        end=""
    )

    attributions_ig, delta = lig.attribute(
        token_ids,
        reference_input_ids,
        additional_forward_args=(attr_class),
        n_steps=n_steps,
        return_convergence_delta=True,
    )
    print(f"convergence delta={delta.item():.3e} when n_steps={n_steps}")

    add_attributions_to_visualizer(
        attributions_ig,
        tokens,
        probs.max().item(),
        CATEGORIES[prediction],
        CATEGORIES[true_class],
        CATEGORIES[attr_class],
        delta,
        vis_data_records_ig,
    )
    return prediction

In [17]:
vis_data_records_ig = []
for i in range(10):
    interpret_text(ds_test[i])

prediction=0 probability=1.000 convergence delta=-2.303e-07 when n_steps=50
prediction=0 probability=1.000 convergence delta=-1.301e-07 when n_steps=50
prediction=0 probability=1.000 convergence delta=-6.360e-08 when n_steps=50
prediction=1 probability=1.000 convergence delta=9.927e-08 when n_steps=50
prediction=0 probability=1.000 convergence delta=-3.411e-08 when n_steps=50
prediction=4 probability=0.746 convergence delta=-1.991e-07 when n_steps=50
prediction=3 probability=0.997 convergence delta=-1.557e-08 when n_steps=50
prediction=1 probability=0.989 convergence delta=-2.019e-09 when n_steps=50
prediction=1 probability=1.000 convergence delta=-9.119e-08 when n_steps=50
prediction=1 probability=0.562 convergence delta=-5.413e-08 when n_steps=50


In [18]:
vis.visualize_text(vis_data_records_ig);

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
sadness,sadness (1.00),sadness,1.05,#bos im feeling rather rotten so im not very #unk right now #eos
,,,,
sadness,sadness (1.00),sadness,1.21,#bos im #unk my blog because i feel shitty #eos
,,,,
sadness,sadness (1.00),sadness,1.03,#bos i never make her separate from me because i don t ever want her to feel like i m ashamed with her #eos
,,,,
joy,joy (1.00),joy,0.95,#bos i left with my #unk of red and #unk #unk under my arm feeling slightly more optimistic than when i arrived #eos
,,,,
sadness,sadness (1.00),sadness,1.07,#bos i was feeling a little vain when i did this one #eos
,,,,
