# Instruction Fine Tuning

In [1]:
import jsonl

with open("data/small-conversations/conversations.jsonl", "r") as f:
  all_data = list(jsonl.load(f))

len(all_data), all_data[0]

(79808,
 [{'role': 'user', 'content': "Hi there, how's your day going?"},
  {'role': 'assistant',
   'content': 'Hello! My day is going great, thank you for asking. How about yours?'},
  {'role': 'user',
   'content': "It's been a bit busy, but I'm managing. Any interesting facts to share?"},
  {'role': 'assistant',
   'content': 'Did you know that honey never spoils? Archaeologists have found pots of honey in ancient Egyptian tombs that are over 3000 years old and still edible!'},
  {'role': 'user',
   'content': "Wow, that's fascinating! I never would have guessed. What makes honey last so long?"},
  {'role': 'assistant',
   'content': "Honey's longevity is due to its low moisture content and acidic nature, which make it inhospitable for bacteria and microorganisms to grow."},
  {'role': 'user',
   'content': "That makes sense. Nature really is amazing, isn't it?"},
  {'role': 'assistant',
   'content': "Absolutely! There's so much complexity and wonder in the natural world. Is there

In [2]:
import random

random.shuffle(all_data)

all_data[0]

[{'role': 'user',
  'content': 'Hey there! Can you suggest a good movie for tonight?'},
 {'role': 'assistant',
  'content': "Sure! Have you seen 'The Grand Budapest Hotel'? It's a delightful mix of comedy and adventure."},
 {'role': 'user', 'content': "I haven't watched it yet. What's it about?"},
 {'role': 'assistant',
  'content': "It's about the adventures of a legendary concierge and his friendship with a young lobby boy at a famous European hotel. It's quite entertaining with a unique visual style."},
 {'role': 'user',
  'content': "Sounds interesting! I'll definitely give it a try. Thanks!"},
 {'role': 'assistant', 'content': "You're welcome! Enjoy the movie!"}]

** test**

In [3]:
train_end_idx = int(0.9 * len(all_data))
train_data = all_data[:train_end_idx]

val_and_test_data = all_data[train_end_idx:]
val_end_idx = int(0.9 * len(val_and_test_data))

val_data = val_and_test_data[:val_end_idx]
test_data = val_and_test_data[val_end_idx:]

len(train_data), len(val_data), len(test_data)

(71827, 7182, 799)

In [4]:
train_data =  list(filter(lambda x: x is not None, train_data))
val_data = list(filter(lambda x: x is not None, val_data))
test_data = list(filter(lambda x: x is not None, test_data))

len(train_data), len(val_data), len(test_data)

(71827, 7182, 799)

In [5]:
import random
random.shuffle(train_data)

In [6]:
import jsonl

jsonl.dump(train_data, "data/small-conversations/train_data.jsonl")
jsonl.dump(val_data, "data/small-conversations/val_data.jsonl")
jsonl.dump(test_data, "data/small-conversations/test_data.jsonl")

In [7]:
import torch
from torch.utils.data import Dataset

class InstructionDataset(Dataset):
  def __init__(self, data, tokenizer):
    self.data = data

    self.encoded_texts = []
    self.masks = []

    for item in self.data:
      tokens = []
      mask = []

      _instruction_base = f"Below is an instruction that describes a task. Write a response that appropriately completes the request"
      _instruction_base_tokens = tokenizer.encode(_instruction_base, allowed_special={"<|endoftext|>"})
      tokens.extend(_instruction_base_tokens)
      mask.extend([True] * len(_instruction_base_tokens))

      for message in item:
        if message["role"] == "user":
          _instruction_content_header = f"\n\n### Instruction:\n"
          _instruction_content_header_tokens = tokenizer.encode(_instruction_content_header, allowed_special={"<|endoftext|>"})
          tokens.extend(_instruction_content_header_tokens)
          mask.extend([True] * len(_instruction_content_header_tokens))

          _instruction_content_content = message["content"]
          _instruction_content_content_tokens = tokenizer.encode(_instruction_content_content, allowed_special={"<|endoftext|>"})
          tokens.extend(_instruction_content_content_tokens)
          mask.extend([True] * len(_instruction_content_content_tokens))
        else:
          _instruction_response_header = f"\n\n### Response:\n"
          _instruction_response_header_tokens = tokenizer.encode(_instruction_response_header, allowed_special={"<|endoftext|>"})
          tokens.extend(_instruction_response_header_tokens)
          mask.extend([True] * len(_instruction_response_header_tokens))

          # Dont mask the assistant response
          _instruction_response_content = message["content"] + "\n\n"
          _instruction_response_content_tokens = tokenizer.encode(_instruction_response_content, allowed_special={"<|endoftext|>"})
          tokens.extend(_instruction_response_content_tokens)
          mask.extend([False] * len(_instruction_response_content_tokens))

      self.encoded_texts.append(tokens)
      self.masks.append(mask)

      
  def __getitem__(self, index):
    return self.encoded_texts[index], self.masks[index]
  
  def __len__(self):
    return len(self.data)

In [8]:
def custom_collate_fn(
   batch,
   pad_token_id=50256,
   ignore_index=-100,
   allowed_max_length=None,
   device="cpu" 
):
  tokens_batch = [item[0] for item in batch]
  masks_batch = [item[1] for item in batch]

  # find the longest equence in the batch
  batch_max_length = max(len(tokens) + 1 for tokens in tokens_batch)

  inputs_lst, targets_lst = [], []

  for tokens, mask in zip(tokens_batch, masks_batch):
    new_tokens = tokens.copy()
    new_tokens += [pad_token_id]
    padded_tokens = (
      new_tokens + ([pad_token_id] * (batch_max_length - len(new_tokens)))
    )

    new_mask = mask.copy()
    new_mask += [True] # mask the added padded token
    padded_mask = (
      new_mask + ([True] * (batch_max_length - len(new_mask)))
    )

    inputs = torch.tensor(padded_tokens[:-1])
    targets = torch.tensor(padded_tokens[1:])

    pad_mask = targets == pad_token_id
    indices = torch.nonzero(pad_mask).squeeze()

    if indices.numel() > 1:
      targets[indices[1:]] = ignore_index

    # add the mask
    for j in range(min(len(padded_mask), len(targets))):
      if j + 1 < len(padded_mask) and padded_mask[j + 1]:
        targets[j] = ignore_index

    if allowed_max_length is not None:
      inputs = inputs[:allowed_max_length]
      targets = targets[:allowed_max_length]

    inputs_lst.append(inputs)
    targets_lst.append(targets)

  inputs_tensor = torch.stack(inputs_lst).to(device)
  targets_tensor = torch.stack(targets_lst).to(device)

  return inputs_tensor, targets_tensor

In [9]:
device = "cuda:1"


torch.cuda.empty_cache()

In [10]:
from functools import partial

customized_collate_fn = partial(
  custom_collate_fn,
  device=device,
  allowed_max_length=1024
)

In [14]:
from torch.utils.data import DataLoader
import tiktoken

tokenizer = tiktoken.get_encoding("gpt2")

num_workers = 0
batch_size = 32 

train_dataset = InstructionDataset(train_data, tokenizer)
train_loader = DataLoader(
  train_dataset,
  batch_size=batch_size,
  collate_fn=customized_collate_fn,
  shuffle=True,
  drop_last=True,
  num_workers=num_workers
)

val_dataset = InstructionDataset(val_data, tokenizer)
val_loader = DataLoader(
  val_dataset,
  batch_size=batch_size,
  collate_fn=customized_collate_fn,
  shuffle=False,
  drop_last=False,
  num_workers=num_workers
)

test_dataset = InstructionDataset(test_data, tokenizer)
test_loader = DataLoader(
  test_dataset,
  batch_size=batch_size,
  collate_fn=customized_collate_fn,
  shuffle=False,
  drop_last=False,
  num_workers=num_workers
)

In [15]:

print("Train loader:")
i = 0
for inputs, targets in train_loader:
    if i == 10:
        break
    print(inputs.shape, targets.shape)
    i+=1


Train loader:
torch.Size([32, 857]) torch.Size([32, 857])
torch.Size([32, 775]) torch.Size([32, 775])
torch.Size([32, 980]) torch.Size([32, 980])
torch.Size([32, 875]) torch.Size([32, 875])
torch.Size([32, 975]) torch.Size([32, 975])
torch.Size([32, 932]) torch.Size([32, 932])
torch.Size([32, 914]) torch.Size([32, 914])
torch.Size([32, 860]) torch.Size([32, 860])
torch.Size([32, 1014]) torch.Size([32, 1014])
torch.Size([32, 995]) torch.Size([32, 995])


In [16]:
print(inputs[0]) 
print(targets[0])

tensor([21106,   318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,
          257,  2882,   326, 20431, 32543,   262,  2581,   198,   198, 21017,
        46486,    25,   198, 10814,    11,   466,   345,   760,  1521,   262,
         6766,   318,  4171,    30,   198,   198, 21017, 18261,    25,   198,
         5297,    11,   262,  6766,  3568,  4171,   780,   286,  7760, 42342,
        45765,    13,   770, 10733,  8833,   618,   262,  3668,   338,  8137,
          629, 34387, 19606,    11,   290,  4171,  1657,   318, 16830,   287,
          477, 11678,   517,   621,   584,  7577,   780,   340, 17781,   287,
        12238,    11,  4833,  9813,    13,   628,   198,   198, 21017, 46486,
           25,   198,  2504,   338,  3499,     0,  1867,   546,  4252, 28709,
          852,  2266,   290, 10912,    30,   198,   198, 21017, 18261,    25,
          198,  7191, 26428,    11,   262,  4252,   318,  2793,   287,   262,
         6766,    11,   290,   262,  1657,   468,   284,  1208, 

In [None]:
from scripts.train import calc_loss_loader
from scripts.model_loader import load_model_from_path


model = load_model_from_path(
  "models/10b/gpt2-355M-bfloat16.pth",
  device=device 
)
model.eval()

model = model.to(device).to(torch.bfloat16)
model = torch.compile(model)

with torch.no_grad():
  train_loss = calc_loss_loader(train_loader, model, device, num_batches=10)
  val_loss = calc_loss_loader(val_loader, model, device, num_batches=10)

print("Training loss", train_loss)
print("Validation loss", val_loss)

In [17]:
torch.cuda.empty_cache()

Epoch 1 - 5e-5
Epoch 2 - 2e-5 <-- good!
Epoch 3 - 1e-5 <-- diminishing returns. start to see reptitions

In [None]:
import time
import tiktoken
import torch
from scripts.train import train_model_simple
from scripts.train import train_model_simple
from scripts.model_loader import load_model_from_path
from scripts.fine_tune import format_input

model = load_model_from_path(
  "models/10b/gpt2-355M-bfloat16.pth",
  device=device 
)

model = model.to(device).to(torch.bfloat16)
model.train()

tokenizer = tiktoken.get_encoding("gpt2")

start_time = time.time()

optimizer = torch.optim.AdamW(
  model.parameters(),
  lr=5e-5,
  weight_decay=0.1,
  fused=True
)

num_epochs=1

double_new_line_id = tokenizer.encode("\n\n", allowed_special={"<|endoftext|>"})[0]

train_losses, val_losses = train_model_simple(
  model=model,
  train_loader=train_loader,
  val_loader=val_loader,
  optimizer=optimizer,
  num_epochs=num_epochs,
  eval_freq=100,
  eval_iter=50,
  start_context=format_input(val_data[0]),
  tokenizer=tokenizer,
  device=device,
  save_iters=200,
  stop_sequence=[double_new_line_id]
)

end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")

In [130]:
from scripts.gpt2_common import save_model_and_optimizer
model_directory = "/home/rngo/code/ttnn-sandbox/notebooks/models"
save_model_and_optimizer(
  model_path=f"{str(model_directory)}/gpt2-355M-model-it-ep1-long-v3.pth",
  model=model,
  optimizer_path=f"{str(model_directory)}/optimizer-gpt2-355M-model-it-ep1-long-v3.pth",
  optimizer=optimizer
)

In [None]:
import tiktoken
from scripts.generate import generate
from scripts.util import text_to_token_ids, token_ids_to_text

checkpoints = [0, 500, 1000, 1500, 2000]

tokenizer = tiktoken.get_encoding("gpt2")
for checkpoint in checkpoints:
  print("-" * 20)
  if checkpoint // 4000 == 0:
    ep = 1
  else:
    ep = 2

  print(f"Testing a message at checkpoint: {checkpoint}, Epoch: {ep}")
  model = load_model_from_path(f"models/checkpoint-model-ep{ep}-{checkpoint}.pth", device)
  model.eval()

  for i, test in enumerate(test_data[:3]):
    print(f"## Test message: {i}")
    test_message = format_input(test)

    token_ids = generate(
      model,
      idx=text_to_token_ids(test_message, tokenizer).to(device),
      max_new_tokens=256,
      context_size=1024,
      temperature=1.0,
      top_k=20,
      eos_id=50256,
      device=device
    )

    text = token_ids_to_text(token_ids, tokenizer)
    print("-" * 20)
    print(text)
    print()




# Continued Fine Tuning - More Facts