In [1]:
#| include: false
!pip install -q datasets transformers deepspeed

[0m

In [2]:
#| include: false
import datetime
import os
from pathlib import Path
import random
from typing import Any, Callable, Dict, List, Optional, Tuple

import datasets
from deepspeed.ops import adam
import matplotlib.pyplot as plt
import numpy as np
import pydantic
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision import models
from tqdm.auto import tqdm
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import wandb

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
os.environ["TOKENIZERS_PARALLELISM"] = "true"

# Wandb login:
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
wandb.login(key=user_secrets.get_secret("wandb_api_key"))
hf_token = user_secrets.get_secret("wandb_api_key")

print(pl.__version__, torch.__version__, transformers.__version__)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


1.7.7 1.11.0 4.20.1


![](../images/fine_tune_t5.jpeg){fig-alt="Image saying fine tune flan t5"}

## Introduction
You may have read my [previous post on fine tuning GPT-2](https://sachinruk.github.io/blog/deep-learning/2022/09/25/grammar-correction-via-gpt2.html) for grammar correction. Well, I am here to tell you I have made a terrible mistake. While it was a fun exercise to understand the intricacies of GPT-2, I butchered it into correcting grammar. Let me explain why.

Firstly, GPT-2 is a decoder only model. Meaning the current token can only attend previous tokens. While this is fine for our task by adding a seperator token, this also means that the decoder model needs to understand AND reconstruct the sentence. Therefore performing two tasks. T5 on the other has an encoder-decoder architecture. The encoder only contains an input task and the decoder only has a output/ generative task. Therefore dividing the responsibilities.

T5 is also trained to as a multi task model which does things like summarization, translation etc.
![t5 paper summary](https://production-media.paperswithcode.com/methods/new_text_to_text.jpg)

[Code for this blog.](https://www.kaggle.com/code/sachin/t5-for-grammar-correction)

In [3]:
#| include: false
LEARNING_RATE = 1e-4
EPOCHS = 1
BATCH_SIZE = 4
ACCUMULATE_GRADIENTS = 8
MAX_LEN = 256
LANGUAGE_MODEL = "google/flan-t5-base"
# LANGUAGE_MODEL = "bigscience/bloom-560m"
LOG_PATH = "/kaggle/working/logs/"
FREEZE_LAYERS = 2
IS_FREEZE_LAYERS = True
UNFREEZE_BATCH_IDX = 1000
LABEL_MASK = -100
NUM_BATCHES = 10_000

wandb.init(
    project="t5_grammar",
    entity="sachinruk",
    name=str(datetime.datetime.now()),
    config = {
        "language_model": LANGUAGE_MODEL,
        "batch_size": BATCH_SIZE,
        "accumulate_gradient": ACCUMULATE_GRADIENTS,
        "learning_rate": LEARNING_RATE,
        "is_freeze_layers": IS_FREEZE_LAYERS,
        "freeze_layers": FREEZE_LAYERS,
    }
)

class Tokenizer:
    def __init__(self, tokenizer, max_len: int):
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __getattr__(self, attribute: str):
        if hasattr(self.tokenizer, attribute):
            return getattr(self.tokenizer, attribute)
        else:
            raise AttributeError(f"{attribute} not found")

    def __call__(self, sentences: List[str], device:torch.device=None) -> AutoTokenizer:  
        tokenized = self.tokenizer(
            sentences, 
            truncation=True,
            padding=True,
            return_tensors="pt",
            max_length=self.max_len,
        )
        if device is not None:
            return {key: tensor.to(device) for key, tensor in tokenized.items()}
        return tokenized

    def decode(self, x: Dict[str, torch.LongTensor]):
        return [self.tokenizer.decode(sentence[:sentence_len]) for sentence, sentence_len in 
                zip(x["input_ids"], target["attention_mask"].sum(axis=-1))]
    
    def batch_decode(self, encoded_outputs: torch.LongTensor) -> List[str]:
        return self.tokenizer.batch_decode(encoded_outputs.cpu(), skip_special_tokens=True)
    
    def __len__(self):
        return len(self.tokenizer)


# get text base and transform
language_model = transformers.T5ForConditionalGeneration.from_pretrained(LANGUAGE_MODEL)
tokenizer = Tokenizer(
    AutoTokenizer.from_pretrained(
        LANGUAGE_MODEL, 
    ),
    MAX_LEN,
)
# move below to lightning_trainer
language_model.resize_token_embeddings(len(tokenizer))

class GeneratorConfig(pydantic.BaseModel):
    repetition_penalty: float = 1.2
    beam_search: bool = True
    num_beam: int = 5
    early_stopping: bool = True
    max_generated_len: int = MAX_LEN
    no_repeat_ngram_size: int = 2
    top_k: int = 2000
    top_p: float = 0.95

    def build_generator_kwargs(self) -> Dict[str, Any]:
#         common_params = {
#             "bos_token_id": self.bos_token_id,
#             "pad_token_id": self.pad_token_id,
#             "eos_token_id": self.eos_token_id,
#         }   
        if self.beam_search:
            return {
#                 **common_params,
                **{
                    "max_length": self.max_generated_len,
                    "num_beams": self.num_beam,
                    "no_repeat_ngram_size": self.no_repeat_ngram_size,
                    "early_stopping": self.early_stopping,
                    "repetition_penalty": self.repetition_penalty,
                }
            }
        else:
            return {
#                 **common_params,  
                **{
                    "max_length": self.max_generated_len,
                    "do_sample": True,
                    "top_k": self.top_k,
                    "top_p": self.top_p,
                    "early_stopping": self.early_stopping,
                    "repetition_penalty": self.repetition_penalty,
                }
            }

[34m[1mwandb[0m: Currently logged in as: [33msachinruk[0m. Use [1m`wandb login --relogin`[0m to force relogin


Downloading:   0%|          | 0.00/1.41k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/945M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.48k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.15k [00:00<?, ?B/s]

In [4]:
#| include: false
data = datasets.load_dataset("liweili/c4_200m", cache_dir="/kaggle/working/", streaming=True, split="train")\
        .shuffle(seed=42, buffer_size=10_000)
c4_train = data.skip(100000)
c4_valid = data.take(100000)
def group_batch(batch):
    return {k: [v] for k, v in batch.items()}
train_dl = c4_train.map(group_batch, batched=True, batch_size=BATCH_SIZE)
valid_dl = c4_valid.map(group_batch, batched=True, batch_size=BATCH_SIZE)

Downloading builder script:   0%|          | 0.00/2.79k [00:00<?, ?B/s]

## Data

Setting up the data is no different to what we did during GPT-2. We will still use the `c4_200m` dataset from huggingface datasets, and we will still try to match input to a corrected output, while also trying to teach the model when to leave it alone when it sees a good sentence.

## Loss function
However, we now come to the first gotcha. The loss function. While in GPT-2 we can use `outputs.loss` we cannot do so here. That is because HF does not interally shift the tokens for a autoregressive task like generation. Instead, it is expecting a missing token prediction task by default. The following function simply shifts the labels 1 across, so that we can predict one token ahead. The loss function is your standard cross entropy loss.

To dive into this deeper, the model predicts what the next token ought to be. Therefore, the shape of the output is `[batch_size, sequence_length, all_possible_tokens]`. If you are wondering why it's not `[batch_size, sequence_length, hidden_dim_size]`, that's because in this class of HF models (generative, specifically in this case, `transformers.T5ForConditionalGeneration`) have one more layer of shape `[hidden_dim_size, all_possible_tokens]` over which a softmax layer is used to get a probability over the tokens. This last layer is often simply the input embeddings transposed.

In [5]:
def calculate_loss_fn(loss_fn, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()
    return loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

The naive method to use the training data here would be to simply use the bad sentence as an input and the good sentence as the output. However, the model should also recognise when to leave a good sentence as is. Therefore, we calculate loss once for each scenario described and sum them up.

This brings us to the second gotcha. When I trained the first time around I simply used the tokenized sentence of the input/ output sentence. It fortunately trained well and masked a serious mathematical error. That was during inference, when we call the `model.generate(...)` function, it always starts off with a special token in the decoder. This could be accessed via `model.config.decoder_start_token_id`. Therefore, during training we need to prepend this token to the output sentence. We can see this in the `LightningModule` below.

In [6]:
#| code-fold: true
class LightningModule(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        tokenizer: Tokenizer,
        generation_kwargs: Dict[str, Any],
        lr: float,
        loss_fn: Callable = nn.CrossEntropyLoss(),
    ) -> None:
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.lr = lr
        self.generation_kwargs = generation_kwargs
        self.loss_fn = loss_fn
        self.prepend_sentence = PREPEND_SENTENCE
        
        decoder_start_token_id = model.config.decoder_start_token_id
        self.prepend_input_ids = torch.LongTensor([decoder_start_token_id] * BATCH_SIZE)[:, None]
        self.prepend_attention_masks = torch.LongTensor([1] * BATCH_SIZE)[:, None]
        
        self.model.train()
        if IS_FREEZE_LAYERS:
            for layer in self.model.encoder.block[:FREEZE_LAYERS]:
                layer.eval()
                for p in layer.parameters():
                    p.requires_grad = False
            for layer in self.model.decoder.block[:FREEZE_LAYERS]:
                layer.eval()
                for p in layer.parameters():
                    p.requires_grad = False
        
        self.table_logging = 0
        
    def prepend_tokens(self, tokenized_batch: Dict[str, torch.LongTensor], len_batch: int) -> Dict[str, torch.LongTensor]:
        input_ids = torch.cat(
            [
                self.prepend_input_ids[:len_batch, :].to(self.device),
                tokenized_batch["input_ids"],
            ],
            dim=-1,
        )
        attention_mask = torch.cat(
            [
                self.prepend_attention_masks[:len_batch, :].to(self.device),
                tokenized_batch["attention_mask"],
            ],
            dim=-1,
        )
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }
        
    def get_loss(self, input_sentences: List[str], output_sentences: List[str]) -> torch.FloatTensor:
        tokenized_input = self.tokenizer(
            [self.prepend_sentence + sentence for sentence in input_sentences], 
            self.device
        )
        tokenized_output = self.prepend_tokens(self.tokenizer(output_sentences, self.device), len(output_sentences))
        labels = tokenized_output["input_ids"].clone()
        labels[tokenized_output["attention_mask"] == 0] == LABEL_MASK
        
        out = self.model(
            input_ids=tokenized_input["input_ids"],
            attention_mask=tokenized_input["attention_mask"],
            decoder_input_ids=tokenized_output["input_ids"],
            decoder_attention_mask=tokenized_output["attention_mask"],
        )
        return calculate_loss_fn(self.loss_fn, out.logits, labels)
        
        
    def common_step(self, batch: Dict[str, str]) -> torch.Tensor:
        bad_grammar_loss = self.get_loss(batch["input"], batch["output"])
        good_grammar_loss = self.get_loss(batch["output"], batch["output"])
    
        return good_grammar_loss + bad_grammar_loss
        
    def training_step(
        self, batch: Dict[str, torch.LongTensor], batch_idx: int,
    ) -> torch.Tensor:
        loss = self.common_step(batch)     
        self.log("training_loss", loss, on_step=True, on_epoch=True, batch_size=len(batch["input"]))
             
        return loss

    def validation_step(
        self, batch: Tuple[torch.Tensor, List[str]], batch_idx: int,
    ) -> torch.Tensor:
        loss = self.common_step(batch)
        self.log("validation_loss", loss, on_step=False, on_epoch=True, batch_size=len(batch["input"]))
        
        if batch_idx == 0:
            self.log_examples(batch)
            
    def log_examples(self, batch):
        good_grammar_batch = self.tokenizer(batch["output"], device=self.device)
        bad_grammar_batch = self.tokenizer(batch["input"], device=self.device)
        encoded_good_outputs = self.model.generate(**good_grammar_batch, **self.generation_kwargs)
        encoded_bad_outputs = self.model.generate(**bad_grammar_batch, **self.generation_kwargs)
        generated_good_sentences = self.tokenizer.batch_decode(encoded_good_outputs)
        generated_bad_sentences = self.tokenizer.batch_decode(encoded_bad_outputs)
        
        columns = ["good input", "good output", "bad input", "bad output"]
        data = [
            [good_input, good_output, bad_input, bad_output]
            for good_input, good_output, bad_input, bad_output in zip(
                batch["output"], generated_good_sentences, batch["input"], generated_bad_sentences
            )
        ]
        table = wandb.Table(data=data, columns=columns)
        if self.logger is not None:
            self.table_logging += 1
            self.logger.experiment.log({f"epoch {self.table_logging} results": table})

    def configure_optimizers(self) -> torch.optim.Optimizer:
        if IS_FREEZE_LAYERS:
            return adam.FusedAdam(
                [
                    {"params": layer.parameters(), "lr": self.lr} for layer in language_model.encoder.block[FREEZE_LAYERS:]
                ] + \
                [
                    {"params": layer.parameters() , "lr": self.lr} for layer in language_model.decoder.block[FREEZE_LAYERS:]
                ]
            )
        else:
            return adam.FusedAdam(self.model.parameters(), self.lr)

## Training the model
The last gotcha comes from during the training, it seems using 16 bit training is unstable. Therefore, I was forced to use 32 bit with a smaller batch size, but this can be remedied by increasing the `accumulate_grad_batches`.

Also as a side note, I do like to freeze certain layers, a trick I picked up in [fast.ai](https://www.fast.ai). This is done in order to not overfit my training data. Conceptually, it makes sense not to train embeddings since some words will be seen (and therefore updated) more often than other.

In [None]:
#| code-fold: show
adam.FusedAdam(
    [
        {"params": layer.parameters(), "lr": self.lr} for layer in language_model.encoder.block[FREEZE_LAYERS:]
    ] + \
    [
        {"params": layer.parameters() , "lr": self.lr} for layer in language_model.decoder.block[FREEZE_LAYERS:]
    ]
)

In [7]:
#| include: false
generator_config = GeneratorConfig()

lightning_module = LightningModule(
    language_model, 
    tokenizer, 
    generation_kwargs=generator_config.build_generator_kwargs(), 
    lr=LEARNING_RATE,
)
is_interactive = os.environ["KAGGLE_KERNEL_RUN_TYPE"] == "Interactive"
logger = None if is_interactive else pl.loggers.WandbLogger(log_path=LOG_PATH)
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    accumulate_grad_batches=ACCUMULATE_GRADIENTS,
    gpus=torch.cuda.device_count(),
    gradient_clip_val=1.0,
    precision=32,
    logger=logger,
    enable_progress_bar=is_interactive,
    log_every_n_steps=200,
    limit_train_batches=20 if is_interactive else NUM_BATCHES,
    limit_val_batches=3 if is_interactive else 1.0,
    val_check_interval=UNFREEZE_BATCH_IDX if not is_interactive else 4,
)
trainer.fit(lightning_module, train_dl, valid_dl) #

  "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"


Using /root/.cache/torch_extensions/py37_cu110 as PyTorch extensions root...
Creating extension directory /root/.cache/torch_extensions/py37_cu110/fused_adam...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py37_cu110/fused_adam/build.ninja...
Building extension module fused_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/3] /usr/local/cuda/bin/nvcc  -DTORCH_EXTENSION_NAME=fused_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -I/opt/conda/lib/python3.7/site-packages/deepspeed/ops/csrc/includes -I/opt/conda/lib/python3.7/site-packages/deepspeed/ops/csrc/adam -isystem /opt/conda/lib/python3.7/site-packages/torch/include -isystem /opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.7/site-packages/torch/include/TH -isyst

## Results
Accoring to the limited experiments that I ran, it seems T5 does better than `Flan-T5`. Keep in mind that this was a drop in replacement during training. In the following examples it seems to be doing a decent job of leaving some text as is while redoing others.
![results](https://i.imgur.com/PVZzGp2.png)

However, let me add that Flan-T5 is magical! If you read the paper it is supposed to generalise to unseen tasks. The original T5 paper was only tasked with 5 tasks, whereas there were many tasks than Flan-T5 was trained with the same footprint in terms of number of weights. As you can see in the following, simply by adding `Correct grammar in following sentence: ` I was able to get a corrected sentence in [Flan-T5-large model](https://huggingface.co/google/flan-t5-large?text=Correct+grammar+in+this+sentence%3A+According+to+my+last+paper%2C+meke+sure+you+have+Git+installed+and+on+Path.).
![flan-t5 result](https://i.imgur.com/4EGRzgN.png)

## Summary
In closing the two main take away points are:
1. Redo the loss calculation.
2. Change to 32 bit training/ try bfloat16.

If you can afford to pay for training it is worth trying to train Flan-T5 for longer and see where it gets to. My wandb logs can be seen [here](https://wandb.ai/sachinruk/t5_grammar) and the kaggle kernel can be found [here](https://www.kaggle.com/code/sachin/t5-for-grammar-correction).

## Shameless Self Promotion
If you enjoyed the tutorial [buy my course](https://www.udemy.com/course/machine-learning-and-data-science-2021/?referralCode=E79228C7436D74315787) (30 days moneyback).