In [1]:
import torch as t
import torch.nn as nn
import math
import einops
import gpt_tests

In [2]:
import pdb
def upper_right_mask(x: t.Tensor):
    return t.tril(x, 0) + (t.triu(t.ones(x.shape), 1) * -1e4)

def lower_left_mask(x: t.Tensor):
    return t.triu(x, 0) + (t.tril(t.ones(x.shape), -1) * -1e4)

class UniMultiHeadAttention(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.attn_ll = nn.Linear(hidden_size, hidden_size*3)
        self.output_ll = nn.Linear(hidden_size, hidden_size)
        self.head_size = hidden_size // num_heads
        self.hidden_size = hidden_size # embedding size
        self.num_heads = num_heads

        
    def forward(self, x: t.Tensor): # [batch, seq_len, hidden_size]
        batch, seq_len, _ = x.shape
        qkv = self.attn_ll(x) # [batch, seq_len, 3 * hidden_size]
        
        q, k, v = einops.rearrange(qkv, 'b s (three e) -> three b e s', three=3) # e for embedding size (which is hidden size)
        q, k, v = [einops.rearrange(m, 'b (h n) s -> b n h s', n=self.num_heads) for m in (q, k, v)]
        #pdb.set_trace()
        
        raw_score = t.einsum('bnhs,bnhz->bnsz', k, v)
        
        scaled_score = raw_score / math.sqrt(self.head_size)
        
        masked_score = upper_right_mask(scaled_score) 
        
        softmaxed_score = masked_score.softmax(-1) # batch, num_heads, seq_len, seq_len
        
        Z = t.einsum('bnsz,bnhz -> bnhs', softmaxed_score, v)
        Z = einops.rearrange(Z, 'b n h s -> b s (n h)')
        
        output = self.output_ll(Z)
        
        return output, qkv, k, q, v, masked_score, softmaxed_score, Z
        # WhatWeWant = Z * WO

In [3]:
tensor = t.arange(1, 26).reshape((5,5))

In [4]:
tensor

tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20],
        [21, 22, 23, 24, 25]])

In [5]:
upper_right_mask(tensor)

tensor([[ 1.0000e+00, -1.0000e+04, -1.0000e+04, -1.0000e+04, -1.0000e+04],
        [ 6.0000e+00,  7.0000e+00, -1.0000e+04, -1.0000e+04, -1.0000e+04],
        [ 1.1000e+01,  1.2000e+01,  1.3000e+01, -1.0000e+04, -1.0000e+04],
        [ 1.6000e+01,  1.7000e+01,  1.8000e+01,  1.9000e+01, -1.0000e+04],
        [ 2.1000e+01,  2.2000e+01,  2.3000e+01,  2.4000e+01,  2.5000e+01]])

In [6]:
batch, seq_len, hidden_size = 1, 4, 100
x = t.randn((batch, seq_len, hidden_size))

module = UniMultiHeadAttention(hidden_size, 5)
output = module(x)

(output, qkv, k, q, v, masked_score, softmaxed_score, Z), (true_output, qkv2, k2, q2, v2, attn_scores, attn_prob, combined_v) = gpt_tests.test_unidirectional_attn(UniMultiHeadAttention)
#output

In [36]:
qkv.shape
einops.rearrange(qkv, 'b s (three e) -> three b e s', three=3).shape

torch.Size([3, 1, 24, 5])

In [43]:
qkv2.shape
t.split(qkv2, 24, dim=-1)[0].equal(q2)

False

In [31]:
einops.rearrange(k, 'b n h s -> b n s h').equal(k2)

False

In [33]:
k

tensor([[[[-0.5433, -0.3950,  0.0243,  0.1633, -0.4108],
          [-0.6218,  0.1771, -0.1220,  0.0649,  0.8572],
          [-0.1799, -1.3417, -0.3185,  0.5328,  0.1256],
          [-0.1425, -0.0530, -0.4618, -0.2510,  1.0349],
          [-0.4637,  0.9413, -0.5275,  0.0479, -0.7248],
          [ 0.6505,  0.5064, -0.6061, -0.4649,  0.7123]],

         [[ 0.6235, -0.0756, -0.1040,  0.4833,  0.7994],
          [-0.5811, -0.9073, -0.0224,  0.3460, -0.0421],
          [ 0.1105, -0.3790, -0.0938, -0.8144,  0.4575],
          [ 0.0680,  0.6321,  0.4341,  0.0670,  0.4790],
          [-0.3622, -0.2336,  0.8287,  0.1245,  0.1725],
          [ 0.6786,  0.2461, -0.2254,  0.1167, -0.1402]],

         [[-1.0921, -0.0949,  0.8208,  0.6795,  0.4688],
          [-0.2159,  0.2922,  1.1232,  0.1396,  0.1746],
          [-0.4640, -0.0530, -0.6113, -0.0201,  0.2239],
          [ 0.0427, -1.4614, -0.2217, -0.0060,  0.2423],
          [-0.8837, -0.4219, -1.8101, -1.1027, -0.0943],
          [ 0.0610, -0.6829

In [34]:
k2

tensor([[[[-0.5433,  0.6235, -1.0921,  0.6205, -0.6218, -0.5811],
          [-0.3950, -0.0756, -0.0949,  0.3078,  0.1771, -0.9073],
          [ 0.0243, -0.1040,  0.8208, -1.2910, -0.1220, -0.0224],
          [ 0.1633,  0.4833,  0.6795, -0.2533,  0.0649,  0.3460],
          [-0.4108,  0.7994,  0.4688,  0.4010,  0.8572, -0.0421]],

         [[-0.2159,  0.4255, -0.1799,  0.1105, -0.4640,  0.6904],
          [ 0.2922, -0.3060, -1.3417, -0.3790, -0.0530, -0.0963],
          [ 1.1232, -0.3564, -0.3185, -0.0938, -0.6113, -0.1102],
          [ 0.1396,  1.6150,  0.5328, -0.8144, -0.0201,  1.2167],
          [ 0.1746,  0.2288,  0.1256,  0.4575,  0.2239,  0.3959]],

         [[-0.1425,  0.0680,  0.0427, -0.1137, -0.4637, -0.3622],
          [-0.0530,  0.6321, -1.4614, -1.6315,  0.9413, -0.2336],
          [-0.4618,  0.4341, -0.2217,  1.5536, -0.5275,  0.8287],
          [-0.2510,  0.0670, -0.0060,  0.7757,  0.0479,  0.1245],
          [ 1.0349,  0.4790,  0.2423, -0.7889, -0.7248,  0.1725]],

    

In [19]:
qkv.equal(qkv2)

True

In [None]:
k.shape

In [None]:
einops.rearrange(k, 'b n h s -> b n s h')#.equal(k2)

In [None]:
k2

In [None]:
k2

In [None]:
attn_scores

In [None]:
masked_score

In [None]:
true_output

In [None]:
output

In [None]:
example = t.tensor([-6.7706e-01, -1e-4, -1e-4, -1e-4])

In [None]:
example.softmax(dim=0)

In [None]:
q, k, v = einops.rearrange(qkv, 'b s (three h) -> three b s h', three=3)

In [None]:
q.shape