# This notebook contains 4 sections: 
#### 1. Installing of libraries and model initialisation
#### 2. Training of model
#### 3. Loading of model
#### 4. Testing of model

#### To load and test the model, skip the execute all cells sequentially excluding the cells in the section of 'Training of model' 
#### To train and test the model, skip the execute all cells sequentially excluding the cells in the section of 'Loading of model' 

## Install and import the necessary libraries

In [None]:
!pip install -U evaluate
!pip install -U datasets
!pip install -U accelerate
!pip install -U transformers
!pip install scikit-learn matplotlib

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import evaluate
import accelerate
import matplotlib.pyplot as plt
from data_preprocessing import CustomDataset
from transformers import AutoTokenizer, pipeline
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from transformers import TrainerCallback, EarlyStoppingCallback
from models import BaseModel, CustomClassifier, SimpleTransformerModel, EnsembleModel
from safetensors.torch import load_file

print(torch.cuda.is_available())

### GradualUnfreezeCallback: The class that defines gradual unfreezing (no changes needed)
### EarlyStoppingCallback: The class that defines early stopping during training (no changes needed)

In [None]:
class GradualUnfreezeCallback(TrainerCallback):
    def __init__(self, model, enable_unfreezing, total_epochs, unfreeze_schedule):
        self.model = model
        self.enable_unfreezing = enable_unfreezing
        self.total_epochs = total_epochs
        self.unfreeze_schedule = unfreeze_schedule
        self.unfrozen_layers = 0  # Tracks the number of unfrozen layers

        # Freeze all layers except classification head initially
        if self.enable_unfreezing:
            base_model = getattr(self.model.pretrained_model, "base_model", self.model.pretrained_model)
            classifier = getattr(self.model.pretrained_model, "classifier", None) or getattr(self.model, "final_classifier", None)

            for param in base_model.parameters():
                param.requires_grad = False
            if classifier:
                for param in classifier.parameters():
                    param.requires_grad = True

    def on_init_end(self, args, state, control, **kwargs):
        """Required method to avoid the AttributeError."""
        pass  # No action needed on initialization

    def on_epoch_begin(self, args, state, control, **kwargs):
        """Unfreezes layers based on the predefined schedule."""
        if not self.enable_unfreezing or self.unfrozen_layers >= len(self.unfreeze_schedule):
            return  # Either all layers are unfrozen or unfreezing is disabled

        current_epoch = int(state.epoch)  # Ensure it's an integer
        next_unfreeze_epoch = self.unfreeze_schedule[self.unfrozen_layers]

        if current_epoch >= next_unfreeze_epoch:
            # Unfreeze one more layer
            layers = list(self.model.pretrained_model.base_model.children())[::-1]  # Reverse list to start from last layers
            if self.unfrozen_layers < len(layers):
                for param in layers[self.unfrozen_layers].parameters():
                    param.requires_grad = True

                self.unfrozen_layers += 1
                print(f"Epoch {current_epoch}: Unfroze layer {self.unfrozen_layers}")

early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=2,  # Number of evals to wait before stopping
    early_stopping_threshold=0.0  # Minimum delta to consider improvement
)

### Define a model checkpoint (e.g. "distiroberta-base" or "bert-base-uncased")
This cell creates the preprocessed dataset based on the model checkpoint used.

In [None]:
model_checkpoint = "distilroberta-base" # EDIT THIS
max_len = 512

# No changes needed to the code below in this cell
dataset = CustomDataset('HateSpeechDatasetBalanced.csv', model_checkpoint=model_checkpoint, seed=43)
train_dataset, val_dataset, test_dataset = dataset.get_splits()

tokenizer = dataset.get_tokenizer()
vocab_size = tokenizer.vocab_size

print(train_dataset, val_dataset, test_dataset)

### Creation of a pretrained model
#### If you are intending to train/test the Simple Transformer Layer model (indvidual transformer layers), run the first cell
#### If you are intending to train/test the distilBERT/RoBERTa or BERT/RoBERTa models (any other pretrained ones), run the second cell

In [None]:
# To initialise a model with individual transformer layers, run this cell (3 layers)
model = SimpleTransformerModel(vocab_size, num_labels=2, dropout=0.1, num_layers=3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
# model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1)
model = BaseModel(model_checkpoint, num_labels=2, hidden_dropout_prob=0.1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.resize_token_embeddings(len(dataset.get_tokenizer())) # need to resize due to new tokens added

## Training the model

### Run the following cells that define the train function and its arguments

In [None]:
metric_name = 'accuracy'
model_name = model_checkpoint.split("/")[-1]

total_epochs = 20
args = TrainingArguments(
    f"./snapshots/{model_name}-finetuned",
    eval_strategy = "epoch",
    save_strategy = "epoch",
    save_total_limit = 3,
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=total_epochs,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    push_to_hub=False,
    fp16=True
)

In [None]:
metric = evaluate.load(metric_name)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [None]:
def train_model(model, args, train_dataset, val_dataset, enable_unfreezing, total_epochs, unfreeze_schedule):
    gradual_unfreeze_callback = GradualUnfreezeCallback(
        model, enable_unfreezing, total_epochs, unfreeze_schedule
    )

    trainer = Trainer(
        model,
        args,
        train_dataset=train_dataset, # Explicitly name the arguments
        eval_dataset=val_dataset,   # Explicitly name the arguments
        compute_metrics=compute_metrics, # Explicitly name the arguments
        tokenizer=tokenizer,
        callbacks=[gradual_unfreeze_callback, early_stopping_callback]
    )

    trainer.train()

    return trainer

### Execute the cell below to initiate the training of the model. This may take some time.

In [None]:
train_log = train_model(model=model,
                        args=args,
                        train_dataset=train_dataset,
                        val_dataset=val_dataset,
                        enable_unfreezing=False,
                        total_epochs=total_epochs,
                        unfreeze_schedule=[1,2]
                        )

### Upon completion of the execution of the previous cell (model training), run the 2 cells below to plot the graphs of loss/accuracy vs epoch

In [None]:
def plot_losses(trainer):
    logs = trainer.state.log_history
    train_loss = [log["loss"] for log in logs if "loss" in log and "epoch" in log]
    val_loss = [log["eval_loss"] for log in logs if "eval_loss" in log and "epoch" in log]
    
    train_epochs = [log["epoch"] for log in logs if "loss" in log and "epoch" in log]
    val_epochs = [log["epoch"] for log in logs if "eval_loss" in log and "epoch" in log]

    plt.figure(figsize=(8, 5))
    plt.plot(train_epochs, train_loss, label="Training Loss")
    plt.plot(val_epochs, val_loss, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Train vs Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.show()


def plot_eval_metric(trainer, metric_name='eval_accuracy'):
    logs = trainer.state.log_history
    metric_vals = [log[metric_name] for log in logs if metric_name in log and "epoch" in log]
    epochs = [log["epoch"] for log in logs if metric_name in log and "epoch" in log]
    print(metric_vals, epochs)

    plt.figure(figsize=(8, 5))
    plt.plot(epochs, metric_vals, label=f"{metric_name.upper()} Score")
    plt.xlabel("Epoch")
    plt.ylabel(metric_name.upper())
    plt.title(f"{metric_name.upper()} Over Epochs")
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
#print(train_log.state.log_history)
plot_losses(train_log)
plot_eval_metric(train_log)

### Run the cell below to save the model, edit the path name based on your needs

In [None]:
train_log.save_model("./models/CHANGE_THIS") # for saving your model

## Loading the model (before testing)
### Assign weights_loacation to the location of the weights (a '.safetensors' file) and run the cell below to load the model

In [None]:
weights_location = "model.safetensors" #EDIT THIS
state_dict = load_file(weights_location)
model.load_state_dict(state_dict, strict=False)

## Testing the model

### Run the following 2 cells to test the model

In [None]:
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score
from torch.utils.data import Dataset, DataLoader

class CustomTorchDataset(Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        # Ensure the inputs are tensors and casted to long (as required by nn.Embedding)
        input_ids = torch.tensor(item['input_ids'], dtype=torch.long)
        attention_mask = torch.tensor(item['attention_mask'], dtype=torch.long)
        label = torch.tensor(item['label'], dtype=torch.long)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': label,
            'text': item['text']
        }

In [None]:
test_torch_dataset = CustomTorchDataset(test_dataset)
test_loader = DataLoader(test_torch_dataset, batch_size=32)

# Evaluation function
def evaluate_model(model, dataloader, device):
    model.eval()
    model.to(device)

    all_preds, all_labels, all_texts, all_losses = [], [], [], []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            texts = batch['text']

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            logits = outputs['logits']
            loss = outputs['loss']

            preds = torch.argmax(logits, dim=-1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_texts.extend(texts)
            all_losses.append(loss.item())

    avg_loss = sum(all_losses) / len(all_losses)
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')

    results_df = pd.DataFrame({
        'text': all_texts,
        'label': all_labels,
        'prediction': all_preds
    })

    return avg_loss, accuracy, f1, results_df

# Run evaluation
loss, acc, f1, df_results = evaluate_model(model, test_loader, device)

print(f"Loss: {loss:.4f}, Accuracy: {acc:.4f}, F1: {f1:.4f}")
print(df_results.head())