In [1]:
import torch
from torch import nn
from torch.nn import functional as F

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [10]:
N_HEAD = 8
N_EMBD = 512
BLOCK_SIZE = 1024

BATCH_SIZE = 64
SEQ_LEN = 256

In [11]:

class CasualSelfAttention(nn.Module):

    def __init__(self,n_embd, n_head, block_size):
        super().__init__()
        self.n_embd = n_embd
        self.n_head = n_head
        self.block_size = block_size
        assert n_embd % n_head == 0
        self.c_attn = nn.Linear(n_embd, 3 * n_embd)
        self.c_proj = nn.Linear(n_embd, n_embd)

        self.n_head = n_head
        self.n_embd = n_embd

        self.biais = torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size,
                                                                                       block_size)

    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)  # flash attention
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y

In [12]:
x_example = torch.randn(BATCH_SIZE, SEQ_LEN, N_EMBD)
model = CasualSelfAttention(N_EMBD, N_HEAD, BLOCK_SIZE)

output = model(x_example)
print("Output shape:", output.shape) 

Output shape: torch.Size([64, 256, 512])


In [13]:
biais = torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)).view(1, 1, BLOCK_SIZE,BLOCK_SIZE)
print("Biais shape:", biais.shape)  # Should be (1, 1, BLOCK_SIZE, BLOCK_SIZE)


Biais shape: torch.Size([1, 1, 1024, 1024])


In [None]:
m1 = torch.ones(4,4)
m1

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])

In [None]:
m2 = torch.tril(m1)
m2

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [None]:
k = torch.randn(64,8,256,512)  # (B, nh, T, hs)
q = torch.randn(64,8,256,512)  # (B, nh, T, hs)
v = torch.randn(64,8,256,512)  # (B, nh, T, hs)
k.shape, q.shape, v.shape

(torch.Size([64, 8, 256, 512]),
 torch.Size([64, 8, 256, 512]),
 torch.Size([64, 8, 256, 512]))

In [None]:
attention = F.scaled_dot_product_attention(q, k, v, is_causal=True)
attention.shape  # Should be (64, 8, 256, 512)

torch.Size([64, 8, 256, 512])