In [2]:
import torch
import torch.nn.functional as F
from torch import nn

In [93]:
# Config
BATCH = 4
VOCAB_SIZE = 50257
N_EMBED = 768 
CONTEXT_SIZE = 1024
N_LAYERS = 12
HEAD_SIZE = 64
N_HEADS = int(N_EMBED // HEAD_SIZE)

#-------------------------

class FFWD(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(N_EMBED, N_EMBED*4)
        self.l2 = nn.Linear(N_EMBED*4, N_EMBED)
        self.act = nn.ReLU()
    
    def forward(self, x):
        return self.l2(self.act(self.l1(x)))


class MHA(nn.Module):
    def __init__(self):
        super().__init__()
        self.kqv = nn.Linear(N_EMBED, HEAD_SIZE*3*N_HEADS)
        self.tril = torch.tril(torch.ones((CONTEXT_SIZE, CONTEXT_SIZE)))
    
    def forward(self, x):
        B, T, C = x.shape

        k, q, v = torch.split(self.kqv(x), HEAD_SIZE*N_HEADS, -1)
        
        k = k.view(B, T, N_HEADS, HEAD_SIZE).transpose(1, 2)
        q = q.view(B, T, N_HEADS, HEAD_SIZE).transpose(1, 2)
        v = v.view(B, T, N_HEADS, HEAD_SIZE).transpose(1, 2)

        print(k.shape, q.shape, v.shape)

        wei = k @ q.transpose(-1, -2) * HEAD_SIZE**-0.5
        wei = torch.masked_fill(wei, self.tril == 0, float('-inf'))
        wei = F.softmax(wei, dim = -1)
    
        out = wei @ v
        out = out.view(BATCH, CONTEXT_SIZE, N_EMBED)
        
        return out

class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.LayerNorm(N_EMBED)
        self.mh_a = MHA()
        self.ln2 = nn.LayerNorm(N_EMBED)
        self.ffwd = FFWD()
    
    def forward(self, x):
        x = self.ln1(x)
        x = self.mh_a(x)
        x = self.ln2(x)
        x = self.ffwd(x)
        
        return x



class gpt(nn.Module):
    def __init__(self):
        super().__init__()
        pass

n = Block()


torch.Size([4, 12, 1024, 64]) torch.Size([4, 12, 1024, 64]) torch.Size([4, 12, 1024, 64])


torch.Size([4, 1024, 768])

In [89]:
x = torch.randn((BATCH, CONTEXT_SIZE, N_EMBED))

print(x.shape)


torch.Size([4, 1024, 768])


tensor([-0.8952, -0.4869, -2.2513,  ..., -1.6232, -0.6060, -1.5344])

In [63]:
# Add this to a python notebook showing the difference between the two methods

import timeit

# Define the first operation
def operation_1():
    return q.view(BATCH, CONTEXT_SIZE, N_HEADS, HEAD_SIZE).transpose(1, 2)

# Define the second operation
def operation_2():
    return q.reshape(BATCH, N_HEADS, CONTEXT_SIZE, HEAD_SIZE)

# Time the first operation
time_1 = timeit.timeit(operation_1, number=10000)

# Time the second operation
time_2 = timeit.timeit(operation_2, number=10000)

print(f"Time for operation 1: {time_1:.6f} seconds")
print(f"Time for operation 2: {time_2:.6f} seconds")

Time for operation 1: 0.049138 seconds
Time for operation 2: 24.167030 seconds
