# Obtenir une prédiction à partir d'un checkpoint

In [None]:
!pip install lightning
!pip install -U 'wandb>=0.12.10'
!pip install pytorch-lightning --quiet

In [6]:
# importing libraries
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import lightning as L
import torch
import torchmetrics
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, multilabel_confusion_matrix

import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

In [7]:
class ToxicCommentTagger(L.LightningModule):

  def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
    super().__init__()
    self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)
    self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
    self.n_training_steps = n_training_steps
    self.n_warmup_steps = n_warmup_steps
    self.criterion = nn.BCELoss()

  def forward(self, input_ids, attention_mask, labels=None):
    output = self.bert(input_ids, attention_mask=attention_mask)
    output = self.classifier(output.pooler_output)
    output = torch.sigmoid(output)    
    loss = 0
    if labels is not None:
        loss = self.criterion(output, labels)
    return loss, output

  def training_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("train_loss", loss, prog_bar=True, logger=True)
    return {"loss": loss, "predictions": outputs, "labels": labels}

  def validation_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("val_loss", loss, prog_bar=True, logger=True)
    return loss

  def test_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("test_loss", loss, prog_bar=True, logger=True)
    return loss

  def on_train_epoch_end(self, *args):
    # trainer, pl_module
    if 0:
      labels = []
      predictions = []
      for output in outputs:
        for out_labels in output["labels"].detach().cpu():
          labels.append(out_labels)
        for out_predictions in output["predictions"].detach().cpu():
          predictions.append(out_predictions)

      labels = torch.stack(labels).int()
      predictions = torch.stack(predictions)

      for i, name in enumerate(LABEL_COLUMNS):
        class_roc_auc = auroc(predictions[:, i], labels[:, i])
        self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)


  def configure_optimizers(self):

    optimizer = AdamW(self.parameters(), lr=2e-5)

    scheduler = get_linear_schedule_with_warmup(
      optimizer,
      num_warmup_steps=self.n_warmup_steps,
      num_training_steps=self.n_training_steps
    )

    return dict(
      optimizer=optimizer,
      lr_scheduler=dict(
        scheduler=scheduler,
        interval='step'
      )
    )
     

In [16]:
classes = ['Compression',
 'Explanation',
 'Modulation',
 'Omission',
 'Substitution',
 'Synonymy',
 'Syntactic',
 'Transcription',
 'Transposition']
BERT_MODEL_NAME = "bert-base-multilingual-cased"
model = ToxicCommentTagger.load_from_checkpoint("/home/onyxia/work/genai-for-public-good/notebooks/lightning_logs/version_3/checkpoints/epoch=0-step=7.ckpt",
n_classes=len(classes))
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
# disable randomness, dropout, etc...
model.eval()

ToxicCommentTagger(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, eleme

In [25]:
def pipeline(text:str):
    encoding = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=512,
        return_token_type_ids=False,
        padding="max_length",
        return_attention_mask=True,
        return_tensors='pt',
        )
    _, test_prediction = model(encoding["input_ids"].to(model.device), encoding["attention_mask"].to(model.device))
    test_prediction = test_prediction.flatten().cpu().detach().numpy()
    result = {}
    for label, prediction in zip(classes, test_prediction):
        result[label] = float(prediction)
    return result

test_comment = "Une adhésion de 15 € / an à l'association est demandée lors de l'inscription."


In [26]:
pipeline(test_comment)

{'Compression': 0.4562477767467499,
 'Explanation': 0.5056126713752747,
 'Modulation': 0.5023998022079468,
 'Omission': 0.5226243138313293,
 'Substitution': 0.5072996020317078,
 'Synonymy': 0.4382634162902832,
 'Syntactic': 0.4507046043872833,
 'Transcription': 0.5091156959533691,
 'Transposition': 0.4874115288257599}