https://heartbeat.comet.ml/how-to-build-a-text-classification-model-using-huggingface-transformers-and-comet-4d40236e8f84



In [1]:

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, DataCollatorWithPadding

from transformers import TrainingArguments, Trainer
from datasets import Dataset
from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
ds_train = Dataset.from_csv('train_twitter.csv')
ds_val = Dataset.from_csv('validate_twitter.csv')
ds_test = Dataset.from_csv('test_twitter.csv')
ds = {"train": ds_train, "validation": ds_val, "test": ds_test}


Found cached dataset csv (C:/Users/Steve Toner/.cache/huggingface/datasets/csv/default-3830c33275af6959/0.0.0)
Found cached dataset csv (C:/Users/Steve Toner/.cache/huggingface/datasets/csv/default-57c0cef75a27644e/0.0.0)
Found cached dataset csv (C:/Users/Steve Toner/.cache/huggingface/datasets/csv/default-a72a0ce3803e3720/0.0.0)


In [4]:

id2label = {0: "United States", 1: "United Kingdom", 2: "Canada", 3: "Australia", 4: "India", 5: "Nigeria"}
label2id = {"United States": 0, "United Kingdom": 1, "Canada": 2, "Australia": 3, "India": 4, "Nigeria": 5}


In [6]:
model_path = 'my_awesome_model'

model = AutoModelForSequenceClassification.from_pretrained("Twitter/twhin-bert-base", num_labels=6, id2label=id2label, label2id=label2id)
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained('Twitter/twhin-bert-base')

def preprocess_function(examples):
    label = examples["country"] 
    examples = tokenizer(examples["tweet_text"], truncation=True, padding="max_length", max_length=256, return_tensors='pt')
    for key in examples:
        examples[key] = examples[key].squeeze(0)
    examples["label"] = label
    return examples

for split in ds:
    ds[split] = ds[split].map(preprocess_function, remove_columns=['user_id', 'tweet_id', 'tweet_text', 'country'])
    ds[split].set_format('pt')


Some weights of the model checkpoint at Twitter/twhin-bert-base were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at Twitter/twhin-bert-ba

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Steve Toner\.cache\huggingface\datasets\csv\default-a72a0ce3803e3720\0.0.0\cache-e61a526aad7c9132.arrow


In [12]:
import evaluate

accuracy = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return {"accuracy": accuracy.compute(predictions=predictions, references=labels), "f1":f1_metric.compute(predictions=predictions, references=labels, average="weighted")}

In [13]:
import torch
class TwitterTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # print ("inputs: ", inputs)
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss = torch.nn.functional.cross_entropy(logits, labels)
        return (loss, outputs) if return_outputs else loss

In [14]:
from transformers import TrainingArguments
from transformers import Trainer

training_args = TrainingArguments(
    output_dir="my_awesome_model",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

trainer = TwitterTrainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    compute_metrics=compute_metrics
)


In [18]:
  
def forward_func(inputs, position = 0):
    """
        Wrapper around prediction method of pipeline
    """
    pred = model(inputs, attention_mask=torch.ones_like(inputs))
    return pred[position]
    
def visualize(inputs: list, attributes: list):
    """
        Visualization method.
        Takes list of inputs and correspondent attributs for them to visualize in a barplot
    """
    attr_sum = attributes.sum(-1) 
    
    attr = attr_sum / torch.norm(attr_sum)
    
    a = pd.Series(attr.numpy()[0], 
                        index = tokenizer.convert_ids_to_tokens(inputs.detach().numpy()[0]))
    
    plt.show(a.plot.barh(figsize=(10,20)))
                    
def explain(text: str):
    """
        Main entry method. Passes text through series of transformations and through the model. 
        Calls visualization method.
    """
    prediction = trainer.predict(text)
    inputs = generate_inputs(text)
    baseline = generate_baseline(sequence_len = inputs.shape[1])
    
    lig = LayerIntegratedGradients(forward_func, getattr(model, 'Twitter/twhin-bert-base').embeddings)
    
    attributes, delta = lig.attribute(inputs=inputs,
                                baselines=baseline,
                                target = model.config.label2id[prediction[0]['label']], 
                                return_convergence_delta = True)
    
    visualize(inputs, attributes, prediction)
    
def generate_inputs(text: str):
    """
        Convenience method for generation of input ids as list of torch tensors
    """
    return torch.tensor(tokenizer.encode(text, add_special_tokens=False), device = device).unsqueeze(0)

def generate_baseline(sequence_len: int):
    """
        Convenience method for generation of baseline vector as list of torch tensors
    """        
    return torch.tensor([tokenizer.cls_token_id] + [tokenizer.pad_token_id] * (sequence_len - 2) + [tokenizer.sep_token_id], device = device).unsqueeze(0)


In [16]:
example = next(iter(ds_test))

In [22]:

viz_ds = preprocess_function(example)

In [24]:
viz_ds

{'input_ids': tensor([    0,  6561,  1556,   163,    15,   246, 18770,    16,   136,  1632,
         1556, 22758,  1520, 27012,    15, 47924, 16852, 96967,    28, 46429,
           16,  3975,   696,    18,     5,   587, 11583,   127,   418, 16444,
          127,   169,   966,   127,   594,     2,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1, 

In [23]:
explain(viz_ds)

IndexError: list index out of range