In [1]:
import torch
import torch.nn as nn

print(torch.__version__)

torch.manual_seed(123)

2.4.0


<torch._C.Generator at 0x7b11d9c7c390>

### Use basic torch functions

In [2]:
x = torch.tensor(
    [[0.11, 0.12, 0.13],
     [0.21, 0.22, 0.23]]
)
print(x.shape)

torch.Size([2, 3])


In [3]:
torch.empty(x.shape), torch.zeros(x.shape), torch.ones(x.shape), torch.rand(x.shape)

(tensor([[3.6434e-44, 1.9346e-19, 1.9194e-07],
         [3.4698e-41, 1.5344e-42, 7.2143e+22]]),
 tensor([[0., 0., 0.],
         [0., 0., 0.]]),
 tensor([[1., 1., 1.],
         [1., 1., 1.]]),
 tensor([[0.2961, 0.5166, 0.2517],
         [0.6886, 0.0740, 0.8665]]))

In [4]:
# How to normalize x along each row
x / x.sum(dim=1).view(x.shape[0], 1)


tensor([[0.3056, 0.3333, 0.3611],
        [0.3182, 0.3333, 0.3485]])

In [5]:
# Dot product

res = 0
for ix in range(len(x[0])):
    res += x[0][ix] * x[1][ix]
print(res)

torch.dot(x[0], x[1])

tensor(0.0794)


tensor(0.0794)

In [6]:
# Matrix dot product

att = torch.empty(len(x), len(x))
for ix1 in range(len(x)):
    for ix2 in range(len(x)):
        att[ix1, ix2] = torch.dot(x[ix1], x[ix2])
print(att)

x @ x.T



tensor([[0.0434, 0.0794],
        [0.0794, 0.1454]])


tensor([[0.0434, 0.0794],
        [0.0794, 0.1454]])

In [7]:
# Softmax 

def softmax_custom(x):
    return torch.exp(x) / torch.exp(x).sum(dim=1).view(x.shape[0], 1)
print(softmax_custom(att))

att_w = torch.softmax(att, dim=-1)

tensor([[0.4910, 0.5090],
        [0.4835, 0.5165]])


In [8]:
# Contextual vectors: Multiply att_w with x

print(att_w.shape, x.shape)

ctx = torch.zeros(*x.shape)
print(ctx.shape)
for ix1 in range(x.shape[0]):
    for ix2 in range(att_w.shape[1]):
        ctx[ix1] += att_w[ix1, ix2] * x[ix2]
print(ctx)

ctx = att_w @ x
ctx

torch.Size([2, 2]) torch.Size([2, 3])
torch.Size([2, 3])
tensor([[0.1609, 0.1709, 0.1809],
        [0.1616, 0.1716, 0.1816]])


tensor([[0.1609, 0.1709, 0.1809],
        [0.1616, 0.1716, 0.1816]])

### Self attention

In [9]:
# In short
att = x @ x.T
ctx = torch.softmax(att, dim=-1) @ x
print(x.shape, att.shape, ctx.shape)
print(ctx)

torch.Size([2, 3]) torch.Size([2, 2]) torch.Size([2, 3])
tensor([[0.1609, 0.1709, 0.1809],
        [0.1616, 0.1716, 0.1816]])


### Self attention with trainable weight parameters (using torch.nn.Parameter)

In [None]:
class SelfAttn1(nn.Module):
    def __init__(self, din, d):
        super().__init__()
        self.wq = nn.Parameter(torch.rand(din, d), requires_grad=False)
        self.wk = nn.Parameter(torch.rand(din, d), requires_grad=False) 
        self.wv = nn.Parameter(torch.rand(din, d), requires_grad=False)  

    def forward(self, x):
        q, k, v = x @ self.wq, x @ self.wk, x @ self.wv
        att = q @ k.T
        ctx = torch.softmax(att / k.shape[-1]**0.5, dim=-1) @ v
        return ctx

din = x.shape[1]
d = 3
layer = SelfAttn1(din, d)
ctx = layer(x)
print(ctx.shape, ctx)

torch.Size([2, 3]) tensor([[0.2049, 0.2957, 0.2305],
        [0.2059, 0.2971, 0.2316]])


### Self attention with trainable weight parameters (using torch.nn.Linear)

In [14]:
# torch.nn.Linear has better weight initialization

class SelfAttn2(nn.Module):
    def __init__(self, din, d, bias=False):
        super().__init__()
        self.wq = nn.Linear(din, d, bias=bias)
        self.wk = nn.Linear(din, d, bias=bias) 
        self.wv = nn.Linear(din, d, bias=bias)  

    def forward(self, x):
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        att = q @ k.T
        ctx = torch.softmax(att / k.shape[-1]**0.5, dim=-1) @ v
        return ctx

din = x.shape[1]
d = 3
layer = SelfAttn2(din, d)
ctx = layer(x)
print(ctx.shape, ctx)

torch.Size([2, 3]) tensor([[-0.0986, -0.0090, -0.0202],
        [-0.0987, -0.0090, -0.0203]], grad_fn=<MmBackward0>)


### Causal attention

### Multi-head attention

### Efficient MHA