In [None]:
# Import necessary libraries and clear GPU cache
import torch
import os
import pandas as pd
import numpy as np
import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pytorch_lightning import seed_everything
from tqdm import tqdm
from pytorch_lightning.callbacks import TQDMProgressBar
import torchmetrics
from torchmetrics.functional.classification import binary_accuracy, binary_f1_score, binary_precision, binary_recall, binary_confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

In [None]:
# Suppress warnings
warnings.filterwarnings('ignore')

In [None]:
# Clear GPU cache and check GPU usage
torch.cuda.empty_cache()
!nvidia-smi

In [None]:
# Disable tokenizers parallelism to avoid warning messages
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
# Check for available GPU and set the device
if torch.cuda.is_available():    
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

In [None]:
# Load the training dataset
train = pd.read_csv('#') # Path to the training dataset

In [None]:
# Split the dataset into training and validation sets
train_data, val_data = train_test_split(train, test_size=0.2, random_state=42)

In [None]:
# Define hyperparameters
learning_rate = 2e-5
max_length = 512
batch_size = 8
num_labels = 2
epochs = 3

In [None]:
# Load tokenizer and pretrained model
model_ckp = '#' # Model checkpoint: microsoft/Multilingual-MiniLM-L12-H384
tokenizer = AutoTokenizer.from_pretrained(model_ckp)
pretrained_model = AutoModelForSequenceClassification.from_pretrained(model_ckp)

In [None]:
# Define loss function
loss_fn = nn.CrossEntropyLoss()

In [None]:
# Define the dataset class
class MyDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        text = self.data.iloc[index]['#'] # Column name for the text data
        label = self.data.iloc[index]['#'] # Column name for the label data
        encoding = self.tokenizer.encode_plus(
            text, 
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt')
        return encoding['input_ids'][0], encoding['attention_mask'][0], label.astype('int64')

In [None]:
# Create train and validation datasets and data loaders
train_dataset = MyDataset(train_data, tokenizer, max_length)
val_dataset = MyDataset(val_data, tokenizer, max_length)

In [None]:
# Define the model class
class MyModel(pl.LightningModule):
    def __init__(self, num_labels, batch_size, learning_rate):
        super().__init__()
        self.model = pretrained_model
        self.num_classes = num_labels
        self.loss_function = loss_fn
        self.batch_size = batch_size
        self.learning_rate = learning_rate
           
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits
    
    def train_dataloader(self):
        return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2, drop_last=True)
    
    def val_dataloader(self):
        return DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2, drop_last=True)
    
    def test_dataloader(self):
        return DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2, drop_last=True)
    
    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, label = batch
        outputs = self(input_ids, attention_mask)
        loss = self.loss_function(outputs, label)
        preds = torch.argmax(outputs, dim=1)
        accuracy = binary_accuracy(preds, label)
        f1_score = binary_f1_score(preds, label)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        self.log("train_accuracy", accuracy, prog_bar=True, logger=True)
        self.log("train_f1", f1_score, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, label = batch
        outputs = self(input_ids, attention_mask)
        loss = self.loss_function(outputs, label)
        preds = torch.argmax(outputs, dim=1)
        accuracy = binary_accuracy(preds, label)
        f1_score = binary_f1_score(preds, label)
        recall = binary_recall(preds, label)
        precision = binary_precision(preds, label)
        self.log("val_loss", loss, prog_bar=True, logger=True)
        self.log("val_accuracy", accuracy, prog_bar=True, logger=True)
        self.log("val_f1", f1_score, prog_bar=True, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        input_ids, attention_mask, label = batch
        outputs = self(input_ids, attention_mask)
        loss = self.loss_function(outputs, label)
        preds = torch.argmax(outputs, dim=1)
        accuracy = binary_accuracy(preds, label)
        f1_score = binary_f1_score(preds, label)
        recall = binary_recall(preds, label)
        precision = binary_precision(preds, label)
        confusion_matrix = binary_confusion_matrix(preds, label).cpu().detach().numpy()
        df_cm = pd.DataFrame(confusion_matrix, index=range(2), columns=range(2))
        plt.figure(figsize=(10, 7))
        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
        plt.close(fig_)
        self.log("test_loss", loss, prog_bar=True, logger=True)
        self.log("test_accuracy", accuracy, prog_bar=True, logger=True)
        self.log("test_f1", f1_score, prog_bar=True, logger=True)
        self.logger.experiment.add_figure('confusion matrix', fig_, global_step=self.current_epoch)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)
        return {
            'optimizer': optimizer,
            'lr_scheduler': lr_scheduler,
            'monitor': 'val_loss'
        }

# Set random seed for reproducibility
seed_everything(42, workers=True)

# Initialize the model
model = MyModel(num_labels, batch_size, learning_rate)

In [None]:
# Configure model checkpoint callback
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    dirpath='./best_model/',
    filename='best_model'
)

In [None]:
# Initialize the PyTorch Lightning trainer
trainer = pl.Trainer(max_epochs=epochs, gpus=1, 
                     callbacks=[checkpoint_callback, TQDMProgressBar(refresh_rate=5)], deterministic=True)

# Train the model
trainer.fit(model, model.train_dataloader(), model.val_dataloader())

In [None]:
# Validate the model
trainer.validate(model, model.val_dataloader())

In [None]:
# Load the best model checkpoint for testing
best_model = MyModel.load_from_checkpoint(checkpoint_callback.best_model_path, num_labels=num_labels, batch_size=batch_size, learning_rate=learning_rate)

In [None]:
# Load test datasets
test01 = pd.read_csv('#') # Path to the test dataset
test02 = pd.read_csv('#') # Path to the test dataset
test03 = pd.read_csv('#') # Path to the test dataset

In [None]:
# Evaluate the model on test datasets
for i, test_data in enumerate([test01, test02, test03], start=1):
    test_dataset = MyDataset(test_data, tokenizer, max_length)
    best_model = MyModel.load_from_checkpoint(
        checkpoint_callback.best_model_path,
        num_labels=num_labels,
        batch_size=batch_size,
        learning_rate=learning_rate
    )
    print(f"Testing on Dataset {i}...")
    trainer.test(best_model, DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True))