# Multi-label Text Classification with BERT and PyTorch Lightning

## Installing & importing Libraries

In [None]:
!pip install transformers lightning torchmetrics 'wandb>=0.12.10' 'tensorboard' --quiet

In [None]:
# importing libraries
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import lightning as L
import torch
import torchmetrics
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, multilabel_confusion_matrix

import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

In [None]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

RANDOM_SEED = 42
import lightning as L
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]
sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
rcParams['figure.figsize'] = 12, 8
pl.seed_everything(RANDOM_SEED)

## Loading the Data

In [None]:
import pandas as pd
file_path = 'ri_annotated_texts_final.csv'
data= pd.read_csv(file_path)
# text cleaning
def clean_text(text:str):
    import re
    text = text.strip()
    text = re.sub(r"^-\s+", "", text)
    return text
data["Version initiale"] = data["Version initiale"].apply(clean_text)
data["Version retraitée"] = data["Version retraitée"].apply(clean_text)
data = data.groupby(by="Version initiale").aggregate({"Version retraitée":'first', "Catégorie":lambda x: ", ".join(x)}).reset_index(drop=False)
classes = sorted(list(set(", ".join(data["Catégorie"]).split(", "))))
classes
class2id = {class_:id_ for id_, class_ in enumerate(classes)}
id2class = {id_:class_ for class_, id_ in class2id.items()}

data["classes"] = [[class2id[g] for g in j.split(", ")] for j in data["Catégorie"]]
data



In [None]:
# Step 2: MultiLabel Binarization
from sklearn.preprocessing import MultiLabelBinarizer
mlb = MultiLabelBinarizer()
labels_binarized = mlb.fit_transform(data['classes'])
texts = data['Version initiale'].str.strip() + " [SEP] " + data['Version retraitée'].str.strip()

# Step 3: Train-test split
train_texts, val_texts, train_labels, val_labels = train_test_split(
    texts, labels_binarized, test_size=0.2, random_state=42
)

## Preprocessing

## Tokenization

In [None]:
# loading tokenizer of bert base version
BERT_MODEL_NAME = "bert-base-cased" if 0 else "bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)

## Wrapping Tokenization process in a PyTorch Dataset

In [None]:
class ToxicCommentsDataset(Dataset):

    def __init__(self, texts, labels, tokenizer: BertTokenizer, max_token_len: int = 128 * 4):

        self.texts = tuple(texts)
        self.labels = tuple(labels)
        self.tokenizer = tokenizer
        self.max_token_len = max_token_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index: int):

        comment_text = self.texts[index] #data_row.comment_text
        labels = self.labels[index]#data_row[LABEL_COLUMNS]

        encoding = self.tokenizer.encode_plus(
            comment_text,
            add_special_tokens=True,
            max_length=self.max_token_len,
            return_token_type_ids=False,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt"
        )
        
        return dict(
            comment_text=comment_text,
            input_ids=encoding["input_ids"].flatten(),
            attention_mask=encoding["attention_mask"].flatten(),
            labels=torch.FloatTensor(labels)
        )


In [None]:
train_dataset = ToxicCommentsDataset(train_texts, train_labels, tokenizer)

In [None]:
sample_item = train_dataset[0]

In [None]:
sample_item.keys()

In [None]:
sample_item["comment_text"]

In [None]:
sample_item["labels"]

In [None]:
sample_item["input_ids"].shape

## Loading the Bert Model

In [None]:
bert_model = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)

In [None]:
sample_item["input_ids"].unsqueeze(dim=0).shape

In [None]:
prediction = bert_model(sample_item["input_ids"].unsqueeze(dim=0), sample_item["attention_mask"].unsqueeze(dim=0))

In [None]:
prediction.last_hidden_state.shape, prediction.pooler_output.shape

### Wrapping our custom dataset into a LightningDataModule

In [None]:
class ToxicCommentsDataModule(L.LightningDataModule):

    def __init__(self, tokenizer, batch_size=0, max_token_len=128 * 4):
        super().__init__()
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_token_len = max_token_len

    def setup(self, *args, **kwargs):
        self.train_dataset = ToxicCommentsDataset(train_texts, train_labels, self.tokenizer)
        
        self.test_dataset = ToxicCommentsDataset(val_texts, val_labels, self.tokenizer)
        
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4
        )

    def val_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4
        )

    #def val_dataloader(self):
    #    return DataLoader(self.test_dataset, batch_size=3, num_workers=4)

    #def test_dataloader(self):
    #    return DataLoader(self.test_dataset, batch_size=3, num_workers=4)

In [None]:
N_EPOCHS = 10
BATCH_SIZE = 3

data_module = ToxicCommentsDataModule(tokenizer, batch_size=BATCH_SIZE)
data_module.setup()

## Modeling

### Evaluation

In [None]:
criterion = nn.BCELoss()

prediction = torch.FloatTensor(
    [10.95873564, 1.07321467, 1.58524066, 0.03839076, 15.72987556, 1.09513213]
)

labels = torch.FloatTensor(
  [1., 0., 0., 0., 1., 0.]
) 

In [None]:
torch.sigmoid(prediction)

In [None]:
output = criterion(torch.sigmoid(prediction), labels)
output

### Converting Bert representation to a classification task & packing it into LightningModule

In [None]:
steps_per_epoch=len(train_texts) // BATCH_SIZE
total_training_steps = steps_per_epoch * N_EPOCHS

In [None]:
warmup_steps = total_training_steps // 5
warmup_steps, total_training_steps

In [None]:
from torcheval.metrics import MultilabelAUPRC

class ToxicCommentTagger(L.LightningModule):

  def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
    super().__init__()
    self.classifier = nn.Linear(bert_model.config.hidden_size, n_classes)
    self.n_training_steps = n_training_steps
    self.n_warmup_steps = n_warmup_steps
    self.criterion = nn.BCELoss()

  def forward(self, input_ids, attention_mask, labels=None):
    output = bert_model(input_ids, attention_mask=attention_mask)
    output = self.classifier(output.pooler_output)
    output = torch.sigmoid(output)    
    loss = 0
    if labels is not None:
        loss = self.criterion(output, labels)
    return loss, output

  def training_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("train_loss", loss, prog_bar=True, logger=True)
    return {"loss": loss, "predictions": outputs, "labels": labels}

  def validation_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("val_loss", loss, prog_bar=True, logger=True)
    metric_auprc = MultilabelAUPRC(num_labels=len(classes), average=None)
    aupc = metric_auprc.update(outputs, labels).compute().tolist()
    for name_class, m in zip(classes, aupc):
      self.log(f"aupc_{name_class}", m, logger=True)
    return loss

  def test_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("test_loss", loss, prog_bar=True, logger=True)
    return loss

  def on_train_epoch_end(self, *args):
    # trainer, pl_module
    if 0:
      labels = []
      predictions = []
      for output in outputs:
        for out_labels in output["labels"].detach().cpu():
          labels.append(out_labels)
        for out_predictions in output["predictions"].detach().cpu():
          predictions.append(out_predictions)

      labels = torch.stack(labels).int()
      predictions = torch.stack(predictions)

      for i, name in enumerate(LABEL_COLUMNS):
        class_roc_auc = auroc(predictions[:, i], labels[:, i])
        self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)


  def configure_optimizers(self):

    optimizer = AdamW(self.parameters(), lr=2e-5)

    scheduler = get_linear_schedule_with_warmup(
      optimizer,
      num_warmup_steps=self.n_warmup_steps,
      num_training_steps=self.n_training_steps
    )

    return dict(
      optimizer=optimizer,
      lr_scheduler=dict(
        scheduler=scheduler,
        interval='step'
      )
    )
     

In [None]:
model = ToxicCommentTagger(
  n_classes=len(classes),
  n_warmup_steps=warmup_steps,
  n_training_steps=total_training_steps 
)

## Training

In [None]:
from lightning.pytorch.loggers import TensorBoardLogger

logger = TensorBoardLogger("tb_logs", name="my_model")

checkpoint_callback = ModelCheckpoint(
  dirpath="./checkpoints",
  filename="best-checkpoint",
  save_top_k=2,
  verbose=True,
  monitor="val_loss",
  mode="min"
)

In [None]:
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)

In [None]:
import lightning as L
from lightning.pytorch.loggers import WandbLogger
wandb_logger = WandbLogger(project="MNIST")

trainer = L.Trainer(
  logger=logger if 1 else wandb_logger,
  #checkpoint_callback=checkpoint_callback,
  #callbacks=[early_stopping_callback],
  max_epochs=10,#N_EPOCHS,
)

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"]="false"
bert_model.to("cuda")
trainer.fit(model, data_module)
bert_model.to("cpu")


In [None]:
model.freeze()
model.eval()

## Predictions

In [None]:
test_comment = "Une adhésion de 15 € / an à l'association est demandée lors de l'inscription."

In [None]:
encoding = tokenizer.encode_plus(
  test_comment,
  add_special_tokens=True,
  max_length=512,
  return_token_type_ids=False,
  padding="max_length",
  return_attention_mask=True,
  return_tensors='pt',
)

In [None]:
model.to("cuda")
_, test_prediction = model(encoding["input_ids"].to(model.device), encoding["attention_mask"].to(model.device))
model.to("cpu")
test_prediction = test_prediction.detach().cpu().flatten().numpy()

for label, prediction in zip(classes, test_prediction):
  print(f"{label}: {prediction}")

In [None]:
test_prediction

In [None]:
a = torch.tensor([[0.47587207, 1., 0.48619938, 0.5708344 , 0.46402225,
       0.45795614, 0.46034777, 0.53842413, 0.5092292 ],[0.57587207, 1., 0.48619938, 0.5708344 , 0.46402225,
       0.65795614, 0.66034777, 0.32413, 0.5092292 ]], dtype=torch.float32)
b = torch.tensor([[1, 1, 1, 1 , 1,
       1, 1, 1, 1 ],[1, 1, 0, 1 , 0,
       0, 1, 0, 1 ]], dtype=torch.float32)

In [None]:
!pip install torcheval
from torcheval.metrics import MultilabelAccuracy
metric = MultilabelAccuracy(threshold=0.5, criteria="hamming")

In [None]:
metric.update(a, b).compute()

In [None]:
from torcheval.metrics import MultilabelAUPRC
metric_auprc = MultilabelAUPRC(num_labels=len(classes), average=None)
h = metric_auprc.update(a, b).compute()

In [None]:
h.tolist()

In [None]:
for k in h:
    print(k.item())

## Evaluation

In [None]:
MAX_TOKEN_COUNT = 512

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trained_model = model.to(device)

val_dataset = ToxicCommentsDataset(
  val_df,
  tokenizer,
  max_token_len=MAX_TOKEN_COUNT
)

predictions = []
labels = []

for item in tqdm(val_dataset):
  _, prediction = trained_model(
    item["input_ids"].unsqueeze(dim=0).to(device), 
    item["attention_mask"].unsqueeze(dim=0).to(device)
  )
  predictions.append(prediction.flatten())
  labels.append(item["labels"].int())

predictions = torch.stack(predictions).detach().cpu()
labels = torch.stack(labels).detach().cpu()

In [None]:
accuracy(predictions, labels, threshold=THRESHOLD)

In [None]:
print("AUROC per tag")
for i, name in enumerate(LABEL_COLUMNS):
  tag_auroc = auroc(predictions[:, i], labels[:, i], pos_label=1)
  print(f"{name}: {tag_auroc}")

In [None]:
y_pred = predictions.numpy()
y_true = labels.numpy()

upper, lower = 1, 0

y_pred = np.where(y_pred > THRESHOLD, upper, lower)

print(classification_report(
  y_true, 
  y_pred, 
  target_names=LABEL_COLUMNS, 
  zero_division=0
))

We now know that the model makes mistakes on the tags with low amounts of samples.