In [1]:
import torch
import torch.nn as nn

print(torch.__version__)

torch.manual_seed(123)

2.4.0


<torch._C.Generator at 0x7bff2110e390>

### Basic torch functions

In [2]:
x = torch.tensor(
    [[0.11, 0.12, 0.13, 0.14],
     [0.21, 0.22, 0.23, 0.24],
     [0.31, 0.32, 0.33, 0.34]]
)
print(x.shape)

torch.Size([3, 4])


In [3]:
torch.empty(3,4), torch.zeros(x.shape), torch.ones(x.shape), torch.rand(x.shape), torch.randn(3, 4)

(tensor([[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]),
 tensor([[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]),
 tensor([[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]),
 tensor([[0.2961, 0.5166, 0.2517, 0.6886],
         [0.0740, 0.8665, 0.1366, 0.1025],
         [0.1841, 0.7264, 0.3153, 0.6871]]),
 tensor([[-0.9724, -0.7550,  0.3239, -0.1085],
         [ 0.2103, -0.3908,  0.2350,  0.6653],
         [ 0.3528,  0.9728, -0.0386, -0.8861]]))

In [4]:
torch.tril(torch.ones(3,3), diagonal=1), torch.triu(torch.ones(3,3), diagonal=1),

(tensor([[1., 1., 0.],
         [1., 1., 1.],
         [1., 1., 1.]]),
 tensor([[0., 1., 1.],
         [0., 0., 1.],
         [0., 0., 0.]]))

In [5]:
# How to normalize x along each row
x = torch.randn(3, 4)
print(x)
x / x.sum(dim=1, keepdim=True) # or.view(x.shape[0], 1)


tensor([[-0.4709, -0.4269, -0.0283,  1.4220],
        [-0.3886, -0.8903, -0.9601, -0.4087],
        [ 1.0764, -0.4015, -0.7291, -0.1218]])


tensor([[-0.9496, -0.8609, -0.0571,  2.8676],
        [ 0.1468,  0.3363,  0.3626,  0.1544],
        [-6.1105,  2.2795,  4.1392,  0.6917]])

In [6]:
# Dot product

x = torch.randn(3, 4)

res = 0
for ix in range(len(x[0])):
    res += x[0][ix] * x[1][ix]
print(res)

torch.dot(x[0], x[1])

tensor(-0.7915)


tensor(-0.7915)

In [7]:
# Matrix dot product

x = torch.randn(3, 4)

att = torch.empty(len(x), len(x))
for ix1 in range(len(x)):
    for ix2 in range(len(x)):
        att[ix1, ix2] = torch.dot(x[ix1], x[ix2])
print(att)

x @ x.T



tensor([[ 7.0545, -1.7496,  1.3280],
        [-1.7496,  1.5781,  0.5650],
        [ 1.3280,  0.5650,  1.5604]])


tensor([[ 7.0545, -1.7496,  1.3280],
        [-1.7496,  1.5781,  0.5650],
        [ 1.3280,  0.5650,  1.5604]])

In [8]:
# Softmax 

def softmax_custom(x):
    return torch.exp(x) / torch.exp(x).sum(dim=1, keepdim=True) # or .view(x.shape[0], 1)
print(softmax_custom(att))

x = torch.randn(3, 4)
att = x @ x.T

att_w = torch.softmax(att, dim=-1)

tensor([[9.9660e-01, 1.4961e-04, 3.2474e-03],
        [2.5644e-02, 7.1482e-01, 2.5953e-01],
        [3.6658e-01, 1.7092e-01, 4.6250e-01]])


In [9]:
# Contextual vectors: Multiply att_w with x

x = torch.randn(3, 4)
att = x @ x.T
att_w = torch.softmax(att, dim=-1)
print(att_w.shape, x.shape)

ctx = torch.zeros(*x.shape)
print(ctx.shape)
for ix1 in range(x.shape[0]):
    for ix2 in range(att_w.shape[1]):
        ctx[ix1] += att_w[ix1, ix2] * x[ix2]
print(ctx)

ctx = att_w @ x
ctx

torch.Size([3, 3]) torch.Size([3, 4])
torch.Size([3, 4])
tensor([[ 0.4221,  0.0596, -0.3723, -0.7592],
        [ 0.5339,  0.0740, -0.6985, -0.6406],
        [ 0.9379,  0.1913, -0.3909,  1.0240]])


tensor([[ 0.4221,  0.0596, -0.3723, -0.7592],
        [ 0.5339,  0.0740, -0.6985, -0.6406],
        [ 0.9379,  0.1913, -0.3909,  1.0240]])

### Self attention - Basic
- No trainable parameters

In [10]:
# In short

x = torch.randn(3, 4)

att = x @ x.T
att = torch.softmax(att, dim=-1)
ctx = att @ x
print(x.shape, att.shape, ctx.shape)
print(ctx)

torch.Size([3, 4]) torch.Size([3, 3]) torch.Size([3, 4])
tensor([[-0.9471,  0.4830, -0.2339, -0.9257],
        [-0.9380,  0.8607, -0.7171, -1.3901],
        [ 0.6743, -0.2244,  0.5628,  0.9100]])


### Self attention - Option 1
- With trainable weight parameters
- Using torch.nn.Parameter

In [11]:
class SelfAttn1(nn.Module):
    def __init__(self, din, d):
        super().__init__()
        self.wq = nn.Parameter(torch.rand(din, d), requires_grad=False)
        self.wk = nn.Parameter(torch.rand(din, d), requires_grad=False) 
        self.wv = nn.Parameter(torch.rand(din, d), requires_grad=False)  

    def forward(self, x):
        q, k, v = x @ self.wq, x @ self.wk, x @ self.wv
        att = q @ k.T
        att = torch.softmax(att / k.shape[-1]**0.5, dim=-1)
        ctx = att @ v
        return ctx

x = torch.randn(3, 4)
din = x.shape[1]
d = 3
sa_layer = SelfAttn1(din, d)
ctx = sa_layer(x)
print(ctx.shape, ctx)

torch.Size([3, 3]) tensor([[-0.6417, -0.0450, -1.1361],
        [-1.2040, -0.1568, -2.0606],
        [ 0.0228,  0.0875, -0.0442]])


### Self attention - Option 2
- With trainable weight parameters
- Using torch.nn.Linear

In [12]:
# torch.nn.Linear has better weight initialization

class SelfAttn2(nn.Module):
    def __init__(self, din, d, bias=False):
        super().__init__()
        self.wq = nn.Linear(din, d, bias=bias)
        self.wk = nn.Linear(din, d, bias=bias) 
        self.wv = nn.Linear(din, d, bias=bias)  

    def forward(self, x):
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        att = q @ k.T
        att = torch.softmax(att / k.shape[-1]**0.5, dim=-1)
        ctx = att @ v
        return ctx

x = torch.randn(3, 4)
din = x.shape[1]
d = 3
sa_layer = SelfAttn2(din, d)
ctx = sa_layer(x)
print(ctx.shape, ctx)

torch.Size([3, 3]) tensor([[-0.3736,  0.3227,  0.0807],
        [ 0.1649,  0.1902,  0.4977],
        [ 0.0631,  0.2223,  0.4000]], grad_fn=<MmBackward0>)


In [13]:
# Self attn forward internals for experiments below

x = torch.randn(3, 4)

q, k, v = sa_layer.wq(x), sa_layer.wk(x), sa_layer.wv(x)
att_orig = q @ k.T
att = torch.softmax( att_orig / k.shape[-1]**0.5, dim=-1)
ctx = att @ v

print(att.shape)
print(att_orig)
print(att)

torch.Size([3, 3])
tensor([[-0.4432,  0.1943,  0.0184],
        [-0.4422,  0.0369,  0.0991],
        [ 1.4514, -0.1574, -0.8511]], grad_fn=<MmBackward0>)
tensor([[0.2667, 0.3853, 0.3481],
        [0.2713, 0.3578, 0.3709],
        [0.6025, 0.2380, 0.1595]], grad_fn=<SoftmaxBackward0>)


### Dropout

In [14]:
dropout = torch.nn.Dropout(0.2)

y = torch.ones(4, 4)
print(y)

print("Dropout scales non dropout nodes by:", 1/(1-0.2))

print(dropout(y))

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
Dropout scales non dropout nodes by: 1.25
tensor([[1.2500, 1.2500, 1.2500, 1.2500],
        [1.2500, 1.2500, 1.2500, 1.2500],
        [1.2500, 1.2500, 1.2500, 1.2500],
        [1.2500, 1.2500, 0.0000, 1.2500]])


### Causal attention - Option 1
- Apply mask after Softmax

In [15]:
att_len = att.shape[0]
att_mask = torch.tril(torch.ones(att_len, att_len))
att = att * att_mask
print(att_mask)
print(att)

# Need to normalize again
att = att / att.sum(dim=-1, keepdim=True)
print(att)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[0.2667, 0.0000, 0.0000],
        [0.2713, 0.3578, 0.0000],
        [0.6025, 0.2380, 0.1595]], grad_fn=<MulBackward0>)
tensor([[1.0000, 0.0000, 0.0000],
        [0.4313, 0.5687, 0.0000],
        [0.6025, 0.2380, 0.1595]], grad_fn=<DivBackward0>)


### Causal attention - Option 2 (preferred)
- Apply mask before Softmax (no need to re-normalize)


In [16]:
att_len = att_orig.shape[0]
att_mask = torch.triu(torch.ones(att_len, att_len), diagonal=1)
att = att_orig.masked_fill(att_mask.bool(), -torch.inf)
print(att_mask)
print(att)
 
# Apply softmax (first time)
att = torch.softmax(att / k.shape[-1]**0.5, dim=-1)
print(att)



tensor([[0., 1., 1.],
        [0., 0., 1.],
        [0., 0., 0.]])
tensor([[-0.4432,    -inf,    -inf],
        [-0.4422,  0.0369,    -inf],
        [ 1.4514, -0.1574, -0.8511]], grad_fn=<MaskedFillBackward0>)
tensor([[1.0000, 0.0000, 0.0000],
        [0.4313, 0.5687, 0.0000],
        [0.6025, 0.2380, 0.1595]], grad_fn=<SoftmaxBackward0>)


### Causal attention - Support batch inputs + Pytorch buffers 
- Batch inputs on dimension 0
- Using Pytorch buffers for attention mask
    - Saved in model state_dict()
    - Also gets transferred to device along with model

In [None]:
class CausalAttn(nn.Module):
    def __init__(self, din, dim, seq_len, dropout, bias=False):
        super().__init__()
        self.wq = nn.Linear(din, dim, bias=bias)
        self.wk = nn.Linear(din, dim, bias=bias) 
        self.wv = nn.Linear(din, dim, bias=bias)  
        self.dropout = nn.Dropout(dropout)
        # att mask as regular tensor
        # self.att_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        # att mask as pytorch buffer
        att_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        self.register_buffer('att_mask', att_mask)

    def forward(self, x):
        bs, seq_len, din = x.shape
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        att = q @ k.transpose(1, 2)
        att_mask = self.att_mask.bool()[:seq_len, :seq_len] # Select mask for seq_len & bool
        att.masked_fill_(att_mask, -torch.inf)      
        att = torch.softmax(att / k.shape[-1]**0.5, dim=-1)
        att = self.dropout(att)
        ctx = att @ v
        return ctx

x = torch.randn(3, 4)
batch_x = torch.stack((x, x), dim=0) # Create a batch of input
din = batch_x.shape[2] # Input dim
dim = 3 # dim of att layer embeddings
seq_len = 6 # Max ctx length supported
ca_layer = CausalAttn(din, dim, seq_len, 0.1)
ctx = ca_layer(batch_x)

print(ctx.shape, ctx)
print(ca_layer.state_dict())
print(type(ca_layer.att_mask), ca_layer.att_mask.device)

torch.Size([2, 3, 3]) tensor([[[-0.4459,  0.8970, -2.0975],
         [-0.6559, -0.2561, -0.6735],
         [-0.6442, -0.5407, -0.2109]],

        [[-0.4459,  0.8970, -2.0975],
         [-0.6559, -0.2561, -0.6735],
         [-0.6442, -0.5407, -0.2109]]], grad_fn=<UnsafeViewBackward0>)
OrderedDict([('att_mask', tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])), ('wq.weight', tensor([[ 0.1647,  0.4296, -0.1152,  0.4357],
        [-0.2384, -0.0656,  0.3323, -0.2590],
        [ 0.3815,  0.1226, -0.0098,  0.4279]])), ('wk.weight', tensor([[ 0.3751, -0.2057,  0.0485,  0.0583],
        [ 0.4096,  0.2810,  0.4049,  0.3048],
        [-0.4351,  0.3322, -0.1328,  0.4012]])), ('wv.weight', tensor([[ 0.3146, -0.2923, -0.0526,  0.0746],
        [ 0.1429, -0.4631,  0.0224,  0.2605],
        [ 0.2823,  0.2459,  0.0791, -0.4796]]))])
<class 'torch.Te

In [None]:
# Save model
torch.save(ca_layer.state_dict(), '../output_dir/ca_model.pth')

# Load model & check if buffer is stored
ca_layer = CausalAttn(din, dim, seq_len, 0.1)
ca_layer.load_state_dict(torch.load('../output_dir/ca_model.pth'))
print(ca_layer.state_dict())


OrderedDict([('att_mask', tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])), ('wq.weight', tensor([[ 0.1647,  0.4296, -0.1152,  0.4357],
        [-0.2384, -0.0656,  0.3323, -0.2590],
        [ 0.3815,  0.1226, -0.0098,  0.4279]])), ('wk.weight', tensor([[ 0.3751, -0.2057,  0.0485,  0.0583],
        [ 0.4096,  0.2810,  0.4049,  0.3048],
        [-0.4351,  0.3322, -0.1328,  0.4012]])), ('wv.weight', tensor([[ 0.3146, -0.2923, -0.0526,  0.0746],
        [ 0.1429, -0.4631,  0.0224,  0.2605],
        [ 0.2823,  0.2459,  0.0791, -0.4796]]))])


  ca_layer.load_state_dict(torch.load('./output_dir/ca_model.pth'))


### Multi-head attention - Option 1
- Do multi-head attn calculations serially  


In [None]:
class MHA1(nn.Module):
    def __init__(self, din, dim, seq_len, dropout, num_heads, bias=False):
        super().__init__()
        self.heads = nn.ModuleList([CausalAttn(din, dim, seq_len, dropout, bias) for _ in range(num_heads)])
    
    def forward(self, x):
        heads = [head(x) for head in self.heads]
        ctx = torch.cat(heads, dim=-1)
        return ctx

x = torch.randn(3, 4)
batch_x = torch.stack((x, x), dim=0) # Create a batch of input
din = batch_x.shape[2] # Input dim
dim = 4 # dim of att layer embeddings
seq_len = 10 # Max ctx length supported
num_heads = 2
ca_layer = MHA1(din, dim, seq_len, 0.1, num_heads)
ctx = ca_layer(batch_x) 

print(batch_x.shape)
print(ctx.shape, ctx)  # bs, seqlen, num_heads*dim

torch.Size([2, 3, 4])
torch.Size([2, 3, 8]) tensor([[[ 1.3275e-01,  1.8884e-02, -1.1431e+00,  5.8095e-01, -5.4552e-01,
          -9.2739e-01, -1.3008e-01, -6.0687e-02],
         [ 8.8851e-02, -3.9358e-03, -2.8078e-02, -2.2807e-01,  1.0267e-01,
          -3.2266e-05, -1.3571e-01,  1.2448e-01],
         [-2.8530e-02, -1.4310e-01, -4.4194e-01,  2.7705e-01, -7.6101e-02,
          -2.7175e-01, -1.4880e-01,  8.3744e-02]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -5.4552e-01,
          -9.2739e-01, -1.3008e-01, -6.0687e-02],
         [ 1.5331e-01,  5.2337e-03, -5.8312e-01,  5.4018e-02, -1.8933e-01,
          -4.9643e-01, -2.0533e-01,  9.1993e-02],
         [-2.8530e-02, -1.4310e-01, -4.4194e-01,  2.7705e-01, -2.9850e-02,
           5.4718e-02,  1.1047e-01,  2.9939e-02]]], grad_fn=<CatBackward0>)


### Multi-head attention - Option 2 (preferred) 
- Do multi-head attn calculations in parallel 
- For given input embedding, create q, k, v embeddings to be same size as input embedding
- Split q, k, v embeddings into num_heads 
- Perform attention on each head
- Compute context for each head
- Reshape & Concatenate context results from all heads
- Also, apply linear layer (proj) on context 

In [None]:

class MHA2(nn.Module):
    def __init__(self, din, dim, seq_len, dropout, num_heads, bias=False, dtype=None):
        super().__init__()

        assert dim % num_heads == 0, "Given dim should be multiple of num_heads"
        self.head_dim = dim // num_heads
        self.num_heads, self.din, self.dim = num_heads, din, dim
        
        self.wq = nn.Linear(din, dim, bias=bias, dtype=dtype)
        self.wk = nn.Linear(din, dim, bias=bias, dtype=dtype) 
        self.wv = nn.Linear(din, dim, bias=bias, dtype=dtype)  
        
        att_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        self.register_buffer('att_mask', att_mask)

        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(dim, dim, bias=bias, dtype=dtype) # bias can be True here, even if qkv bias can be False

    def forward(self, x):
        bs, seq_len, din = x.shape
        
        q, k, v = self.wq(x), self.wk(x), self.wv(x)  # (bs, seq_len, dim)

        # Reshape to (bs, seq_len, num_heads, head_dim)
        q = q.view(bs, seq_len, self.num_heads, self.head_dim)
        k = k.view(bs, seq_len, self.num_heads, self.head_dim) 
        v = v.view(bs, seq_len, self.num_heads, self.head_dim) 

        # Reshape to calculate attn in parallel for all heads
        q = q.transpose(1, 2) # (bs, num_heads, seq_len, head_dim)
        k = k.transpose(1, 2) # (bs, num_heads, seq_len, head_dim)
        v = v.transpose(1, 2) # (bs, num_heads, seq_len, head_dim)
        
        # att matrix mult along seq_len, head_dim. 
        att = q @ k.transpose(2, 3) # (bs, num_heads, seq_len, seq_len)
        
        # causal attn + dropout 
        att_mask = self.att_mask.bool()[:seq_len, :seq_len] # Select mask for seq_len & convert to bool
        att.masked_fill_(att_mask, -torch.inf)      
        att = torch.softmax(att / k.shape[-1]**0.5, dim=-1)
        att = self.dropout(att)
        
        # Calc context & reshape from (bs, num_heads, seq_len, head_dim) & then to (bs, seq_len, num_heads, head_dim)
        ctx = (att @ v).transpose(1, 2)
        
        # Concatenate heads to get (bs, seq_len, dim) & make it contiguous in memory
        ctx = ctx.contiguous().view(bs, seq_len, self.dim)
        ctx = self.proj(ctx)

        return ctx

x = torch.randn(3, 4)
batch_x = torch.stack((x, x), dim=0) # Create a batch of input
din = batch_x.shape[2] # Input dim
dim = 4 # dim of att layer embeddings
seq_len = 10 # Max ctx length supported
num_heads = 2
mha_layer = MHA2(din, dim, seq_len, 0.1, num_heads)
ctx = mha_layer(batch_x) 

print(batch_x.shape)
print(ctx.shape, ctx)  # bs, seqlen, dim

torch.Size([2, 3, 4])
torch.Size([2, 3, 4]) tensor([[[ 0.1145,  0.3370,  0.1083,  0.1875],
         [-0.0768,  0.1804, -0.0021,  0.2917],
         [-0.3022,  0.2157, -0.0712,  0.3215]],

        [[ 0.1145,  0.3370,  0.1083,  0.1875],
         [-0.0768,  0.1804, -0.0021,  0.2917],
         [-0.2115,  0.1560, -0.0558,  0.3222]]], grad_fn=<ViewBackward0>)


### Efficient MHA

### References:

> https://github.com/rasbt/LLMs-from-scratch