In [1]:
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW

import numpy as np

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
import pytorch_lightning as pl

from transformers import AutoTokenizer
from transformers import GPT2Tokenizer
from transformers import get_linear_schedule_with_warmup
from transformers import GPT2LMHeadModel
from datasets import load_dataset, Dataset, DatasetDict


## Data

In [2]:
class XSumPreprocessor:
    def __init__(self, tokenizer, max_length):
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Add special tokens to the tokenizer
        self.special_tokens_dict = {'bos_token': '<bos>', 'eos_token': '<eos>', 'sep_token': '<sep>', 'pad_token': '<pad>'}
        self.num_added_toks = self.tokenizer.add_special_tokens(self.special_tokens_dict)
    def preprocess(self, example):
        # Concatenate article and summary and add special tokens
        encoded_example = tokenizer.encode_plus(
            f'{self.special_tokens_dict["bos_token"]} {example["document"]} {self.special_tokens_dict["sep_token"]} {example["summary"]} {self.special_tokens_dict["eos_token"]}',
            truncation=True,
            padding='max_length',
            max_length=self.max_length
        )

        return encoded_example

    def filter(self, dataset):
        dataset = [sample for sample in dataset if self.tokenizer.sep_token_id in sample['input_ids']]
        return dataset

In [3]:
from transformers import GPT2Tokenizer
from datasets import load_dataset

# Load pre-trained model tokenizer (vocabulary)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
max_length=1024

preprocessor = XSumPreprocessor(
                tokenizer = tokenizer,
                max_length = max_length
)

# Load XSum dataset
xsum_dataset = load_dataset('xsum')

use_percent = 5
dataset_train = load_dataset("xsum", split=f"train[:{use_percent}%]")
dataset_val = load_dataset("xsum", split=f"validation[:{use_percent}%]")
dataset_test = load_dataset("xsum", split=f"test[:{use_percent}%]")

dataset = DatasetDict({'train': dataset_train, 'validation': dataset_val, 'test': dataset_test})

# Apply the function to all examples in the dataset
xsum_dataset = dataset.map(preprocessor.preprocess, remove_columns=['document', 'summary'])
# Format the dataset to PyTorch tensors and split into training, validation, and test sets
xsum_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])

train_dataset = preprocessor.filter(xsum_dataset['train'])
val_dataset = preprocessor.filter(xsum_dataset['validation'])
test_dataset = xsum_dataset['test']


Found cached dataset xsum (/home/studio-lab-user/.cache/huggingface/datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset xsum (/home/studio-lab-user/.cache/huggingface/datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71)
Found cached dataset xsum (/home/studio-lab-user/.cache/huggingface/datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71)
Found cached dataset xsum (/home/studio-lab-user/.cache/huggingface/datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71)


Map:   0%|          | 0/10202 [00:00<?, ? examples/s]

Map:   0%|          | 0/567 [00:00<?, ? examples/s]

Map:   0%|          | 0/567 [00:00<?, ? examples/s]

## Model

In [4]:
from torch.nn.utils.rnn import pad_sequence
from torch.nn.functional import cross_entropy



In [6]:

### TODO: Save losses while training ###
### TODO: Add Checkpointing ###

import torch
from transformers import GPT2LMHeadModel
from pytorch_lightning import LightningModule
from torch.utils.data import DataLoader
from transformers import AdamW

class GPT2FineTuner(LightningModule):
    def __init__(self):
        super().__init__()
        
        self.model = GPT2LMHeadModel.from_pretrained("gpt2")
        # Resize token embeddings in case you have added more tokens in the vocab
        self.model.resize_token_embeddings(len(tokenizer))
        self.train_losses = []
        self.validation_losses = []

    def forward(self, input_ids, attention_mask=None):
        return self.model(input_ids, attention_mask=attention_mask)

    def training_step(self, batch, batch_nb):
        input_ids, attention_mask = batch['input_ids'], batch['attention_mask']
        sep_positions = (input_ids == tokenizer.sep_token_id).nonzero(as_tuple=False)
        # Forward pass
        outputs = self(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        # Calculate loss only on the reference summary
        loss = 0
        for i, sep_position in enumerate(sep_positions):
            sep_position = sep_position[1]  # Use the single element from the tensor

            shift_logits = logits[i, sep_position:-1, :].contiguous()
            shift_labels = input_ids[i, sep_position+1:].contiguous()
            loss += cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss = loss / len(sep_positions)  # average loss
        self.train_losses.append(loss.item())
        self.log('loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return {'loss': loss}
    def validation_step(self, batch, batch_nb):
        input_ids, attention_mask = batch['input_ids'], batch['attention_mask']
        sep_positions = (input_ids == tokenizer.sep_token_id).nonzero(as_tuple=False)
        # Forward pass
        outputs = self(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

       # Calculate loss only on the reference summary
        val_loss = 0
        for i, sep_position in enumerate(sep_positions):
            sep_position = sep_position[1] # Use the single element from the tensor
            shift_logits = logits[i, sep_position:-1, :].contiguous()
            shift_labels = input_ids[i, sep_position+1:].contiguous()
            val_loss += cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        val_loss = val_loss / len(sep_positions)  # average loss
        self.validation_losses.append(val_loss.item())
        self.log('val_loss', val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return {'val_loss': val_loss}

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=1e-4)

    def train_dataloader(self):
        return DataLoader(train_dataset, batch_size=2, shuffle=True)

    def val_dataloader(self):
        return DataLoader(val_dataset, batch_size=2)
    

In [7]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

model = GPT2FineTuner()

# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',  # Choose your metric here
    dirpath='./saved/models/',
    filename='GPT2FineTuner-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min',  # or 'max', depending on what you want to monitor
)


trainer = Trainer(max_epochs=5, accumulate_grad_batches=4, callbacks=[checkpoint_callback])
trainer.fit(model)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type            | Params
------------------------------------------
0 | model | GPT2LMHeadModel | 124 M 
------------------------------------------
124 M     Trainable params
0         Non-trainable params
124 M     Total params
497.772   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


epoch0: val_loss 0.167
epoch1: val_loss 0.163
epoch2: val_loss 0.165
epoch3: val_loss 0.174
epoch4: val_loss 

In [9]:
### NB: Will workd once losses are saved ###
#print('training loss', model.train_losses)
#print('validation loss', model.validation_losses)


In [10]:
# Save Model
tokenizer.save_pretrained("./saved/tokenizers/")


('./saved/tokenizers/tokenizer_config.json',
 './saved/tokenizers/special_tokens_map.json',
 './saved/tokenizers/vocab.json',
 './saved/tokenizers/merges.txt',
 './saved/tokenizers/added_tokens.json')

In [11]:
# The best model is saved at the path:
checkpoint_path = checkpoint_callback.best_model_path

In [12]:
model = GPT2FineTuner.load_from_checkpoint(checkpoint_path)
model.eval()  

GPT2FineTuner(
  (model): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50261, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_features=50261, bias=False)
  )
)

In [13]:
print(checkpoint_path)

/home/studio-lab-user/sagemaker-studiolab-notebooks/TextPressoMachine/saved/models/GPT2FineTuner-epoch=01-val_loss=0.16.ckpt


In [14]:
tokenizer = GPT2Tokenizer.from_pretrained("./saved/tokenizers")

In [15]:
gpt2_summarizer = model



# Inference

In [16]:
def summarize(model, text, length, device):
    ## From the blog ##
    text = torch.tensor(text, dtype=torch.long, device=device)
    text = text.unsqueeze(0)
    generated = text
    model = model.to(device)
    with torch.no_grad():
        for _ in tnrange(length):
            inputs = {'input_ids': generated}

            outputs = model(**inputs)
            next_token_logits = outputs[0][0, -1, :]
            next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
    return generated



### Test with one example

In [41]:
example=test_dataset[25]['input_ids']

In [42]:
sep_idx=(example == tokenizer.sep_token_id).nonzero(as_tuple=False).item()

In [43]:
print(sep_idx)

328


In [44]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [45]:
from tqdm import tnrange


In [46]:
text = example[:sep_idx].tolist()
summary =example[sep_idx+1:].tolist()
generated_text = summarize(gpt2_summarizer, text, length=100, device=device)


  for _ in tnrange(length):


  0%|          | 0/100 [00:00<?, ?it/s]

In [47]:
tokenized_summary = generated_text[0, sep_idx:].tolist()
id_summary = tokenizer.convert_ids_to_tokens(tokenized_summary,skip_special_tokens=True)
gpt2_summary = tokenizer.convert_tokens_to_string(id_summary)

In [48]:
print('######### Original Text #############')
print(tokenizer.decode(text), end='\n\n')
print('######### GPT2 Summary ##############')
print(gpt2_summary, end='\n\n')
print('######### Ground Truth Summary ###########')
print(tokenizer.decode(summary, skip_special_tokens=True), end='\n\n')

######### Original Text #############
<bos> The fine follows the conviction of former RBS trader, Shirlina Tsang, for fraud last year.
She was sentenced to 50 months in prison after being caught falsifying records of emerging markets trades.
Hong Kong regulators said RBS's controls were "seriously inadequate".
The Securities and Futures Commission (SFC) also said there were "significant weaknesses in its procedures, management systems and internal controls."
But the regulator said the fine took into account the bank's speedy action in alerting the authorities once it had discovered the illegal trades, which took place in its emerging markets rates business in 2011.
"This deserves substantial credit and is the reason why today's sanctions are not heavier ones," Mark Steward, the SFC's head of enforcement, said in a statement.
RBS responded with a statement, reading: "We put in place a comprehensive remediation programme that strengthened our governance and supervisory oversight, and our