In [1]:
import torch
import torch.nn as nn

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [2]:
N_HEAD = 8
N_EMBD = 512
BLOCK_SIZE = 1024

BATCH_SIZE = 64
SEQ_LEN = 256

In [5]:
class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln_1 = nn.LayerNorm(N_EMBD)
        self.attn = CasualSelfAttention(N_EMBD,N_HEAD,BLOCK_SIZE)
        self.ln_2 = nn.LayerNorm(N_EMBD)
        self.mlp = MLP()

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

In [3]:
class CasualSelfAttention(nn.Module):

    def __init__(self,n_embd, n_head, block_size):
        super().__init__()
        self.n_embd = n_embd
        self.n_head = n_head
        self.block_size = block_size
        assert n_embd % n_head == 0
        self.c_attn = nn.Linear(n_embd, 3 * n_embd)
        self.c_proj = nn.Linear(n_embd, n_embd)

        self.n_head = n_head
        self.n_embd = n_embd

        self.biais = torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size,
                                                                                       block_size)

    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)  # flash attention
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y

In [4]:
class MLP(nn.Module):
    def __init__(self,n_embd=N_EMBD):
        super().__init__()
        self.n_embed = n_embd
        self.c_fc = nn.Linear(self.n_embed, self.n_embed * 4)
        self.gelu = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(self.n_embed * 4, self.n_embed)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x