# Instruction Fine Tuning

We'll now take our base model and fine tune it to a chat bot like ChatGPT :).

AFAIK, OpenAI never created a fine-tuned instruct model of GPT2, so we're in uncharted territory now. It also means that this can either turn out great, or just horrible.

## Preparing the Dataset

I have a dataset that I synthethically generated by the included `dataset_generation.py` script using `gpt-4o` with the OpenAI API. 

You can run this script yourself to generate as many conversations as you need. Do note, that this does cost money, but you can also generate the data using a local LLM such as Llama, or Qwen. 

I have included a small version of the dataset in this repo called `shared/data/conversations-sm.jsonl`. It contains about 10K multi-turn conversations up to 1000 tokens -- which is just hitting th maximum context length for GPT2.

In [None]:
import jsonl

dataset_file_path = "data/small-conversations/conversations.jsonl"

with open(dataset_file_path, "r") as f:
  all_data = list(jsonl.load(f))

len(all_data), all_data[0]

We'll also shuffle the entire dataset in place to get a good mix of long and short conversations.

In [None]:
import random

random.shuffle(all_data)

all_data[0]

I'll do a training split of 90%. The validation and test portions of the dataset will be the remaining 10%. 

Of the remaining 10%, 90% of that 10% will go in as validation data, and the rest, test data.

In [None]:
train_end_idx = int(0.9 * len(all_data))
train_data = all_data[:train_end_idx]

val_and_test_data = all_data[train_end_idx:]
val_end_idx = int(0.9 * len(val_and_test_data))

val_data = val_and_test_data[:val_end_idx]
test_data = val_and_test_data[val_end_idx:]

len(train_data), len(val_data), len(test_data)

Let's filter out any strange entries...

In [None]:
train_data =  list(filter(lambda x: x is not None, train_data))
val_data = list(filter(lambda x: x is not None, val_data))
test_data = list(filter(lambda x: x is not None, test_data))

len(train_data), len(val_data), len(test_data)

And shuffle our training data once more!

In [5]:
import random
random.shuffle(train_data)

In [7]:
import torch
from torch.utils.data import Dataset

class InstructionDataset(Dataset):
  def __init__(self, data, tokenizer):
    self.data = data

    self.encoded_texts = []
    self.masks = []

    for item in self.data:
      tokens = []
      mask = []

      _instruction_base = f"Below is an instruction that describes a task. Write a response that appropriately completes the request"
      _instruction_base_tokens = tokenizer.encode(_instruction_base, allowed_special={"<|endoftext|>"})
      tokens.extend(_instruction_base_tokens)
      mask.extend([True] * len(_instruction_base_tokens))

      for message in item:
        if message["role"] == "user":
          _instruction_content_header = f"\n\n### Instruction:\n"
          _instruction_content_header_tokens = tokenizer.encode(_instruction_content_header, allowed_special={"<|endoftext|>"})
          tokens.extend(_instruction_content_header_tokens)
          mask.extend([True] * len(_instruction_content_header_tokens))

          _instruction_content_content = message["content"]
          _instruction_content_content_tokens = tokenizer.encode(_instruction_content_content, allowed_special={"<|endoftext|>"})
          tokens.extend(_instruction_content_content_tokens)
          mask.extend([True] * len(_instruction_content_content_tokens))
        else:
          _instruction_response_header = f"\n\n### Response:\n"
          _instruction_response_header_tokens = tokenizer.encode(_instruction_response_header, allowed_special={"<|endoftext|>"})
          tokens.extend(_instruction_response_header_tokens)
          mask.extend([True] * len(_instruction_response_header_tokens))

          # Dont mask the assistant response
          _instruction_response_content = message["content"] + "\n\n"
          _instruction_response_content_tokens = tokenizer.encode(_instruction_response_content, allowed_special={"<|endoftext|>"})
          tokens.extend(_instruction_response_content_tokens)
          mask.extend([False] * len(_instruction_response_content_tokens))

      self.encoded_texts.append(tokens)
      self.masks.append(mask)

      
  def __getitem__(self, index):
    return self.encoded_texts[index], self.masks[index]
  
  def __len__(self):
    return len(self.data)

In [8]:
def custom_collate_fn(
   batch,
   pad_token_id=50256,
   ignore_index=-100,
   allowed_max_length=None,
   device="cpu" 
):
  tokens_batch = [item[0] for item in batch]
  masks_batch = [item[1] for item in batch]

  # find the longest equence in the batch
  batch_max_length = max(len(tokens) + 1 for tokens in tokens_batch)

  inputs_lst, targets_lst = [], []

  for tokens, mask in zip(tokens_batch, masks_batch):
    new_tokens = tokens.copy()
    new_tokens += [pad_token_id]
    padded_tokens = (
      new_tokens + ([pad_token_id] * (batch_max_length - len(new_tokens)))
    )

    new_mask = mask.copy()
    new_mask += [True] # mask the added padded token
    padded_mask = (
      new_mask + ([True] * (batch_max_length - len(new_mask)))
    )

    inputs = torch.tensor(padded_tokens[:-1])
    targets = torch.tensor(padded_tokens[1:])

    pad_mask = targets == pad_token_id
    indices = torch.nonzero(pad_mask).squeeze()

    if indices.numel() > 1:
      targets[indices[1:]] = ignore_index

    # add the mask
    for j in range(min(len(padded_mask), len(targets))):
      if j + 1 < len(padded_mask) and padded_mask[j + 1]:
        targets[j] = ignore_index

    if allowed_max_length is not None:
      inputs = inputs[:allowed_max_length]
      targets = targets[:allowed_max_length]

    inputs_lst.append(inputs)
    targets_lst.append(targets)

  inputs_tensor = torch.stack(inputs_lst).to(device)
  targets_tensor = torch.stack(targets_lst).to(device)

  return inputs_tensor, targets_tensor

In [20]:
device = "cuda:1"


torch.cuda.empty_cache()

In [10]:
from functools import partial

customized_collate_fn = partial(
  custom_collate_fn,
  device=device,
  allowed_max_length=1024
)

In [14]:
from torch.utils.data import DataLoader
import tiktoken

tokenizer = tiktoken.get_encoding("gpt2")

num_workers = 0
batch_size = 32 

train_dataset = InstructionDataset(train_data, tokenizer)
train_loader = DataLoader(
  train_dataset,
  batch_size=batch_size,
  collate_fn=customized_collate_fn,
  shuffle=True,
  drop_last=True,
  num_workers=num_workers
)

val_dataset = InstructionDataset(val_data, tokenizer)
val_loader = DataLoader(
  val_dataset,
  batch_size=batch_size,
  collate_fn=customized_collate_fn,
  shuffle=False,
  drop_last=False,
  num_workers=num_workers
)

test_dataset = InstructionDataset(test_data, tokenizer)
test_loader = DataLoader(
  test_dataset,
  batch_size=batch_size,
  collate_fn=customized_collate_fn,
  shuffle=False,
  drop_last=False,
  num_workers=num_workers
)

In [None]:

print("Train loader:")
i = 0
for inputs, targets in train_loader:
    if i == 10:
        break
    print(inputs.shape, targets.shape)
    i+=1


In [None]:
print(inputs[0]) 
print(targets[0])

In [None]:
from scripts.train import calc_loss_loader
from scripts.model_loader import load_model_from_path


model = load_model_from_path(
  "models/10b/gpt2-355M-bfloat16.pth",
  device=device 
)
model.eval()

model = model.to(device).to(torch.bfloat16)
model = torch.compile(model)

with torch.no_grad():
  train_loss = calc_loss_loader(train_loader, model, device, num_batches=10)
  val_loss = calc_loss_loader(val_loader, model, device, num_batches=10)

print("Training loss", train_loss)
print("Validation loss", val_loss)

In [24]:
torch.cuda.empty_cache()

For this run it was like this:

- Epoch 1 - 8e-5
- Epoch 2 - 5e-5 
- Epoch 3 - 2e-5 

After Epoch 3 , i was starting to see diminishing returns. Training loss was 1.317 with a validation loss of 1.564. The gap was only widening. so I stopped there and consider this to be my complete model `gpt2-355M-it-bfloat16.pth`

In [None]:
import time
import tiktoken
import torch
from scripts.train import train_model_simple
from scripts.train import train_model_simple
from scripts.model_loader import load_model_from_path
from scripts.fine_tune import format_input

model = load_model_from_path(
  "models/gpt2-355M-model-it-ep2-long-v3.pth",
  device=device 
)

model = model.to(device).to(torch.bfloat16)
model.train()

tokenizer = tiktoken.get_encoding("gpt2")

start_time = time.time()

optimizer = torch.optim.AdamW(
  model.parameters(),
  lr=2e-5,
  weight_decay=0.1,
  fused=True
)

num_epochs=1

double_new_line_id = tokenizer.encode("\n\n", allowed_special={"<|endoftext|>"})[0]

train_losses, val_losses = train_model_simple(
  model=model,
  train_loader=train_loader,
  val_loader=val_loader,
  optimizer=optimizer,
  num_epochs=num_epochs,
  eval_freq=100,
  eval_iter=50,
  start_context=format_input(val_data[0]),
  tokenizer=tokenizer,
  device=device,
  save_iters=200,
  stop_sequence=[double_new_line_id]
)

end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")

In [30]:
from scripts.gpt2_common import save_model_and_optimizer
model_directory = "/home/rngo/code/ttnn-sandbox/notebooks/models"
save_model_and_optimizer(
  model_path=f"{str(model_directory)}/gpt2-355M-model-it-ep3-long-v3.pth",
  model=model,
  optimizer_path=f"{str(model_directory)}/optimizer-gpt2-355M-model-it-ep3-long-v3.pth",
  optimizer=optimizer
)

In [None]:
import tiktoken
from scripts.generate import generate
from scripts.util import text_to_token_ids, token_ids_to_text

checkpoints = [0, 500, 1000, 1500, 2000]

tokenizer = tiktoken.get_encoding("gpt2")
for checkpoint in checkpoints:
  print("-" * 20)
  if checkpoint // 4000 == 0:
    ep = 1
  else:
    ep = 2

  print(f"Testing a message at checkpoint: {checkpoint}, Epoch: {ep}")
  model = load_model_from_path(f"models/checkpoint-model-ep{ep}-{checkpoint}.pth", device)
  model.eval()

  for i, test in enumerate(test_data[:3]):
    print(f"## Test message: {i}")
    test_message = format_input(test)

    token_ids = generate(
      model,
      idx=text_to_token_ids(test_message, tokenizer).to(device),
      max_new_tokens=256,
      context_size=1024,
      temperature=1.0,
      top_k=20,
      eos_id=50256,
      device=device
    )

    text = token_ids_to_text(token_ids, tokenizer)
    print("-" * 20)
    print(text)
    print()




# Continued Fine Tuning - More Facts