In [1]:
import torch
import torch.nn as nn

print(torch.__version__)

torch.manual_seed(123)

2.4.0


<torch._C.Generator at 0x74f3f22cb390>

### Basic torch functions

In [2]:
x = torch.tensor(
    [[0.11, 0.12, 0.13, 0.14],
     [0.21, 0.22, 0.23, 0.24],
     [0.31, 0.32, 0.33, 0.34]]
)
print(x.shape)

torch.Size([3, 4])


In [3]:
torch.empty(x.shape), torch.zeros(x.shape), torch.ones(x.shape), torch.rand(x.shape)

(tensor([[ 2.3822e-44,  4.7429e+30, -5.0954e+13,  3.1207e-41],
         [-5.0951e+13,  3.1207e-41,  0.0000e+00,  0.0000e+00],
         [ 1.4013e-45,  0.0000e+00,  1.4013e-45,  2.8026e-44]]),
 tensor([[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]),
 tensor([[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]),
 tensor([[0.2961, 0.5166, 0.2517, 0.6886],
         [0.0740, 0.8665, 0.1366, 0.1025],
         [0.1841, 0.7264, 0.3153, 0.6871]]))

In [4]:
torch.tril(torch.ones(3,3), diagonal=1), torch.triu(torch.ones(3,3), diagonal=1),

(tensor([[1., 1., 0.],
         [1., 1., 1.],
         [1., 1., 1.]]),
 tensor([[0., 1., 1.],
         [0., 0., 1.],
         [0., 0., 0.]]))

In [5]:
# How to normalize x along each row
x / x.sum(dim=1, keepdim=True) # or.view(x.shape[0], 1)


tensor([[0.2200, 0.2400, 0.2600, 0.2800],
        [0.2333, 0.2444, 0.2556, 0.2667],
        [0.2385, 0.2462, 0.2538, 0.2615]])

In [6]:
# Dot product

res = 0
for ix in range(len(x[0])):
    res += x[0][ix] * x[1][ix]
print(res)

torch.dot(x[0], x[1])

tensor(0.1130)


tensor(0.1130)

In [7]:
# Matrix dot product

att = torch.empty(len(x), len(x))
for ix1 in range(len(x)):
    for ix2 in range(len(x)):
        att[ix1, ix2] = torch.dot(x[ix1], x[ix2])
print(att)

x @ x.T



tensor([[0.0630, 0.1130, 0.1630],
        [0.1130, 0.2030, 0.2930],
        [0.1630, 0.2930, 0.4230]])


tensor([[0.0630, 0.1130, 0.1630],
        [0.1130, 0.2030, 0.2930],
        [0.1630, 0.2930, 0.4230]])

In [8]:
# Softmax 

def softmax_custom(x):
    return torch.exp(x) / torch.exp(x).sum(dim=1, keepdim=True) # or .view(x.shape[0], 1)
print(softmax_custom(att))

att_w = torch.softmax(att, dim=-1)

tensor([[0.3168, 0.3331, 0.3501],
        [0.3038, 0.3324, 0.3637],
        [0.2911, 0.3315, 0.3775]])


In [9]:
# Contextual vectors: Multiply att_w with x

print(att_w.shape, x.shape)

ctx = torch.zeros(*x.shape)
print(ctx.shape)
for ix1 in range(x.shape[0]):
    for ix2 in range(att_w.shape[1]):
        ctx[ix1] += att_w[ix1, ix2] * x[ix2]
print(ctx)

ctx = att_w @ x
ctx

torch.Size([3, 3]) torch.Size([3, 4])
torch.Size([3, 4])
tensor([[0.2133, 0.2233, 0.2333, 0.2433],
        [0.2160, 0.2260, 0.2360, 0.2460],
        [0.2186, 0.2286, 0.2386, 0.2486]])


tensor([[0.2133, 0.2233, 0.2333, 0.2433],
        [0.2160, 0.2260, 0.2360, 0.2460],
        [0.2186, 0.2286, 0.2386, 0.2486]])

### Self attention - Basic
- No trainable parameters

In [10]:
# In short
att = x @ x.T
att = torch.softmax(att, dim=-1)
ctx = att @ x
print(x.shape, att.shape, ctx.shape)
print(ctx)

torch.Size([3, 4]) torch.Size([3, 3]) torch.Size([3, 4])
tensor([[0.2133, 0.2233, 0.2333, 0.2433],
        [0.2160, 0.2260, 0.2360, 0.2460],
        [0.2186, 0.2286, 0.2386, 0.2486]])


### Self attention - Option 1
- With trainable weight parameters
- Using torch.nn.Parameter

In [11]:
class SelfAttn1(nn.Module):
    def __init__(self, din, d):
        super().__init__()
        self.wq = nn.Parameter(torch.rand(din, d), requires_grad=False)
        self.wk = nn.Parameter(torch.rand(din, d), requires_grad=False) 
        self.wv = nn.Parameter(torch.rand(din, d), requires_grad=False)  

    def forward(self, x):
        q, k, v = x @ self.wq, x @ self.wk, x @ self.wv
        att = q @ k.T
        att = torch.softmax(att / k.shape[-1]**0.5, dim=-1)
        ctx = att @ v
        return ctx

din = x.shape[1]
d = 3
sa_layer = SelfAttn1(din, d)
ctx = sa_layer(x)
print(ctx.shape, ctx)

torch.Size([3, 3]) tensor([[0.3725, 0.5042, 0.3365],
        [0.3807, 0.5154, 0.3438],
        [0.3888, 0.5264, 0.3510]])


### Self attention - Option 2
- With trainable weight parameters
- Using torch.nn.Linear

In [12]:
# torch.nn.Linear has better weight initialization

class SelfAttn2(nn.Module):
    def __init__(self, din, d, bias=False):
        super().__init__()
        self.wq = nn.Linear(din, d, bias=bias)
        self.wk = nn.Linear(din, d, bias=bias) 
        self.wv = nn.Linear(din, d, bias=bias)  

    def forward(self, x):
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        att = q @ k.T
        att = torch.softmax(att / k.shape[-1]**0.5, dim=-1)
        ctx = att @ v
        return ctx

din = x.shape[1]
d = 3
sa_layer = SelfAttn2(din, d)
ctx = sa_layer(x)
print(ctx.shape, ctx)

torch.Size([3, 3]) tensor([[-0.1125,  0.0133,  0.0639],
        [-0.1124,  0.0133,  0.0639],
        [-0.1123,  0.0133,  0.0638]], grad_fn=<MmBackward0>)


In [13]:
# Self attn forward internals for experiments below

q, k, v = sa_layer.wq(x), sa_layer.wk(x), sa_layer.wv(x)
att_orig = q @ k.T
att = torch.softmax( att_orig / k.shape[-1]**0.5, dim=-1)
ctx = att @ v

print(att.shape)
print(att_orig)
print(att)

torch.Size([3, 3])
tensor([[-0.0073, -0.0117, -0.0160],
        [-0.0135, -0.0217, -0.0298],
        [-0.0197, -0.0317, -0.0436]], grad_fn=<MmBackward0>)
tensor([[0.3342, 0.3333, 0.3325],
        [0.3349, 0.3333, 0.3318],
        [0.3356, 0.3333, 0.3310]], grad_fn=<SoftmaxBackward0>)


### Dropout

In [14]:
dropout = torch.nn.Dropout(0.2)

y = torch.ones(4, 4)
print(y)

print("Dropout scales non dropout nodes by:", 1/(1-0.2))

print(dropout(y))

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
Dropout scales non dropout nodes by: 1.25
tensor([[1.2500, 1.2500, 0.0000, 1.2500],
        [0.0000, 0.0000, 1.2500, 1.2500],
        [1.2500, 1.2500, 1.2500, 1.2500],
        [1.2500, 1.2500, 0.0000, 1.2500]])


### Causal attention - Option 1
- Apply mask after Softmax

In [15]:
att_len = att.shape[0]
att_mask = torch.tril(torch.ones(att_len, att_len))
att = att * att_mask
print(att_mask)
print(att)

# Need to normalize again
att = att / att.sum(dim=-1, keepdim=True)
print(att)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[0.3342, 0.0000, 0.0000],
        [0.3349, 0.3333, 0.0000],
        [0.3356, 0.3333, 0.3310]], grad_fn=<MulBackward0>)
tensor([[1.0000, 0.0000, 0.0000],
        [0.5012, 0.4988, 0.0000],
        [0.3356, 0.3333, 0.3310]], grad_fn=<DivBackward0>)


### Causal attention - Option 2 (preferred)
- Apply mask before Softmax (no need to re-normalize)


In [16]:
att_len = att_orig.shape[0]
att_mask = torch.triu(torch.ones(att_len, att_len), diagonal=1)
att = att_orig.masked_fill(att_mask.bool(), -torch.inf)
print(att_mask)
print(att)
 
# Apply softmax (first time)
att = torch.softmax(att / k.shape[-1]**0.5, dim=-1)
print(att)



tensor([[0., 1., 1.],
        [0., 0., 1.],
        [0., 0., 0.]])
tensor([[-0.0073,    -inf,    -inf],
        [-0.0135, -0.0217,    -inf],
        [-0.0197, -0.0317, -0.0436]], grad_fn=<MaskedFillBackward0>)
tensor([[1.0000, 0.0000, 0.0000],
        [0.5012, 0.4988, 0.0000],
        [0.3356, 0.3333, 0.3310]], grad_fn=<SoftmaxBackward0>)


### Causal attention - Support batch inputs + Pytorch buffers 
- Batch inputs on dimension 0
- Using Pytorch buffers for attention mask
    - Saved in model state_dict()
    - Also gets transferred todevice along with model

In [17]:
class CausalAttn(nn.Module):
    def __init__(self, din, dim, ctx_len, dropout, bias=False):
        super().__init__()
        self.wq = nn.Linear(din, dim, bias=bias)
        self.wk = nn.Linear(din, dim, bias=bias) 
        self.wv = nn.Linear(din, dim, bias=bias)  
        self.dropout = nn.Dropout(dropout)
        # att mask as regular tensor
        # self.att_mask = torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1)
        # att mask as pytorch buffer
        self.register_buffer('att_mask', torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1))

    def forward(self, x):
        bs, seq_len, din = x.shape
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        att = q @ k.transpose(1, 2)
        att_mask = self.att_mask.bool()[:seq_len, :seq_len] # Select mask for seq_len & bool
        att.masked_fill_(att_mask, -torch.inf)      
        att = torch.softmax(att / k.shape[-1]**0.5, dim=-1)
        att = self.dropout(att)
        ctx = att @ v
        return ctx

batch_x = torch.stack((x, x), dim=0) # Create a batch of input
din = batch_x.shape[2] # Input dim
dim = 3 # dim of att layer embeddings
ctx_len = 6 # Max ctx length supported
ca_layer = CausalAttn(din, dim, ctx_len, 0.1)
ctx = ca_layer(batch_x)

print(ctx.shape, ctx)
print(ca_layer.state_dict())
print(type(ca_layer.att_mask), ca_layer.att_mask.device)

torch.Size([2, 3, 3]) tensor([[[ 0.0501, -0.0684,  0.0670],
         [ 0.0699, -0.0939,  0.0947],
         [ 0.0897, -0.1196,  0.1226]],

        [[ 0.0501, -0.0684,  0.0670],
         [ 0.0699, -0.0939,  0.0947],
         [ 0.0897, -0.1196,  0.1226]]], grad_fn=<UnsafeViewBackward0>)
OrderedDict([('att_mask', tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])), ('wq.weight', tensor([[-0.1889, -0.3045,  0.4153,  0.2751],
        [ 0.1749, -0.3834,  0.3858,  0.1568],
        [ 0.3459, -0.1967,  0.1060,  0.4882]])), ('wk.weight', tensor([[ 0.3363,  0.4010, -0.1050,  0.3809],
        [-0.3916,  0.0432, -0.2815, -0.1166],
        [-0.1280,  0.0374,  0.4551,  0.2475]])), ('wv.weight', tensor([[-0.0021,  0.3549, -0.2562,  0.2577],
        [-0.0464, -0.0870,  0.0585, -0.3830],
        [ 0.0578,  0.1681,  0.4275, -0.1557]]))])
<class 'torch.Te

In [18]:
# Save model
torch.save(ca_layer.state_dict(), './output_dir/ca_model.pth')

# Load model & check if buffer is stored
ca_layer = CausalAttn(din, dim, ctx_len, 0.1)
ca_layer.load_state_dict(torch.load('./output_dir/ca_model.pth'))
print(ca_layer.state_dict())


OrderedDict([('att_mask', tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])), ('wq.weight', tensor([[-0.1889, -0.3045,  0.4153,  0.2751],
        [ 0.1749, -0.3834,  0.3858,  0.1568],
        [ 0.3459, -0.1967,  0.1060,  0.4882]])), ('wk.weight', tensor([[ 0.3363,  0.4010, -0.1050,  0.3809],
        [-0.3916,  0.0432, -0.2815, -0.1166],
        [-0.1280,  0.0374,  0.4551,  0.2475]])), ('wv.weight', tensor([[-0.0021,  0.3549, -0.2562,  0.2577],
        [-0.0464, -0.0870,  0.0585, -0.3830],
        [ 0.0578,  0.1681,  0.4275, -0.1557]]))])


  ca_layer.load_state_dict(torch.load('./output_dir/ca_model.pth'))


### Multi-head attention - Option 1
- Do multi-head attn calculations serially  


In [19]:
class MHA1(nn.Module):
    def __init__(self, din, dim, ctx_len, dropout, num_heads, bias=False):
        super().__init__()
        self.heads = nn.ModuleList([CausalAttn(din, dim, ctx_len, dropout, bias) for _ in range(num_heads)])
    
    def forward(self, x):
        heads = [head(x) for head in self.heads]
        ctx = torch.cat(heads, dim=-1)
        return ctx

batch_x = torch.stack((x, x), dim=0) # Create a batch of input
din = batch_x.shape[2] # Input dim
dim = 4 # dim of att layer embeddings
ctx_len = 10 # Max ctx length supported
num_heads = 2
ca_layer = MHA1(din, dim, ctx_len, 0.1, num_heads)
ctx = ca_layer(batch_x) 

print(ctx.shape, ctx)  # bs, seqlen, num_heads*dim

torch.Size([2, 3, 8]) tensor([[[-0.0269, -0.1061,  0.0524, -0.0600,  0.0007,  0.0862,  0.0180,
           0.0029],
         [-0.0311, -0.0962,  0.0501, -0.0522, -0.0016,  0.1145,  0.0197,
           0.0074],
         [-0.0622, -0.1926,  0.1002, -0.1045, -0.0039,  0.1430,  0.0214,
           0.0119]],

        [[-0.0269, -0.1061,  0.0524, -0.0600,  0.0007,  0.0862,  0.0180,
           0.0029],
         [-0.0446, -0.1495,  0.0764, -0.0823,  0.0000,  0.0000,  0.0000,
           0.0000],
         [-0.0622, -0.1926,  0.1002, -0.1045, -0.0010,  0.0762,  0.0131,
           0.0049]]], grad_fn=<CatBackward0>)


### Multi-head attention - Option 2 (preferred) 
- Do multi-head attn calculations in parallel 
- Split q, k, v embeddings into num_heads and perform attention. 
- Concatenate results

In [20]:

class MHA2(nn.Module):
    def __init__(self, din, dim, ctx_len, dropout, num_heads, bias=False):
        super().__init__()
        assert dim % num_heads == 0, "Given dim should be multiple of num_heads"
        self.head_dim = dim // num_heads
        self.num_heads, self.din, self.dim = num_heads, din, dim
        self.wq = nn.Linear(din, dim, bias=bias)
        self.wk = nn.Linear(din, dim, bias=bias) 
        self.wv = nn.Linear(din, dim, bias=bias)  
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('att_mask', torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1))
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        bs, seq_len, din = x.shape
        
        q, k, v = self.wq(x), self.wk(x), self.wv(x)  # (bs, seq_len, dim)

        # Reshape to (bs, seq_len, num_heads, head_dim) & then to (bs, num_heads, seq_len, head_dim)
        q = q.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2) 
        k = k.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2) 
        v = v.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2) 
        
        # att matrix mult along seq_len, head_dim. 
        att = q @ k.transpose(2, 3) # (bs, num_heads, seq_len, seq_len)
        
        # causal attn + dropout 
        att_mask = self.att_mask.bool()[:seq_len, :seq_len] # Select mask for seq_len & bool
        att.masked_fill_(att_mask, -torch.inf)      
        att = torch.softmax(att / k.shape[-1]**0.5, dim=-1)
        att = self.dropout(att)
        
        # Reshape to (bs, num_heads, seq_len, head_dim) & then to (bs, seq_len, num_heads, head_dim)
        ctx = (att @ v).transpose(1, 2)
        
        # Reshape to (bs, seq_len, dim) & make it contiguous in memory
        ctx = ctx.contiguous().view(bs, seq_len, self.dim)
        ctx = self.proj(ctx)

        return ctx

batch_x = torch.stack((x, x), dim=0) # Create a batch of input
din = batch_x.shape[2] # Input dim
dim = 4 # dim of att layer embeddings
ctx_len = 10 # Max ctx length supported
num_heads = 2
mha_layer = MHA2(din, dim, ctx_len, 0.1, num_heads)
ctx = mha_layer(batch_x) 

print(ctx.shape, ctx)  # bs, seqlen, dim

torch.Size([2, 3, 4]) tensor([[[-0.3442, -0.1607,  0.4049,  0.2980],
         [-0.3538, -0.1427,  0.4228,  0.2968],
         [-0.3374, -0.1368,  0.4419,  0.2958]],

        [[-0.3442, -0.1607,  0.4049,  0.2980],
         [-0.3232, -0.1725,  0.4033,  0.2986],
         [-0.3634, -0.1247,  0.4406,  0.2957]]], grad_fn=<ViewBackward0>)


### Efficient MHA