# Training GPT-J-6B with 8-bit weights for Headline Generation
This notebook contains the code to train a GPT Model using Transfer Learning on the task of Headline Generation from News Articles.


* GPT-J-6B: A 6 billion parameter, autoregressive text generation model by EleutherAI trained on [The Pile](https://pile.eleuther.ai/) using [Mesh Transformer JAX](https://github.com/kingoflolz/mesh-transformer-jax/) (Ben Wang and Aran Komatsuzaki).
* GPT-J-6B-8bit: A quantized GPT-J-6B with 8-bit weights for scalable and cost-efficient fine-tuning by Hivemind with [LoRA](https://arxiv.org/pdf/2106.09685.pdf) and [8-bit Adam](https://arxiv.org/abs/2110.02861).
* LoRA Adapter Implementation is created by [Denis Mazur](https://github.com/deniskamazur): [Notebook](https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es)



## Installs & Imports
Freeze library versions for reproduction

In [None]:
!pip install transformers==4.25.1
!pip install bitsandbytes==0.35.4
!pip install datasets==2.7.1
!pip install accelerate==0.15.0

In [None]:
import transformers
import pandas as pd

# Torch Imports
import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import custom_fwd, custom_bwd
from torch.utils.data import DataLoader

import accelerate

from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
from bitsandbytes.optim import Adam8bit

from datasets import load_dataset

from google.colab import drive

## Convert Model to 8 bits
We convert EleutherAI's GPT-J-6B model to 8 bits using facebook's [bitsandbytes](https://github.com/facebookresearch/bitsandbytes) library. This reduces the model's size from 20Gb down to 6Gb.
* large weight tensors are quantized using dynamic 8-bit quantization and de-quantized just-in-time for multiplication
* using gradient checkpoints to store one only activation per layer: using dramatically less memory at the cost of 30% slower training

In [None]:
class FrozenBNBLinear(nn.Module):
    def __init__(self, weight, absmax, code, bias=None):
        assert isinstance(bias, nn.Parameter) or bias is None
        super().__init__()
        self.out_features, self.in_features = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
        self.bias = bias
 
    def forward(self, input):
        output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
        if self.adapter:
            output += self.adapter(input)
        return output
 
    @classmethod
    def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
        weights_int8, state = quantize_blockise_lowmemory(linear.weight)
        return cls(weights_int8, *state, linear.bias)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
 
 
class DequantizeAndLinear(torch.autograd.Function): 
    @staticmethod
    @custom_fwd
    def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
                absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        ctx.save_for_backward(input, weights_quantized, absmax, code)
        ctx._has_bias = bias is not None
        return F.linear(input, weights_deq, bias).clone()
 
    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output: torch.Tensor):
        assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
        input, weights_quantized, absmax, code = ctx.saved_tensors
        # grad_output: [*batch, out_features]
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        grad_input = grad_output @ weights_deq
        grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
        return grad_input, None, None, None, grad_bias
 
 
class FrozenBNBEmbedding(nn.Module):
    def __init__(self, weight, absmax, code):
        super().__init__()
        self.num_embeddings, self.embedding_dim = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
 
    def forward(self, input, **kwargs):
        with torch.no_grad():
            # note: both quantized weights and input indices are not differentiable
            weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
            output = F.embedding(input, weight_deq, **kwargs)
        if self.adapter:
            output += self.adapter(input)
        return output 
 
    @classmethod
    def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
        weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
        return cls(weights_int8, *state)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
 
 
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
    assert chunk_size % 4096 == 0
    code = None
    chunks = []
    absmaxes = []
    flat_tensor = matrix.view(-1)
    for i in range((matrix.numel() - 1) // chunk_size + 1):
        input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
        quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
        chunks.append(quantized_chunk)
        absmaxes.append(absmax_chunk)
 
    matrix_i8 = torch.cat(chunks).reshape_as(matrix)
    absmax = torch.cat(absmaxes)
    return matrix_i8, (absmax, code)
 
 
def convert_to_int8(model):
    """Convert linear and embedding modules to 8-bit with optional adapters"""
    for module in list(model.modules()):
        for name, child in module.named_children():
            if isinstance(child, nn.Linear):
                print(name, child)
                setattr(
                    module,
                    name,
                    FrozenBNBLinear(
                        weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                        bias=child.bias,
                    ),
                )
            elif isinstance(child, nn.Embedding):
                setattr(
                    module,
                    name,
                    FrozenBNBEmbedding(
                        weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                    )
                )

class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
    def __init__(self, config):
        super().__init__(config)

        convert_to_int8(self.attn)
        convert_to_int8(self.mlp)


class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)
        

class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)


transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock

We use the model configuration and the tokenizer of GPT-J-6B and set `pad_token` as `eos_token`.


In [None]:
config = transformers.GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
config.pad_token_id = config.eos_token_id

tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = config.pad_token_id

## Load pretrained model
We load the pre-trained gpt-j-6b with 8-bit weights from [huggingface](https://huggingface.co/hivemind/gpt-j-6B-8bit). To reduce the peak RAM usage, we add the argument, `low_cpu_mem_usage=True` to `from_pretrained`.

In [None]:
gpt = GPTJForCausalLM.from_pretrained("hivemind/gpt-j-6B-8bit", low_cpu_mem_usage=True)

## Add LoRA Adapter
Low-Rank Adaptation, or LoRA, which freezes the pretrained model weights and injects trainable rank decomposition matrices into each layer of the Transformer architecture, greatly reducing the number of trainable parameters for downstream tasks.
* We set `adapter_dim` from 16 to 4
* We set the Dropout `p` from 0 to 0.1

In [None]:
def add_adapters(model, adapter_dim=4, p = 0.1):
    assert adapter_dim > 0

    for name, module in model.named_modules():
      if isinstance(module, FrozenBNBLinear):
          if "attn" in name or "mlp" in name or "head" in name:
              print("Adding adapter to", name)
              module.adapter = nn.Sequential(
                nn.Linear(module.in_features, adapter_dim, bias=False),
                nn.Dropout(p=p),
                nn.Linear(adapter_dim, module.out_features, bias=False),
            )
              print("Initializing", name)
              nn.init.zeros_(module.adapter[2].weight)

          else:
              print("Not adding adapter to", name)
      elif isinstance(module, FrozenBNBEmbedding):
          print("Adding adapter to", name)
          module.adapter = nn.Sequential(
                nn.Embedding(module.num_embeddings, adapter_dim),
                nn.Dropout(p=p),
                nn.Linear(adapter_dim, module.embedding_dim, bias=False),
            )
          print("Initializing", name)
          nn.init.zeros_(module.adapter[2].weight)

In [None]:
add_adapters(gpt)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpt.to(device)

## Load Dataset
We load our dataset as csv files from Google Drive. Data Structure:

| pair                                              |   |   |   |   |
|---------------------------------------------------|---|---|---|---|
| [Text:] Lorem Ipsum<br>[Title:] Title of Lorem    |   |   |   |   |
| [Text:] Dolor Sit Amet<br>[Title:] Title of Dolor |   |   |   |   |
|                                                   |   |   |   |   |

In [None]:
drive.mount('/content/drive')

In [None]:
dataset = load_dataset('csv', data_files={'train': '/content/drive/MyDrive/train.csv', 'test': '/content/drive/MyDrive/test.csv'})

## Tokenize Data
We tokenize our dataset and set `max_length=2048` for our task. Since GPTJ models operate with a maximum total token count of 2048 tokens, one Text-Title-Pair must not exceed this token count. You can use [this tool](https://beta.openai.com/tokenizer) to understand how a piece of text would be tokenized and the total count of tokens in that piece of text.

In [None]:
def tokenize_function(examples):
    return tokenizer(examples["pair"], padding=True, truncation=True, max_length=2048)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["pair"])

#Convert to Torch Format
tokenized_datasets.set_format("torch")

In [None]:
full_train_dataset = tokenized_datasets["train"]
train_dataloader = DataLoader(full_train_dataset, shuffle=False, batch_size=8)

## Training Params
We define the training parameters and initialize the Optimizer, Schedluer and Scaler.

In [None]:
num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)
num_warmup_steps = int(num_training_steps*0.1)
learning_rate = 1e-5 # Initial Learning Rate of 0.00001

We activate gradient checkpointing for the model to reduce memory load.

In [None]:
gpt.gradient_checkpointing_enable()

We set a savepath for the model

In [None]:
filepath = '/content/drive/MyDrive/model.pt'

We initialize the Optimizer, Schedluer and Scaler for training

In [None]:
optimizer = Adam8bit(gpt.parameters(), lr=learning_rate, weight_decay=0.01)
lr_scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
scaler = torch.cuda.amp.GradScaler()

## Training Loop
We train the model and save it 👏

In [None]:
progress_bar = tqdm(range(num_training_steps))
gpt.train()
k = 0

# Iterate through epochs
for epoch in range(num_epochs):
    # Iterate through our training data in batches
    for batch in train_dataloader:
        k = k + 1
        if k % 500 == 0:
            print(k)

            # Define a custom model state dict
            state = {'k': k, 'epoch': num_epochs, 'lr_scheduler': lr_scheduler.state_dict(
            ), 'state_dict': gpt.state_dict(), 'optimizer': optimizer.state_dict()}

            # Save Model with Torch
            torch.save(state, filepath)

        # Define batch to train on
        batch = {k: v.to(device) for k, v in batch.items()}

        # Zeroing out the gradients
        # Explanation: https://stackoverflow.com/a/48009142
        optimizer.zero_grad()

        # Runs the forward pass with autocasting
        with torch.cuda.amp.autocast():

            # Feed batch in custom forward function
            out = gpt.forward(**batch,)
            
            # Custom loss function
            loss = F.cross_entropy(out.logits[:, :-1, :].flatten(0, -2),
                                   batch['input_ids'][:, 1:].flatten(),
                                   reduction='mean', 
                                   label_smoothing=0.1)

        print(loss)

        # Scales loss
        # Calls backward() on scaled loss to create scaled gradients
        scaler.scale(loss).backward()

        # Unscales gradients held by optimizer’s assigned parameters
        scaler.unscale_(optimizer)

        # clips the norm of the overall gradient by concatenating all parameters passed to the function
        # documentation: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
        torch.nn.utils.clip_grad_norm_(gpt.parameters(), 1.0)

        # Updates the scale for next iteration.
        scaler.step(optimizer)
        scaler.update()

        # Updates initial learning rate
        lr_scheduler.step()

        progress_bar.update(1)

## Evaluate Model
We evaluate the model and generate a headline.

In [None]:
gpt.eval()

input_text="[Text:]Lorem Ipsum\n\n[Titel:]"

with torch.no_grad():
  prompt = tokenizer(cleaned_text, truncation=True, padding=True, return_tensors='pt')
  prompt = {key: value.to(device) for key, value in prompt.items()}
  out = gpt.generate(**prompt, min_length=5, max_new_tokens=256, top_p=0.7, temperature=1.0, do_sample=True)
  print(tokenizer.decode(out[0]))