In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import warnings
warnings.simplefilter('ignore')

In [74]:
class CausalSelfAttentionHead(nn.Module):
    def __init__(self, embed_dim, head_size, context_len, attn_drop_val):
        super(CausalSelfAttentionHead, self).__init__()
        self.embed_dim = embed_dim
        self.head_size = head_size
        self.context_len = context_len
        
        # Init layers and stuff
        self.key = nn.Linear(embed_dim, head_size, bias=False)
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        self.attn_drop = nn.Dropout(attn_drop_val)
        self.register_buffer('tril', torch.tril(torch.ones(context_len, context_len)))
    
    def forward(self, x):
        B, T, N = x.shape
        # Pass the input data through the q, k, v layers
        q, k, v = self.query(x), self.key(x), self.value(x)
        # Calculate attention scores
        attention = torch.div(torch.bmm(q, k.permute(0, 2, 1)), torch.sqrt(torch.tensor(self.head_size)))
        attention = attention.masked_fill(self.tril[:T, :T]==0, float('-inf'))
        attention = F.softmax(attention, dim=-1)
        attention = self.attn_drop(attention)
        # Aggregate by values
        # attention: (B, C, C), value: (B, C, head_dim), output: (B, C, head_dim)
        return torch.bmm(attention, v)

In [73]:
dummy_data = torch.randn([64, 16, 32], device='mps')

In [86]:
attn_drop = 0.2
head_size = 16
bs, cl, ed = dummy_data.shape
attention_block = CausalSelfAttentionHead(
    embed_dim=ed,
    head_size=head_size,
    context_len=cl,
    attn_drop_val=attn_drop
)
attention_block = attention_block.to('mps')

In [102]:
a = torch.randn([5, 5])
b = torch.randn([5, 5])

In [103]:
a

tensor([[-0.8491,  0.9288, -0.1314, -0.8019,  0.1869],
        [ 1.1707,  0.5597, -0.2477,  1.5830,  0.7145],
        [ 0.0821, -2.5598, -0.8892,  1.6925, -0.4810],
        [ 0.4280,  0.1001, -0.8228, -0.5716,  0.1815],
        [-1.0965, -0.2318,  1.4492, -0.4824,  0.0934]])

In [104]:
b

tensor([[ 0.1432, -0.3256,  1.1720,  0.1256,  0.4844],
        [-0.2341,  1.4385,  1.7822,  0.0292, -1.2843],
        [ 1.0318,  0.4665,  1.5819,  0.3102,  1.0315],
        [ 1.0084, -0.0718, -0.9305, -1.3787, -0.5911],
        [ 1.5977,  0.5849,  0.7514,  1.5093, -0.1699]])

In [107]:
import numpy as np