### import

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rcParams
rcParams['text.usetex'] = True
rcParams['font.family'] = 'serif'
rcParams['font.serif'] = ['Times New Roman']

import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report
from torchmetrics import AUROC
from torchmetrics.functional import accuracy
from sklearn.metrics import f1_score, roc_auc_score, roc_curve, auc, confusion_matrix, accuracy_score
from sklearn import metrics
from torchmetrics import F1Score
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from torch.nn.utils import clip_grad_norm_
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import torch.nn.functional as F
import torch
import torch.nn as nn
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from transformers import get_linear_schedule_with_warmup
from transformers import AdamW
from pytorch_lightning import Trainer
from transformers import AutoModel, AutoTokenizer
from transformers import BertTokenizerFast as AdamW, get_linear_schedule_with_warmup
import pytorch_lightning as pl
from torchmetrics.classification import F1Score
from sklearn.metrics import precision_recall_fscore_support
import torch.nn.functional as F
from sklearn.metrics import classification_report
import torchmetrics
from torchmetrics import Metric
import torchmetrics.classification
import re
from pytorch_lightning import LightningModule
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from torchmetrics import Accuracy, MetricCollection, Precision, Recall, F1Score, Metric
from sklearn.metrics import confusion_matrix
from torchmetrics.classification import BinaryF1Score, BinaryConfusionMatrix
from torch import tensor
from typing import Optional
from pytorch_lightning.callbacks import Callback
import time
import io
import random
from PIL import Image
import numpy as np
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning import Trainer, seed_everything


In [None]:
train_df_color = '#E9BFC0'          #red
val_df_color = '#C3D5E7'            #blue
test_df_color = '#F4D98F'           #yellow
threshold_color = '#DC96AE'         #pink

real_line_color = '#93c47d'         #blue
fake_line_color = '#E9BFC0'         #red
maybe_line_color = '#F4D98F'        #yellow

real_dot_color = '#56943a'          #green hard
fake_dot_color = '#8a292c'          #red red
maybe_dot_color = '#d69d00'         #yellow hard

logreg_line_color = '#7db364'       #green
logreg_fill_color = '#97c283'
logreg_dot_color = '#48802e'

binary_line_color = '#aeb1cf'       #purple
binary_dot_color = '#8a8db6'
binary_fill_color = '#bec0d8'       #3 tint lighter

multiclass_line_color = '#a2c9ef'   #blue
multiclass_fill_color = '#7a8c9f'
multiclass_dot_color = '#1e456b' 

colors = ['#d1e5f0', '#94c4f3', '#63a7e0', '#358ece', '#0a6ab7', '#084f8a'] #blues for CM
cmap = sns.color_palette(colors, as_cmap=True)

In [None]:
def preprocess_text(text):
    text = text.lower()
    text = re.sub(r"@\w*", " ", text)                                           #tags
    text = re.sub(r"#|^\*|\*$|&quot;|&gt;|&lt;|&lt;3", " ", text)               #special characters
    text = text.replace("&amp;", " and ")                                       #&
    text = re.sub(r"ht+p+s?://\S*", " ", text)                                  #links
    text = re.sub(r"[^a-zA-Z0-9æøåÆØÅ\d\s:/£$%€,-]|(?<!\S)å(?!\S)", "", text)   #ascii
    text = re.sub(r",(?!\d)", "", text)                                         #comma
    text = re.sub(r"(\S)\.(\S)", r"\1. \2", text)                               #period
    text = re.sub(r"\(\)|\[\]|\{\}", " ", text)                                 #brackets
    text = re.sub(r"\s[b-gjz]\s", " ", text)                                    #single chars
    text = " ".join(text.split()).strip()                                       #spaces
    return text

In [None]:
RANDOMSEED = 42
seed_everything(RANDOMSEED, workers=True) #for dataloaders
torch.set_float32_matmul_precision('medium')

In [None]:
train_df = pd.read_csv('/home/thombras/files/HUG_MULTI_LIAR/upsampled_train_df.csv')
val_df = pd.read_csv('/home/thombras/files/HUG_MULTI_LIAR/val_with_labels.csv')
test_df = pd.read_csv('/home/thombras/files/HUG_MULTI_LIAR/test_with_labels.csv')
PFO_test_df = pd.read_csv('/home/thombras/files/PFO/PFO_3_test_df.csv')
LABEL_COLUMNS_CLAS = train_df.columns.tolist()[15:] #fake, maybe, real

In [None]:
train_df['claim_no'] = train_df['claim_no'].apply(preprocess_text)
val_df['claim_no'] = val_df['claim_no'].apply(preprocess_text)
test_df['claim_no'] = test_df['claim_no'].apply(preprocess_text)
PFO_test_df['claim_no'] = PFO_test_df['claim_no'].apply(preprocess_text)
test_df.shape, PFO_test_df.shape, LABEL_COLUMNS_CLAS

### norbert3

In [None]:
BASE_MODEL = 'ltg/norbert3-base' #LTG = Language Technology Group (University of Oslo)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
norbert3 = AutoModel.from_pretrained(BASE_MODEL, trust_remote_code=True, output_hidden_states=True)

### encode

In [None]:
BATCH_SIZE = 32
MAX_TOKEN_COUNT = 60
N_EPOCHS = 20
N_WORKERS = 2
INPUT_SIZE = norbert3.config.hidden_size 

class data_encode(Dataset):
    def __init__(self, data: pd.DataFrame, 
                tokenizer: AutoTokenizer, 
                MAX_TOKEN_COUNT: int = 128,
                labels=None): 
        self.tokenizer = tokenizer
        self.data = data
        self.max_token_len = MAX_TOKEN_COUNT
    
    def __len__(self):                 
        return len(self.data)           
    
    def __getitem__(self, index: int):  
        data_row = self.data.iloc[index]
        text = data_row.claim_no
        labels = data_row[LABEL_COLUMNS_CLAS] 
        encoding = tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_token_len,
            return_token_type_ids=False,
            padding="max_length", 
            truncation=True,
            return_attention_mask=True, 
            return_tensors='pt', 
        )
        return dict(
            text=text,
            input_ids=encoding["input_ids"].flatten(),       
            attention_mask=encoding["attention_mask"].flatten(),
            labels=torch.FloatTensor(labels))               

In [None]:
train_dataset = data_encode(
    train_df,
    tokenizer,
    MAX_TOKEN_COUNT)

val_dataset = data_encode(
    val_df,
    tokenizer,
    MAX_TOKEN_COUNT)

test_dataset = data_encode(
    test_df,
    tokenizer,
    MAX_TOKEN_COUNT)

PFO_test_dataset = data_encode(
    PFO_test_df,
    tokenizer,
    MAX_TOKEN_COUNT)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=N_WORKERS)
val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=N_WORKERS)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=N_WORKERS)
PFO_test_loader = DataLoader(dataset=PFO_test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=N_WORKERS)

In [None]:
steps_per_epoch = len(train_df) // BATCH_SIZE  
total_training_steps = steps_per_epoch * N_EPOCHS
WARMUP_STEPS = total_training_steps // 5
TRAINING_STEPS = steps_per_epoch * N_EPOCHS - WARMUP_STEPS

print('Epochs:{0:15}'.format(N_EPOCHS), '\nBatch size:{0:12}'.format(BATCH_SIZE))
print('-----------------------------')
print('Dataset len:{0:13}'.format(len(train_dataset)))
print('Steps per epoch:{0:8}'.format(steps_per_epoch))
print('-----------------------------')
print('Total training steps:{0:8}'.format(total_training_steps))
print('Warmup is 1/5 of the total training steps:')
print('- warmup steps:{0:8}'.format(WARMUP_STEPS), '\n- training steps:{0:8}'.format(TRAINING_STEPS))

### architecture

In [None]:
class LiteNorbertClassifierArch(pl.LightningModule):
   def __init__(self, base_model, 
      input_layer=768, fc1=512, fc2=256, fc3=64, n_classes=3, 
      drop_rate=0.2, n_training_steps=None, n_warmup_steps=None, lr=1e-5, wd=1e-2, eps=1e-8):
      super().__init__()
      self.DROP_RATE = drop_rate
      self.N_WARMUP_STEPS = n_warmup_steps
      self.N_TRAINING_STEPS = n_training_steps
      self.LR = lr
      self.WD = wd
      self.EPS = eps
      self.save_hyperparameters(ignore=['base_model']) #save params
      self.bert = base_model
      self.classification_head = nn.Sequential(
         nn.Linear(input_layer, fc1), nn.ReLU(), nn.Dropout(self.DROP_RATE), #fc1
         nn.Linear(fc1, fc2), nn.ReLU(), nn.Dropout(self.DROP_RATE),         #fc2 
         nn.Linear(fc2, fc3), nn.ReLU(), nn.Dropout(self.DROP_RATE),         #fc3   
         nn.Linear(fc3, n_classes))
      self.loss_func = nn.CrossEntropyLoss()
      self.accuracy = torchmetrics.classification.Accuracy(task='multiclass', num_classes=n_classes, average='weighted').cuda()
      self.precision = torchmetrics.classification.Precision(task='multiclass', num_classes=n_classes, average='weighted').cuda()
      self.recall = torchmetrics.classification.Recall(task='multiclass', num_classes=n_classes, average='weighted').cuda()
      self.f1_score = torchmetrics.classification.F1Score(task='multiclass', num_classes=n_classes, average='weighted').cuda()
      self.confusion_matrix = torchmetrics.classification.MulticlassConfusionMatrix(num_classes=n_classes).cuda()
      self.roc = torchmetrics.classification.ROC(task='multiclass', num_classes=n_classes).cuda()
      self.train_y_prob = []
      self.train_labels = []
      self.test_y_prob = []
      self.test_labels = []

   def forward(self, input_ids, attention_mask, labels=None):
      output = self.bert(input_ids, attention_mask=attention_mask)
      pooler = output.last_hidden_state[:, 0, :]   #CLS
      logits = self.classification_head(pooler)
      return logits

   def _recurrent_step(self, batch, batch_idx):
      input_ids = batch["input_ids"]
      attention_mask = batch["attention_mask"]
      labels = batch["labels"]
      logits = self.forward(input_ids, attention_mask, labels=labels)
      y_prob = torch.softmax(logits, dim=1)
      loss = 0
      if labels is not None: 
         loss = self.loss_func(logits, labels)
      return loss, logits, labels, y_prob

   def _calculate_metrics(self, batch, batch_idx):
         _, _, labels, y_prob = self._recurrent_step(batch, batch_idx)
         pred = y_prob.argmax(dim=1)
         label_arg = labels.argmax(dim=1)
         acc = self.accuracy(pred, label_arg)
         prec = self.precision(pred, label_arg)
         rec = self.recall(pred, label_arg)
         f1 = self.f1_score(pred, label_arg)
         return acc, prec, rec, f1

   def training_step(self, batch, batch_idx):
      loss, logits, labels, y_prob = self._recurrent_step(batch, batch_idx)
      acc, prec, rec, f1 = self._calculate_metrics(batch, batch_idx)
      self.log_dict({'train_accuracy': acc, 'train_precision': prec,
                     'train_recall': rec, 'train_f1': f1}, 
            logger=True, on_step=True, on_epoch=True, prog_bar=False)
      
      self.train_y_prob.append(y_prob)
      self.train_labels.append(labels)
      self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=True)
      return {'loss': loss, 'predictions': y_prob, 'labels': labels}

   def configure_optimizers(self):
      optimizer = AdamW(self.parameters(), lr=self.LR, weight_decay=self.WD, eps=self.EPS)
      scheduler = get_linear_schedule_with_warmup(
         optimizer, 
         num_warmup_steps=self.N_WARMUP_STEPS, 
         num_training_steps=self.N_TRAINING_STEPS)
      return dict(optimizer=optimizer, 
         lr_scheduler=dict(scheduler=scheduler, 
                           interval='step'))
   
   def on_training_epoch_end(self, batch, batch_idx):
     pass

   def validation_step(self, batch, batch_idx):
      loss, logits, labels, y_prob = self._recurrent_step(batch, batch_idx)
      acc, prec, rec, f1 = self._calculate_metrics(batch, batch_idx)
      self.log_dict({'val_accuracy': acc, 'val_f1': f1},
                  prog_bar=False, logger=True, on_step=False, on_epoch=True)
      self.log('val_loss', loss,  prog_bar=True, logger=True, on_step=True, on_epoch=True)
      return loss

   def test_step(self, batch, batch_idx):
      loss, logits, labels, y_prob = self._recurrent_step(batch, batch_idx)
      acc, prec, rec, f1 = self._calculate_metrics(batch, batch_idx)
      self.log_dict({'test_accuracy': acc, 'test_precision': prec, 'test_recall': rec, 'test_f1': f1},
               prog_bar=False, logger=True, on_step=True, on_epoch=True)
      self.test_y_prob.append(y_prob)
      self.test_labels.append(labels)
      self.log('test_loss', loss,  prog_bar=True, logger=True, on_step=True, on_epoch=True)
      return loss

   def on_train_end(self):
      pass

   def on_test_end(self):
      test_y_prob_total = torch.cat(self.test_y_prob, dim=0)
      test_labels_total = torch.cat(self.test_labels, dim=0).long()
      
      print('test_y_prob_total', test_y_prob_total)
      print('test_labels_total', test_labels_total)

      #ROC
      plt.figure(figsize=(8, 5))
      colors = [fake_line_color, maybe_line_color, real_line_color] 
      class_labels = ["Fake", "Maybe", "Real"]  #
      dot_colors = [fake_dot_color, maybe_dot_color, real_dot_color]

      y_prob_cpu = test_y_prob_total.cpu().detach().numpy()
      target_cpu = test_labels_total.cpu().detach().numpy()
      n_classes = y_prob_cpu.shape[1]
      best_thresholds = []
      fpr = []
      tpr = []
      thresholds = []

      for i in range(n_classes):
         fpr, tpr, thresholds = roc_curve(target_cpu[:, i], y_prob_cpu[:, i])
         roc_auc = auc(fpr, tpr)
         print('roc_auc', roc_auc)
         best_threshold_index = np.argmax(tpr - fpr)
         best_threshold = thresholds[best_threshold_index]      #get best threshold for the current class
         best_thresholds.append(best_threshold)                 #store best threshold
         plt.plot(fpr, tpr, lw=2, color=colors[i], label='ROC curve')
         plt.plot(fpr[best_threshold_index], tpr[best_threshold_index], 'o', color=dot_colors[i], label=f'Best Threshold: {thresholds[best_threshold_index]:.2f}')

      plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--', label='Reference Line')
      plt.xlim([0.0, 1.0])
      plt.ylim([0.0, 1.05])
      plt.xlabel('False Positive Rate', fontsize=12)
      plt.ylabel('True Positive Rate', fontsize=12)
      plt.title('ROC for Fake-NorBERT-multi', fontsize=15)
      plt.legend(loc="lower right")
      plt.tight_layout()
      plt.show()

      #Mean AUROC
      all_fpr = np.linspace(0, 1, 100)
      mean_tpr = 0.0
      mean_auc = 0.0
      for i in range(y_prob_cpu.shape[1]):
         fpr, tpr, _ = roc_curve(target_cpu[:, i], y_prob_cpu[:, i])
         mean_tpr += np.interp(all_fpr, fpr, tpr)

      mean_tpr /= y_prob_cpu.shape[1]
      mean_auc = auc(all_fpr, mean_tpr)
      plt.figure(figsize=(8, 5))
      plt.plot(all_fpr, mean_tpr, lw=2, color=multiclass_line_color, linestyle='--', label='ROC curve')
      plt.fill_between(all_fpr, mean_tpr, color=multiclass_fill_color, alpha=0.3)
      plt.text(0.5, 0.5, f'AUROC = {mean_auc:.2f}', fontsize=13, ha='right', va='top')
      plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--', label='Reference Line')
      plt.xlim([0.0, 1.0])
      plt.ylim([0.0, 1.05])
      plt.xlabel('False Positive Rate', fontsize=12)
      plt.ylabel('True Positive Rate', fontsize=12)
      plt.title('ROC for Fake-NorBERT-multi', fontsize=15)
      plt.legend(loc="upper left")
      plt.tight_layout()
      plt.show()

      #CM
      predicted_classes = test_y_prob_total.argmax(dim=1)
      y_truth = test_labels_total.argmax(dim=1)
      predicted_classes_cpu = predicted_classes.cpu().numpy()
      y_truth_cpu = y_truth.cpu().numpy()
      conf_matrix = confusion_matrix(y_truth_cpu, predicted_classes_cpu)
      labels = np.array([
         [f'True Fake\n{conf_matrix[0][0]}', f'False Maybe\n{conf_matrix[0][1]}', f'False Real\n{conf_matrix[0][2]}'],
         [f'False Fake\n{conf_matrix[1][0]}', f'True Maybe\n{conf_matrix[1][1]}', f'False Real\n{conf_matrix[1][2]}'],
         [f'False Fake\n{conf_matrix[2][0]}', f'False Maybe\n{conf_matrix[2][1]}', f'True Real\n{conf_matrix[2][2]}']
      ])
      plt.figure(figsize=(8, 5))
      sns.heatmap(conf_matrix, annot=labels, cmap=cmap, fmt="", annot_kws={"size": 13}, cbar=False)
      plt.ylabel("True Class", fontsize=12)
      plt.xlabel("Predicted Class", fontsize=12)
      plt.title("CM for Fake-NorBERT-multi", fontsize=15)
      plt.show()

      #CLASSIFICATION REPORT
      class_names = ["Fake", "Maybe", "Real"]
      print(classification_report(y_truth_cpu, predicted_classes_cpu, target_names=class_names))

      self.test_y_prob = [] #reset arrays
      self.test_labels = []

class TrainingTime(Callback):
    def on_train_start(self, trainer, NB3RegressorArch):
        self.start = time.time()
    
    def on_train_end(self, trainer, NB3RegressorArch):
        self.end = time.time()
        total_minutes = (self.end-self.start)
        print(f'\nTraining finished: {total_minutes:.2f} seconds\n')
        

### fine-tune

In [None]:
MONITOR_METRIC = 'val_loss'
MON_MODE = 'min'

FC1=1024
FC2=768
FC3=512

TUNER_LR = 3.31e-5 
LR = TUNER_LR   
WD = 1e-2
EPS = 1e-8
DR = 0.8
GC = 0
DEVICE = 'gpu'
PATIENCE = 1
TOP_K = 1
DEV_RUN = 0

TB_LOGGER = TensorBoardLogger('/home/light_logs',
            name='multi')

checkpoint_callback = ModelCheckpoint(
  dirpath = '/home/checkpoints',
  filename = 'best_multic',
  save_top_k = TOP_K, 
  verbose = True,
  monitor = MONITOR_METRIC,   
  mode = MON_MODE)

LR_MONITOR = LearningRateMonitor(logging_interval='epoch')

early_stopping_callback = EarlyStopping(
    monitor=MONITOR_METRIC, 
    patience=PATIENCE) 

nb3 = norbert3
clas_model = LiteNorbertClassifierArch(
    nb3,
    input_layer=INPUT_SIZE,
    fc1=FC1,
    fc2=FC2,
    fc3=FC3,
    drop_rate=DR,
    n_warmup_steps=WARMUP_STEPS,
    n_training_steps=TRAINING_STEPS,
    lr=LR,
    wd=WD,
    eps=EPS)

trainer = pl.Trainer(
    fast_dev_run=DEV_RUN,
    logger=TB_LOGGER,
    min_epochs=1,
    max_epochs=N_EPOCHS,
    accelerator = DEVICE,
    callbacks=[checkpoint_callback, early_stopping_callback, 
               TrainingTime(), LR_MONITOR]
)


print("\nFast_dev_run =", trainer.fast_dev_run)
print("Max epoch:", N_EPOCHS)
print("Batch size:", BATCH_SIZE)

In [None]:
torch.cuda.empty_cache()

### evaluate and test

In [None]:
trainer.fit(clas_model, train_loader, val_loader)

In [None]:
trainer.validate(clas_model, val_loader)

In [None]:
m_version = '/home/checkpoints/best_multi.ckpt'
trainer.test(ckpt_path=m_version, dataloaders=test_loader)
#test_loader
#PFO_test_loader