## Load Packages

In [None]:
import torch

import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn.functional import cross_entropy

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl
from pytorch_lightning import LightningModule

from transformers import GPT2Tokenizer, GPT2Config

from transformers import get_linear_schedule_with_warmup
from transformers import GPT2LMHeadModel
from datasets import load_dataset, DatasetDict

from tqdm import tnrange

import matplotlib.pyplot as plt




## Data

In [None]:
class XSumPreprocessor:
    def __init__(self, tokenizer, max_length):
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Add special tokens to the tokenizer
        self.special_tokens_dict = {'bos_token': '<bos>', 'eos_token': '<eos>', 'sep_token': '<sep>', 'pad_token': '<pad>'}
        self.num_added_toks = self.tokenizer.add_special_tokens(self.special_tokens_dict)
    def preprocess(self, example):
        # Concatenate article and summary and add special tokens
        encoded_example = tokenizer.encode_plus(
            f'{self.special_tokens_dict["bos_token"]} {example["document"]} {self.special_tokens_dict["sep_token"]} {example["summary"]} {self.special_tokens_dict["eos_token"]}',
            truncation=True,
            padding='max_length',
            max_length=self.max_length
        )

        return encoded_example

    def filter(self, dataset):
        dataset = [sample for sample in dataset if self.tokenizer.sep_token_id in sample['input_ids']]
        return dataset

In [None]:
# Load pre-trained model tokenizer (vocabulary)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
max_length=1024

preprocessor = XSumPreprocessor(
                tokenizer = tokenizer,
                max_length = max_length
)



use_percent = 3
dataset_train = load_dataset("xsum", split=f"train[:{use_percent}%]")
dataset_val = load_dataset("xsum", split=f"validation[:{use_percent*2}%]")
dataset_test = load_dataset("xsum", split=f"test[:{use_percent*2}%]")


dataset = DatasetDict({'train': dataset_train, 'validation': dataset_val})

# Apply the function to all examples in the dataset
xsum_dataset = dataset.map(preprocessor.preprocess, remove_columns=['document'])
# Format the dataset to PyTorch tensors and split into training, validation, and test sets
xsum_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])

train_dataset = preprocessor.filter(xsum_dataset['train'])
val_dataset = preprocessor.filter(xsum_dataset['validation'])
test_dataset = dataset_test


## Model

In [None]:
class GPT2FineTuner(LightningModule):
    def __init__(self, tokenizer):
        super().__init__()
        
        self.model = GPT2LMHeadModel.from_pretrained("gpt2")
        # Resize token embeddings in case you have added more tokens in the vocab
        self.tokenizer = tokenizer
        self.model.resize_token_embeddings(len(self.tokenizer))
        self.train_losses = []
        self.validation_losses = []
        
        self.train_losses_epoch=[]
        self.validation_losses_epoch=[]
        self.config=self.model.config
        self.save_hyperparameters()

    def forward(self, input_ids, attention_mask=None):
        return self.model(input_ids, attention_mask=attention_mask)

    def training_step(self, batch, batch_nb):
        input_ids, attention_mask = batch['input_ids'], batch['attention_mask']
        sep_positions = (input_ids == tokenizer.sep_token_id).nonzero(as_tuple=False)
        # Forward pass
        outputs = self(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        # Calculate loss only on the reference summary
        loss = 0
        for i, sep_position in enumerate(sep_positions):
            sep_position = sep_position[1]  # Use the single element from the tensor
            shift_logits = logits[i, sep_position:-1,:].contiguous()
            shift_labels = input_ids[i, sep_position+1:].contiguous()
            loss_calculation = cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            loss += torch.where(torch.isnan(loss_calculation), torch.tensor(0), loss_calculation)


        loss = loss / len(sep_positions)  # average loss
        self.train_losses.append(loss)
        self.log('loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return loss

    def validation_step(self, batch, batch_nb):
        input_ids, attention_mask = batch['input_ids'], batch['attention_mask']
        sep_positions = (input_ids == tokenizer.sep_token_id).nonzero(as_tuple=False)
        # Forward pass
        outputs = self(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

       # Calculate loss only on the reference summary
        val_loss = 0
        for i, sep_position in enumerate(sep_positions):
            sep_position = sep_position[1] # Use the single element from the tensor
            shift_logits = logits[i, sep_position:-1, :].contiguous()
            shift_labels = input_ids[i, sep_position+1:].contiguous()
            val_loss += cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        val_loss = val_loss / len(sep_positions)  # average loss
        self.validation_losses.append(val_loss)
        self.log('val_loss', val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return val_loss
    
    
    
    def summarize(self, text, length, device):
        ## From the blog ##
        text_length=len(text)
        text = torch.tensor(text, dtype=torch.long, device=device)
        text = text.unsqueeze(0)
        generated = text
        model = self.model.to(device)
        with torch.no_grad():
            for _ in tnrange(length):
                inputs = {'input_ids': generated}
                outputs = model(**inputs)
                next_token_logits = outputs[0][0, -1, :]
                next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1)
                generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
        generated = generated[0, text_length:]

        return generated


    def on_train_epoch_end(self):
        # Calculate average loss for the epoch and append to the list
        avg_train_loss = sum(self.train_losses)/ len(self.train_losses)
        self.train_losses_epoch.append(avg_train_loss.item())

        # Reset epoch loss accumulator
        self.train_losses = []

    def on_validation_epoch_end(self):
        # Calculate average loss for the epoch and append to the list
        avg_val_loss = sum(self.validation_losses) / len(self.validation_losses)
        self.validation_losses_epoch.append(avg_val_loss.item())

        # Reset epoch loss accumulator
        self.validation_losses = []


    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=1e-4)

    def train_dataloader(self):
        return DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2)

    def val_dataloader(self):
        return DataLoader(val_dataset, batch_size=2, num_workers=2)
    

In [None]:
model = GPT2FineTuner(tokenizer)

# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',  # Choose your metric here
    dirpath='./saved/saved_checkpoints/',
    filename='GPT2FineTuner-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min',  # or 'max', depending on what you want to monitor
)

trainer = Trainer(max_epochs=5, callbacks=[checkpoint_callback])
trainer.fit(model)



In [None]:

# The best model is saved at the path:
checkpoint_path = checkpoint_callback.best_model_path
## model = GPT2FineTuner.load_from_checkpoint(checkpoint_path)
# model.eval()
#gpt2_summarizer = model

## Save Fine-Tuned Model

In [None]:
import os

# Create the parent directory if it doesn't exist
parent_dir = "./models"
os.makedirs(parent_dir, exist_ok=True)


In [None]:
import json
model_file='./models/gpt2_model.bin'
config_file='./models/gpt2_config.json'

torch.save(model.state_dict(), model_file)
model.config.to_json_file(config_file)
tokenizer.save_pretrained("./saved/models/")

In [None]:
config = GPT2Config.from_json_file(config_file)
test_model = GPT2LMHeadModel(config)
state_dict = torch.load(model_file)
new_state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()}
test_model.load_state_dict(new_state_dict)
test_model.eval()

## Push Model to Hugging Face Hub

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
test_model.push_to_hub("gpt2_test_config", organization="ZinebSN")
tokenizer.push_to_hub("gpt2_test_config", organization="ZinebSN")

## Plotting the losses

In [None]:
# Plotting the losses
train_losses= model.train_losses_epoch
validation_losses= model.validation_losses_epoch
plt.plot(train_losses, label='Training Loss')
plt.plot(validation_losses[1:], label='Validation Loss')

# Adding labels and title
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('GPT2: Losses over Epochs')

# Adding legend
plt.legend()

plt.savefig('losses_plot.png')
# Displaying the plot
plt.show()