In [None]:
import torch
from torch import nn

dummy_data = torch.randint(0, 19, (4, 8))

print(dummy_data)

tensor([[ 7,  3,  1,  1, 14, 11, 12,  2],
        [ 3, 15, 17, 17, 16,  2, 17,  1],
        [11, 12, 17, 12, 11,  3,  8, 18],
        [ 7, 16, 16,  0,  5,  7, 11,  6]])


In [None]:
import torch.nn.functional as F

class Attention(nn.Module):
  def __init__(self, embedding_dim):
    super().__init__()
    self.query = nn.Linear(embedding_dim, embedding_dim)
    self.keys = nn.Linear(embedding_dim, embedding_dim)
    self.values = nn.Linear(embedding_dim, embedding_dim)

  def forward(self, idx):
    # shape of idx ? [B, T, C], C = embedding dim
    B,T,C = idx.shape
    q = self.query(idx) # [B, T, C] -> [B, T, C]
    k = self.keys(idx) # [B, T, C] -> [B, T, C]
    # q @ k -> ignore batch dimension, only look at [T, C] x [T, C] <-- this wont work, since inner dims have to mathc
    # so... we transpose k from [B, T, C] -> [B, C, T]
    wei = q @ k.transpose(-2,-1) * C**-0.5 # [B, T, C] x [B, C, T] --> [B, T, T]

    tril = torch.tril(torch.ones(T, T))
    wei = wei.masked_fill(tril == 0, float('-inf'))
    # wei shape: [B, T, T]
    wei = F.softmax(wei, dim=-1)
    v = self.values(idx) # [B, T, C]
    # wei shape: [B, T, T]

    #[B, T, T] x [B, T, C] --> [B, T, C]

    out = wei @ v
    return out

class FFN(nn.Module):
  def __init__(self, embedding_dim):
    super().__init__()
    self.big = nn.Linear(embedding_dim, 4*embedding_dim)
    self.act = nn.ReLU()
    self.back_down = nn.Linear(4*embedding_dim, embedding_dim)

  def forward(self, idx):
    # idx: shape = [B, T, C]
    out_proj = self.big(idx) # shape: [B, T, 4*C]
    act = self.act(out_proj) # shape: [B, T, 4*C]
    final = self.back_down(act) # shape: [B, T, C]
    return final

class AttentionBlock(nn.Module):
  def __init__(self, embedding_dim):
    super().__init__()
    self.attn = Attention(embedding_dim)
    self.ffn = FFN(embedding_dim)

  def forward(self, idx):
    # idx shape: [B, T, C]
    attended = self.attn(idx) # shape: [B, T, C]
    ffn_d = self.ffn(attended) # shape: [B, T, C]
    return ffn_d

class GilmoreGirlsModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, max_seqlen=384, num_blocks=4):
        super().__init__()
        self.embeddings_table = nn.Embedding(vocab_size, embedding_dim)
        self.pos_emb = nn.Embedding(max_seqlen, embedding_dim)
        self.proj = nn.Linear(embedding_dim, vocab_size)
        self.blocks = nn.ModuleList([AttentionBlock(embedding_dim) for _ in range(num_blocks)])
        self.max_seqlen = max_seqlen

    def forward(self, index):
        #index shape: [B, T]
        x = self.embeddings_table(index) # B, T, C
        B, T, C = x.shape

        pe = self.pos_emb(torch.arange(T, device=index.device).repeat(B, 1))

        x = x + pe

        for block in self.blocks:
            x = x + block(x)

        logits = self.proj(x)
        return logits

In [None]:
torch.arange(4).repeat(4, 1)

tensor([[0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3]])

In [None]:
data # B, T, C

pos_emb = torch.randn_like(data) # B, T, C

embed = data + pos_emb

In [None]:
attn = Attention(16)

B, T, C = (1, 5, 16)

data = torch.randn(B, T, C)

attn(data)

tensor([[[ 0.4889,  0.0213,  0.2273, -0.5746,  0.3502,  0.2317, -0.3004,
          -0.4665,  0.0458,  0.0026, -0.1012,  0.1075,  0.0457, -0.3787,
           0.6303, -0.3164],
         [ 0.3537,  0.2263, -0.5231, -0.2912, -0.4993,  0.4828,  0.3862,
          -0.1133,  0.4472,  0.2432, -0.1466,  0.2350, -0.2628,  0.2256,
           0.1077,  1.0802],
         [ 0.3368, -0.0499, -0.3979, -0.0053, -0.1191,  0.4282,  0.2202,
          -0.2333,  0.1812,  0.0992, -0.4765,  0.1799, -0.0421,  0.2814,
           0.1335,  0.7313],
         [ 0.2808, -0.3724, -0.1599,  0.0786,  0.1022,  0.2207,  0.0816,
          -0.4529,  0.1907,  0.0626, -0.3384,  0.2092,  0.1368, -0.0632,
           0.4214,  0.0902],
         [-0.0304, -0.3717, -0.2522,  0.1635, -0.1021,  0.2009,  0.1112,
          -0.6200,  0.2763,  0.3410, -0.2824,  0.1682, -0.0302, -0.0927,
           0.4609,  0.1211]]], grad_fn=<UnsafeViewBackward0>)

In [None]:
vocab_size = 20
embedding_dim = 10

model = GilmoreGirlsModel(vocab_size, embedding_dim)

dummy_data = torch.randint(0, vocab_size, (4, 8))

embedded_data = model(dummy_data)

print(embedded_data.shape)


torch.Size([4, 8, 20])


In [None]:
embedded_data.shape

torch.Size([4, 8, 20])

In [None]:
embedded_data[0,0]

tensor([-0.4465,  1.3808, -1.5511,  0.0703,  0.3038,  0.7585,  1.4891, -0.1262,
         1.0671, -0.1898,  0.0757,  1.1880,  0.4566,  0.5843, -0.1075, -0.4119,
         0.2808,  0.2050,  1.3213, -1.0443], grad_fn=<SelectBackward0>)

In [None]:
#loss calculation

input = torch.randint(0, 19, (4, 8))
targets = torch.randint(0, 19, (4, 8))

m = GilmoreGirlsModel(20, 100, 8)

logits = m(input) #B, T, C
B, T, C = logits.shape
logits = logits.view(B*T, C)

targets = targets.view(B*T)

loss = F.cross_entropy(logits, targets)
print(loss)

#def loss_calc(idx, targets):
  #B, T, C = logits.shape


tensor(3.2074, grad_fn=<NllLossBackward0>)


In [None]:
#training loop

batch_size = 128
m = GilmoreGirlsModel(20, 100, 8)
optimizer = torch.optim.AdamW(m.parameters(), lr = 1e-3)
for steps in range(10000):
    xb, yb = torch.randint(0, 19, (4, 8)), torch.randint(0, 19, (4, 8))
    logits = m(xb) #B, T, C
    B, T, C = logits.shape
    logits = logits.view(B*T, C) #B*T, C

    yb = yb.view(B*T) #B*T
    loss = F.cross_entropy(logits, yb) #1
    # print(loss)
    loss.backward()
    optimizer.step()
  if steps%500 == 0:
      print('step', steps, 'loss', loss.item())
  optimizer.zero_grad()

print(loss.item())

NameError: name 'GilmoreGirlsModel' is not defined