In [1]:
from SimpleGPT import SimpleGPT
from llm_utils import text_generation, train
from InstructionDataset import InstructionDataset, collate_fn, to_alpaca_format
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from functools import partial
import tiktoken
import json
import tensorflow as tf
import numpy as np
import os
import json

In [2]:
MODEL_DIR = "./117M"
USE_MLA = True
GPT2_CONFIG_124M = {
     "vocab_size": 50257,
     "context_length": 256,
     "emb_dim": 768,
     "n_heads": 12,
     "n_layers": 12,
     "dropout_rate": 0.1,
     'batch_size': 12,
     'd_R': 16, # For Multihead Latent Attention
     'd_c': 256, # For Multihead Latent Attention
     "qkv_bias": False 
}
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = SimpleGPT(GPT2_CONFIG_124M, use_mla=USE_MLA)

In [3]:
print(f'Using {DEVICE}')

Using cuda


In [4]:
# Since the original model is trained with TensorFlow, we read it using TensorFlow

tf_checkpoint_path = tf.train.latest_checkpoint(MODEL_DIR)
tf_model_settings = json.load(open(os.path.join(MODEL_DIR, "hparams.json")))

params = {"blocks": [{} for _ in range(tf_model_settings['n_layer'])]}

for name, _ in tf.train.list_variables(tf_checkpoint_path):
    var = np.squeeze(tf.train.load_variable(tf_checkpoint_path, name))
    var_name_no_prefix = name.split("/")[1:]
    
    target_dict = params
    
    if var_name_no_prefix[0].startswith("h"):
        layer_id = int(var_name_no_prefix[0][1:])
        target_dict = params["blocks"][layer_id]
    
    for key in var_name_no_prefix[1:-1]:
        target_dict = target_dict.setdefault(key, {})
    
    last_key = var_name_no_prefix[-1]
    target_dict[last_key] = var

In [5]:
model.pos_emb.weight = nn.Parameter(torch.tensor(params['wpe']))
model.token_emb.weight = nn.Parameter(torch.tensor(params['wte']))

# Load pretrained model parameters
for i in range(len(params['blocks'])):
    q_w, k_w, v_w = np.split((params['blocks'][i]['attn']['c_attn'])['w'], 3, axis=-1)
    q_b, k_b, v_b = np.split((params['blocks'][i]['attn']['c_attn'])['b'], 3, axis=-1)
    
    if not USE_MLA:
        model.transformer_blocks[i].attn.W_q.weight = nn.Parameter(torch.tensor(q_w.T))
        model.transformer_blocks[i].attn.W_k.weight = nn.Parameter(torch.tensor(k_w.T))
        model.transformer_blocks[i].attn.W_v.weight = nn.Parameter(torch.tensor(v_w.T))
        
        model.transformer_blocks[i].attn.W_q.bias = nn.Parameter(torch.tensor(q_b))
        model.transformer_blocks[i].attn.W_k.bias = nn.Parameter(torch.tensor(k_b))
        model.transformer_blocks[i].attn.W_v.bias = nn.Parameter(torch.tensor(v_b))
    
    model.transformer_blocks[i].attn.output_projection.weight = nn.Parameter(torch.tensor(params['blocks'][i]['attn']['c_proj']['w'].T))
    model.transformer_blocks[i].attn.output_projection.bias = nn.Parameter(torch.tensor(params['blocks'][i]['attn']['c_proj']['b']))
    
    model.transformer_blocks[i].ff_block[1].weight = nn.Parameter(torch.tensor(params['blocks'][i]['mlp']['c_fc']['w'].T))
    model.transformer_blocks[i].ff_block[1].bias = nn.Parameter(torch.tensor(params['blocks'][i]['mlp']['c_fc']['b']))
    model.transformer_blocks[i].ff_block[3].weight = nn.Parameter(torch.tensor(params['blocks'][i]['mlp']['c_proj']['w'].T))
    model.transformer_blocks[i].ff_block[3].bias = nn.Parameter(torch.tensor(params['blocks'][i]['mlp']['c_proj']['b']))
    
    model.transformer_blocks[i].attn_block[0].scale = nn.Parameter(torch.tensor(params['blocks'][i]['ln_1']['g']))
    model.transformer_blocks[i].attn_block[0].shift = nn.Parameter(torch.tensor(params['blocks'][i]['ln_1']['b']))
    model.transformer_blocks[i].ff_block[0].scale = nn.Parameter(torch.tensor(params['blocks'][i]['ln_2']['g']))
    model.transformer_blocks[i].ff_block[0].shift = nn.Parameter(torch.tensor(params['blocks'][i]['ln_2']['b']))

model.final_norm.scale = nn.Parameter(torch.tensor(params['g']))
model.final_norm.shift = nn.Parameter(torch.tensor(params['b']))
model.out.weight = nn.Parameter(torch.tensor(params['wte']))

In [6]:
print(model)

SimpleGPT(
  (token_emb): Embedding(50257, 768)
  (pos_emb): Embedding(256, 768)
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer_blocks): Sequential(
    (0): Transformer(
      (attn): MultiheadLatentAttention(
        (W_DKV): Linear(in_features=768, out_features=256, bias=False)
        (W_DQ): Linear(in_features=768, out_features=256, bias=False)
        (W_UK): Linear(in_features=256, out_features=768, bias=False)
        (W_UV): Linear(in_features=256, out_features=768, bias=False)
        (W_UQ): Linear(in_features=256, out_features=768, bias=False)
        (W_KR): Linear(in_features=768, out_features=16, bias=False)
        (W_QR): Linear(in_features=256, out_features=192, bias=False)
        (output_projection): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (attn_block): Sequential(
        (0): LayerNorm()
        (1): MultiheadLatentAttention(
          (W_DKV): Linear(in_features=768, out_featur

In [7]:
model.to(DEVICE)
tokenizer = tiktoken.get_encoding("gpt2")

model.eval()

with torch.no_grad():
    response = text_generation(
        model=model,
        query="Who is the first president of the United States?",
        tokenizer=tokenizer,
        max_generated_tokens=30,
        context_size=GPT2_CONFIG_124M['context_length'],
        temperature=1.4,
        top_k=20,
        device=DEVICE
    )

print("Response:", response)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (120x64 and 768x16)

## Fine tuning

In [8]:
alpaca_data = json.load(open("alpaca_data.json", "r"))
print(alpaca_data[:5])

[{'instruction': 'Give three tips for staying healthy.', 'input': '', 'output': '1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.'}, {'instruction': 'What are the three primary colors?', 'input': '', 'output': 'The three primary colors are red, blue, and yellow.'}, {'instruction': 'Describe the structure of an atom.', 'input': '', 'output': 'An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom.'}, {'instruction': 'How can we reduce air pollution?', 'input': '', 'output': 'There are a number of ways to reduce air pollution, such as shifting to ren

In [9]:
train_size = int(len(alpaca_data) * 0.8)
test_size = int(len(alpaca_data) * 0.1)
val_size = len(alpaca_data) - train_size - test_size

dataset = InstructionDataset(alpaca_data, tokenizer)
train_dataset, test_dataset, val_dataset = random_split(dataset, [train_size, test_size, val_size])

# Pre-fill the collate function with the context length
collate_fn = partial(collate_fn, max_context_length=GPT2_CONFIG_124M['context_length'])

train_loader = DataLoader(train_dataset, 
                          batch_size=GPT2_CONFIG_124M['batch_size'],
                          collate_fn=collate_fn,
                          shuffle=True,
                          num_workers=8,
                          pin_memory=True)

test_loader = DataLoader(test_dataset,
                         batch_size=GPT2_CONFIG_124M['batch_size'],
                         collate_fn=collate_fn,
                         shuffle=False,
                         num_workers=8,
                         pin_memory=True)

val_loader = DataLoader(val_dataset,
                        batch_size=GPT2_CONFIG_124M['batch_size'],
                        collate_fn=collate_fn,
                        shuffle=False,
                        num_workers=8,
                        pin_memory=True)

In [10]:
print(next(iter(train_loader)))

[tensor([[21106,   318,   281,  ..., 50256, 50256, 50256],
        [21106,   318,   281,  ..., 50256, 50256, 50256],
        [21106,   318,   281,  ..., 50256, 50256, 50256],
        ...,
        [21106,   318,   281,  ..., 26411,     0,   628],
        [21106,   318,   281,  ..., 50256, 50256, 50256],
        [21106,   318,   281,  ..., 50256, 50256, 50256]]), tensor([[  318,   281, 12064,  ...,  -100,  -100,  -100],
        [  318,   281, 12064,  ...,  -100,  -100,  -100],
        [  318,   281, 12064,  ...,  -100,  -100,  -100],
        ...,
        [  318,   281, 12064,  ...,     0,   628, 50256],
        [  318,   281, 12064,  ...,  -100,  -100,  -100],
        [  318,   281, 12064,  ...,  -100,  -100,  -100]])]


In [11]:
print(len(min(val_loader.dataset, key=lambda x: len(x))), len(max(val_loader.dataset, key=lambda x: len(x))))

37 669


In [12]:
test_text = to_alpaca_format(alpaca_data[0])

res = text_generation(model=model,
                     query=test_text,
                     tokenizer=tokenizer,
                     max_generated_tokens=50,
                     context_size=GPT2_CONFIG_124M['context_length'],
                     temperature=1.4,
                     top_k=20,
                     device=DEVICE)

print(test_text)
print(res[0][len(test_text):])

Below is an instruction that describes a task
Write a response that appropriately completes the request.

### Instruction:
Give three tips for staying healthy.



Step 2

Include as many items as possible:

A good body temperature

A well fed diet

Caffeine supplementation on active eating

Dieting a healthy diet regularly

Keeping healthy in the gym


In [13]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

In [None]:
train_losses, val_losses, tokens_seen = train(model=model,
      train_loader=train_loader,
      val_loader=test_loader,
      optimizer=optimizer,
      criterion=criterion,
      tokenizer=tokenizer,
      device=DEVICE,
      n_epochs=1,
      eval_freq=10,
      early_stop=1)

Epoch: 0 - 2, train_loss: 2.988710567965413, val_loss: 2.9911671506441557
Epoch: 0 - 4, train_loss: 2.626174870029685, val_loss: 2.6297641192949737
Epoch: 0 - 6, train_loss: 2.3467894093253663, val_loss: 2.3518476427518404
Epoch: 0 - 8, train_loss: 2.1436020390296573, val_loss: 2.149673807804401
Epoch: 0 - 10, train_loss: 1.9918972559099908, val_loss: 1.9991346045640799
Epoch: 0 - 12, train_loss: 1.9281936444312053, val_loss: 1.9356513654268706
Epoch: 0 - 14, train_loss: 1.8986405657245846, val_loss: 1.9057244638296273
Epoch: 0 - 16, train_loss: 1.8723402119122934, val_loss: 1.8795431436025178
Epoch: 0 - 18, train_loss: 1.8429103487387366, val_loss: 1.8498335020358745
Epoch: 0 - 20, train_loss: 1.8239694289853081, val_loss: 1.8318410557966966
Epoch: 0 - 22, train_loss: 1.8187418733278482, val_loss: 1.82621607707097
Epoch: 0 - 24, train_loss: 1.8062148175086454, val_loss: 1.8138213025606595
Epoch: 0 - 26, train_loss: 1.7988520615423123, val_loss: 1.8072266184366665
Epoch: 0 - 28, train_

KeyboardInterrupt: 

In [15]:
torch.save(model.state_dict(), "simple_gpt_117M.pth")
model.load_state_dict(torch.load("simple_gpt_117M.pth"))

<All keys matched successfully>

In [21]:
test_text = to_alpaca_format(alpaca_data[0])

res = text_generation(model=model,
                     query=test_text,
                     tokenizer=tokenizer,
                     max_generated_tokens=50,
                     context_size=GPT2_CONFIG_124M['context_length'],
                     temperature=1.4,
                     top_k=20,
                     device=DEVICE)

print(test_text)
print(res[0][len(test_text):])

Below is an instruction that describes a task
Write a response that appropriately completes the request.

### Instruction:
Give three tips for staying healthy.


<|endoftext|>It was just before Thanksgiving weekend, and I thought I'd start writing a story I thought I'd never have time for. I've been feeling pretty good lately and want to continue doing well. So, after an hourlong break, what would
