<a href="https://colab.research.google.com/github/soumyadip1995/BabyGPT/blob/main/Notebook/lora.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math
import torch
import torch.nn as nn
from torch.nn import functional as F 
from math import sqrt



torch.manual_seed(1337)
class Attention(nn.Module):
  def __init__(self, embedded_dim, num_heads, rank):
    super(Attention, self).__init__()
    self.rank = rank
    self.atten = nn.Linear(embedded_dim, 3 * embedded_dim)
    self.projection = nn.Linear(embedded_dim, embedded_dim)
    self.num_heads = num_heads
    self.embedded_dim = embedded_dim
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    self.W_A = nn.Parameter(torch.empty(embedded_dim, rank))




  def forward(self, x):
    B,T,C = x.size()
    q, k ,v  = self.atten(x).split(self.embedded_dim, dim=2)
    q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
    k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
    v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
    

    
    # manual implementation of attention
    # from karpathy
    att = (q @ k.transpose(-2, -1)) * ((1.0 / math.sqrt(k.size(-1)) * self.rank))
    att = att.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
    att = F.softmax(att, dim=-1)
    y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    y = att @ (self.W_A)
    y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

    # output projection
    y = self.projection(y)
    return y

In [2]:
dropout = 0.2
class FeedForward(nn.Module):
  def __init__(self, embedded_dim):
    super(FeedForward, self).__init__()
    self.net = nn.Sequential(nn.Linear(embedded_dim, 4 * embedded_dim),
    nn.Linear(4 * embedded_dim, embedded_dim),
    nn.GELU(),
    nn.Dropout(dropout))

  def forward(self, x):
    return self.net(x)

In [3]:
### A simple Transformer Block    
class Transformer(nn.Module):
  def __init__(self, embedded_dim, num_heads, rank):
    super(Transformer, self).__init__()
    self.attention = Attention(embedded_dim, num_heads, rank)
    self.feed_forward = FeedForward(embedded_dim)
    self.layer_norm_1 = nn.LayerNorm(embedded_dim)
    self.layer_norm_2 = nn.LayerNorm(embedded_dim)

  def forward(self, x):
    
    x = x + self.attention(self.layer_norm_1(x))
    x = x + self.feed_forward(self.layer_norm_2(x))
    return x

In [4]:
class BabyGPTmodel(nn.Module):
  def __init__(self, vocab_size, block_size, num_layers, embedded_dim, num_heads, rank):
    super(BabyGPTmodel, self).__init__()
    self.token = nn.Embedding(vocab_size, embedded_dim)
    self.positional_embeddings = nn.Embedding(block_size, embedded_dim)
    self.layers1 = nn.ModuleList([Transformer(embedded_dim, num_heads, rank) for _ in range(num_layers)])
    self.ln_f = nn.LayerNorm(embedded_dim, eps = 1e-12) # final layer 
    self.ln_head = nn.Linear(embedded_dim, vocab_size)


    # init all weights
    ## from karpathy
    self.apply(self._init_weights)
    # apply special scaled init to the residual projections, per GPT-2 paper
    for pn, p in self.named_parameters():
      if pn.endswith('projection.weight'):
        torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * num_layers))

        # report number of parameters
        print("number of parameters: %d" % (sum(p.nelement() for p in self.parameters()),))

  def _init_weights(self, module):
      if isinstance(module, nn.Linear):
          torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
          if module.bias is not None:
              torch.nn.init.zeros_(module.bias)
      elif isinstance(module, nn.Embedding):
          torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

  def forward(self, idx):
    device = idx.device
    b, t = idx.size()
    tok_emb = self.token(idx)
    position_ids = torch.arange(0, t, dtype = torch.long).unsqueeze(0)
    pos_emb = self.positional_embeddings(position_ids)
    x = tok_emb + pos_emb
    for layers1 in self.layers1:
      x = layers1(x)
    x = self.ln_f(x)
    logits = self.ln_head(x[:, -1, :])
    return logits


In [5]:
vocab_size =  478
block_size = 4
embedded_dim = 16
num_heads = 4
num_layers = 4
rank = 4
gpt = BabyGPTmodel(vocab_size, block_size, num_layers, embedded_dim, num_heads, rank)


number of parameters: 29246
number of parameters: 29246
number of parameters: 29246
number of parameters: 29246


A comparison between BabyGPT and Low rank adaptation. BabyGPT is b/w  28k-29k parametres. Low rank adaptation improves the parametre efficiency. 15k parametres. 

Note: Parametre size is also directly linked to context length.

In [None]:
input_dim = 16

W_A = nn.Parameter(torch.empty(input_dim, rank))
W_A.shape

In [None]:
import torch
import torch.nn as nn

class LowRankAttention(nn.Module):
    def __init__(self, dim, rank):
        super(LowRankAttention, self).__init__()
        self.rank = rank
        self.Wq = nn.Linear(dim, rank, bias=False)
        self.Wk = nn.Linear(dim, rank, bias=False)
        self.Wv = nn.Linear(dim, rank, bias=False)
        self.Wo = nn.Linear(rank, dim, bias=False)

    def forward(self, q, k, v):
        Q = self.Wq(q)
        K = self.Wk(k)
        V = self.Wv(v)

        # Compute the attention scores using low-rank approximation
        A = torch.bmm(Q, K.transpose(-2, -1)) / (self.rank ** 0.5)

        # Softmax along the key dimension
        A = torch.softmax(A, dim=-1)

        # Compute the attention-weighted values using low-rank approximation
        AV = torch.bmm(A, V)

        # Apply the output layer to the attention-weighted values
        out = self.Wo(AV)

        return out

class LowRankTransformerLayer(nn.Module):
    def __init__(self, dim, rank, dropout=0.2):
        super(LowRankTransformerLayer, self).__init__()
        self.attention = LowRankAttention(dim, rank)
        self.norm1 = nn.LayerNorm(dim)
        self.dropout1 = nn.Dropout(dropout)
        self.feedforward = nn.Sequential(
            nn.Linear(dim, dim * 3),
            nn.GELU(),
            nn.Linear(dim * 3, dim)
        )
        self.norm2 = nn.LayerNorm(dim)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        # Compute the self-attention layer
        attention_out = self.attention(x, x, x)

        # Add residual connection and normalize
        x = self.norm1(x + self.dropout1(attention_out))

        # Feed-forward layer
        ff_out = self.feedforward(x)

        # Add residual connection and normalize
        x = self.norm2(x + self.dropout2(ff_out))

        return x

class LowRankTransformer(nn.Module):
    def __init__(self, vocab_size, num_layers, dim, rank, num_heads, dropout= 0.2):
        super(LowRankTransformer, self).__init__()
        self.layers = nn.ModuleList([LowRankTransformerLayer(dim, rank, dropout) for _ in range(num_layers)])
        self.num_layers = num_layers
        self.dim = dim
        self.rank = rank
        self.num_heads = num_heads
        self.pos_embedding = nn.Embedding(vocab_size, dim)
        self.dropout = nn.Dropout(dropout)




        # init all weights
        ## from karpathy
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
          if pn.endswith('Wo.weight'):
            torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * num_layers))

            # report number of parameters
            print("number of parameters: %d" % (sum(p.nelement() for p in self.parameters()),))

    def _init_weights(self, module):
      if isinstance(module, nn.Linear):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        if module.bias is not None:
          torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
          torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x):
        # Add positional embeddings
        x = x + self.pos_embedding[:, :x.size(1)]

        # Apply dropout
        x = self.dropout(x)

        # Apply the transformer layers
        for layer in self.layers:
            x = layer(x)

        return x

        


In [None]:
lrt = LowRankTransformer(478, 4, 16, 4, 8)

number of parameters: 15328
number of parameters: 15328
number of parameters: 15328
number of parameters: 15328


In [None]:
words = open(r"/content/text.txt", 'r', encoding='utf-8').read().split()

chars = sorted(list(set(words)))
vocab_size = len(chars)
lrt = LowRankTransformer(vocab_size, 4, 16, 4, 8)

number of parameters: 15328
number of parameters: 15328
number of parameters: 15328
number of parameters: 15328
