In [1]:
import torch
import torch.nn as nn
from dataclasses import dataclass

In [2]:
@dataclass
class ModelArgs:
    device: str = 'cuda'
    epochs: int = 5
    max_lr: float = 2e-5
    rank: int = 4

In [3]:
!pip install datasets
# !pip install evaluate



In [4]:

import os
from transformers import AutoTokenizer, AutoModelForCausalLM

from google.colab import userdata
HF_TOKEN=userdata.get('HF_TOKEN')

model_id = "openai-community/gpt2"


tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", token=HF_TOKEN)

In [5]:
model.config

GPT2Config {
  "_attn_implementation_autoset": true,
  "_name_or_path": "openai-community/gpt2",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.47.1",
  "use_cache": true,
  "vocab_size": 50257
}

In [6]:
#Collab setup
from pathlib import Path
data_path = Path('/content/data')
data_path.mkdir(exist_ok=True)
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
!cp input.txt data/input.txt


--2025-01-24 23:42:04--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.2’


2025-01-24 23:42:05 (21.7 MB/s) - ‘input.txt.2’ saved [1115394/1115394]



In [7]:
#Datasets

# Using tinyshakespeare

with open('/content/data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

####################################################################

#Using BookCorpus
# from datasets import load_dataset
# data = load_dataset('bookcorpus/bookcorpus')

In [8]:

#Loading custom trained BPE
# Load the tokenizer
# tokenizer = Tokenizer.from_file("bpe_tokenizer_tinyshakespeare_20k.json")
# vocab_size = tokenizer.get_vocab_size()
# Encode and decode functions
# encode = lambda s: tokenizer.encode(s).ids
# decode = lambda l: tokenizer.decode(l)

###############################################################################
#Character level tokenization

# # here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)


# create a mapping from characters to integers
stoi = { ch: i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string


In [9]:
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
block_size = 1024
batch_size = 4
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(ModelArgs.device), y.to(ModelArgs.device)
    return x, y

In [10]:
count = 0

count = sum(p.numel() for p in model.parameters() if p.requires_grad)

for param in model.parameters():
    param.requires_grad = False



In [11]:
class LoRALayer(nn.Module):
    def __init__(self) -> None:
        super().__init__()


        self.rank = ModelArgs.rank
        self.model_weight_dims = model.config.n_embd
        self.query_A = nn.Parameter(torch.ones((self.model_weight_dims, self.rank), requires_grad=True))
        self.query_B = nn.Parameter(torch.zeros((self.rank, self.model_weight_dims), requires_grad=True))
        self.key_A = nn.Parameter(torch.ones((self.model_weight_dims, self.rank), requires_grad=True))
        self.key_B = nn.Parameter(torch.zeros((self.rank, self.model_weight_dims), requires_grad=True))
        self.value_A = nn.Parameter(torch.ones((self.model_weight_dims, self.rank), requires_grad=True))
        self.value_B = nn.Parameter(torch.zeros((self.rank, self.model_weight_dims), requires_grad=True))
        self.output_A = nn.Parameter(torch.ones((self.model_weight_dims, self.rank), requires_grad=True))
        self.output_B = nn.Parameter(torch.zeros((self.rank, self.model_weight_dims), requires_grad=True))
        # self.linear_q = nn.Linear(in_features=model.config.n_ctx, out_features=self.model_weight_dims, bias=False)
        # self.linear_k = nn.Linear(in_features=model.config.n_ctx, out_features=self.model_weight_dims, bias=False)
        # self.linear_v = nn.Linear(in_features=model.config.n_ctx, out_features=self.model_weight_dims, bias=False)
        # self.linear_o = nn.Linear(in_features=model.config.n_ctx, out_features=self.model_weight_dims, bias=False)
        torch.nn.init.normal_(self.query_A, mean=0.0, std=1)
        torch.nn.init.normal_(self.key_A, mean=0.0, std=1)
        torch.nn.init.normal_(self.output_A, mean=0.0, std=1)
        torch.nn.init.normal_(self.value_A, mean=0.0, std=1)


    def forward(self, w_o, q_o, k_o, v_o):
        # print((self.output_B).shape)
        final_weight_WO = w_o + self.output_B.T @ self.output_A.T
        final_weight_QO = q_o + self.query_B.T @ self.query_A.T
        final_weight_KO = k_o + self.key_B.T @ self.key_A.T
        final_weight_VO = v_o + self.value_B.T @ self.value_A.T
        # out_q = self.linear_q(final_weight_QO)
        # out_k = self.linear_k(final_weight_KO)
        # out_v = self.linear_v(final_weight_VO)

        # out_o
        #
        #
        # = self.linear_o(final_weight_WO)

        return final_weight_WO, final_weight_QO , final_weight_KO, final_weight_VO

In [None]:
class LoRAWrapper(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.lora_layer = LoRALayer()
        # self.linear = nn.Linear(in_features=model.config.vocab_size, out_features=2)
        self.config = model.config

    def forward(self, x):
        qkv_layers = [model.transformer.h[i].attn.c_attn for i in range(self.config.n_layer)]
        o_layers = [model.transformer.h[i].attn.c_proj for i in range(self.config.n_layer)]

        for i in range(len(qkv_layers)):
            hidden_size = qkv_layers[i].weight.size(-1) // 3
            Q, K, V = torch.split(qkv_layers[i].weight, hidden_size, dim=-1)
            O = o_layers[i].weight
            out_o, out_q, out_k, out_v = self.lora_layer(O,Q,K,V)
            combined_qkv = torch.concat([out_q, out_k, out_v], dim=-1)
            # print(combined_qkv.shape)
            # Update the model's attention weights
            # # with torch.no_grad():
            # qkv_layers[i].weight.copy_(combined_qkv)
            # o_layers[i].weight.copy_(out_o)
            # Assign the updated weights back to the model
            qkv_layers[i].weight.data.copy_(combined_qkv)
            o_layers[i].weight.data.copy_(out_o)
        return model(x)


In [13]:
class LoRAModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.lora_wrapper = LoRAWrapper()

        self.config = model.config

    def forward(self,x):

        out = self.lora_wrapper(x)

        return out


In [14]:
lora_model = LoRAModel()
lora_model.to(ModelArgs.device)

LoRAModel(
  (lora_wrapper): LoRAWrapper(
    (lora_layer): LoRALayer()
  )
)

In [16]:
# final_model = outputLayer()
# final_model.to(ModelArgs.device)

In [17]:
#Printing a summary of the architecture
!pip install torchinfo
from torchinfo import summary
input_ids = torch.randint(
    0, model.config.vocab_size,
    (1, model.config.n_ctx)
).to(ModelArgs.device)
# idx = idx.to(device)
summary(model=lora_model,
        input_data=input_ids,
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])



Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
LoRAModel (LoRAModel)                    [1, 1024]            [1, 12, 1024, 64]    --                   True
├─LoRAWrapper (lora_wrapper)             [1, 1024]            [1, 12, 1024, 64]    --                   True
│    └─LoRALayer (lora_layer)            [768, 768]           [768, 768]           24,576               True
│    └─LoRALayer (lora_layer)            [768, 768]           [768, 768]           (recursive)          True
│    └─LoRALayer (lora_layer)            [768, 768]           [768, 768]           (recursive)          True
│    └─LoRALayer (lora_layer)            [768, 768]           [768, 768]           (recursive)          True
│    └─LoRALayer (lora_layer)            [768, 768]           [768, 768]           (recursive)          True
│    └─LoRALayer (lora_layer)            [768, 768]           [768, 768]           (recursive)          True
│    └─LoRALay

In [18]:
#Printing a summary of the architecture
from torchinfo import summary
input_ids = torch.randint(
    0, model.config.vocab_size,
    (1, model.config.n_ctx)
).to(ModelArgs.device)
# idx = idx.to(device)
summary(model=model,
        input_data=input_ids,
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                            Input Shape          Output Shape         Param #              Trainable
GPT2LMHeadModel (GPT2LMHeadModel)                  [1, 1024]            [1, 12, 1024, 64]    --                   False
├─GPT2Model (transformer)                          [1, 1024]            [1, 12, 1024, 64]    --                   False
│    └─Embedding (wte)                             [1, 1024]            [1, 1024, 768]       (38,597,376)         False
│    └─Embedding (wpe)                             [1, 1024]            [1, 1024, 768]       (786,432)            False
│    └─Dropout (drop)                              [1, 1024, 768]       [1, 1024, 768]       --                   --
│    └─ModuleList (h)                              --                   --                   --                   False
│    │    └─GPT2Block (0)                          [1, 1024, 768]       [1, 1024, 768]       (7,087,872)          False
│    │    └─GPT2Block (1)              

In [19]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Total trainable parameters:", count_parameters(lora_model) ," which is: " , (count_parameters(lora_model) / 163037184 )*100 , "%\ of" , 163037184 , "trainable params")

Total trainable parameters: 24576  which is:  0.015073861923424782 %\ of 163037184 trainable params


In [20]:
# Optimizer setup and scheduler steup

optimizer = torch.optim.AdamW(lora_model.parameters(), lr=ModelArgs.max_lr)
# optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, weight_decay=weight_decay_optim)
initial_iters = 2000
total_steps = 10000
eval_iters = 100

@torch.inference_mode()
def estimate_loss():
    out = {}
    lora_model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            idx, targets = get_batch(split=split)
            logits = lora_model(idx).logits
            batch_size, block_size, embeddings_dims = logits.shape
            logits = logits.view(batch_size*block_size, embeddings_dims) # Total tokens(words) => batch_size * block_size
            targets = targets.view(batch_size * block_size)
            loss = nn.functional.cross_entropy(logits, targets)
            losses[k] = loss.item()
        out[split] = losses.mean()
    lora_model.train()
    return out

In [None]:
#Train the  model
from tqdm import tqdm

lora_model.train()
for step in tqdm(range(total_steps)):

    # every once in a while evaluate the loss on train and val sets
    if (step  % eval_iters == 0 and step != 0) or step == total_steps - 1:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")


    idx, targets = get_batch(split='train')
    logits = lora_model(idx).logits
    batch_size, block_size, embeddings_dims = logits.shape
    logits = logits.view(batch_size*block_size, embeddings_dims)
    targets = targets.view(batch_size * block_size)
    loss = nn.functional.cross_entropy(logits, targets)
    # print(loss.requires_grad)
    loss.requires_grad = True
    # print(count_parameters(lora_model))
    # print(loss.requires_grad)
    # break
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    # print(loss.item())
    # break

    # if step != 0 and (step % eval_iters == 0 or step == total_steps -1) :
    #     loss_values = estimate_loss()
    #     print("Train Loss at {} steps : {}".format(step, loss.item()), "Val Loss at {} steps : {}".format(step, loss_values['val']))

  1%|          | 101/10000 [01:49<60:29:47, 22.00s/it]

step 100: train loss 3.5244, val loss 3.4753


  1%|          | 111/10000 [01:52<2:46:03,  1.01s/it]

In [None]:
for name, param in lora_model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")
