In [1]:
import torch 
import torch.nn as nn
import math 
import torch.functional as F 


In [None]:
torch.manual_seed(1337)
x = torch.randn(size=(2,8,4)) #bs, T, dim 8 tokens 
print(x)

tensor([[[ 0.1808, -0.0700, -0.3596, -0.9152],
         [ 0.6258,  0.0255,  0.9545,  0.0643],
         [ 0.3612,  1.1679, -1.3499, -0.5102],
         [ 0.2360, -0.2398, -0.9211,  1.5433],
         [ 1.3488, -0.1396,  0.2858,  0.9651],
         [-2.0371,  0.4931,  1.4870,  0.5910],
         [ 0.1260, -1.5627, -1.1601, -0.3348],
         [ 0.4478, -0.8016,  1.5236,  2.5086]],

        [[-0.6631, -0.2513,  1.0101,  0.1215],
         [ 0.1584,  1.1340, -1.1539, -0.2984],
         [-0.5075, -0.9239,  0.5467, -1.4948],
         [-1.2057,  0.5718, -0.5974, -0.6937],
         [ 1.6455, -0.8030,  1.3514, -0.2759],
         [-1.5108,  2.1048,  2.7630, -1.7465],
         [ 1.4516, -1.5103,  0.8212, -0.2115],
         [ 0.7789,  1.5333,  1.6097, -0.4032]]])


In [9]:
def split_into_windows(x, window_size=2):
    bs, seq_len, dim = x.shape 
    num_windows = seq_len // window_size 
    x = x.view(bs, num_windows, window_size, dim)
    return x
y = split_into_windows(x)
print(y)
print()
batch_1 = y[0]
win_1_batch_1 = y[0][0]
chunk_1_win_1_batch_1 = y[0][0][0]
chunk_2_win_1_batch_1 = y[0][0][1]
print(chunk_1_win_1_batch_1,chunk_2_win_1_batch_1)

tensor([[[[ 0.1808, -0.0700, -0.3596, -0.9152],
          [ 0.6258,  0.0255,  0.9545,  0.0643]],

         [[ 0.3612,  1.1679, -1.3499, -0.5102],
          [ 0.2360, -0.2398, -0.9211,  1.5433]],

         [[ 1.3488, -0.1396,  0.2858,  0.9651],
          [-2.0371,  0.4931,  1.4870,  0.5910]],

         [[ 0.1260, -1.5627, -1.1601, -0.3348],
          [ 0.4478, -0.8016,  1.5236,  2.5086]]],


        [[[-0.6631, -0.2513,  1.0101,  0.1215],
          [ 0.1584,  1.1340, -1.1539, -0.2984]],

         [[-0.5075, -0.9239,  0.5467, -1.4948],
          [-1.2057,  0.5718, -0.5974, -0.6937]],

         [[ 1.6455, -0.8030,  1.3514, -0.2759],
          [-1.5108,  2.1048,  2.7630, -1.7465]],

         [[ 1.4516, -1.5103,  0.8212, -0.2115],
          [ 0.7789,  1.5333,  1.6097, -0.4032]]]])

tensor([ 0.1808, -0.0700, -0.3596, -0.9152]) tensor([0.6258, 0.0255, 0.9545, 0.0643])


In [None]:
class LocalWindowAttention(nn.Module):
    def __init__(self, 
                 window_size = 2,
                 embedding_dim = 4,
                 num_attention_heads = 2,
                 causal = False,
                 look_backward = 1,
                 look_forward=1,
                 attention_dropout = 0.0
                ):
        

        super(LocalWindowAttention, self).__init__()
        self.causal = causal
        self.look_backward = look_backward
        self.look_forward = look_forward 
        self.window_size = window_size
        self.embed_dim = embedding_dim
        self.num_heads = num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads 

        # Projections
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.dropout = nn.Dropout(attention_dropout)

        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)

    def split_into_windows(self, x):
        if x.dim()==4:
            b, h, seq_len, head_dim = x.shape
            num_windows = seq_len // self.window_size
            x = x.view(b, h, num_windows, self.window_size, head_dim) 
            return x
    
        b, seq_len, d = x.shape 
        num_windows = seq_len // self.window_size 
        x = x.view(b, num_windows, self.window_size, d) 
        return x 
    
    
    def collect_windows(self, x, backward =1, forward = 1, pad_value = -1):
        if x.dim() == 4:
            batch_heads, num_windows, window_size, embed_dim = x.shape 
            pad = (0, 0, 0, 0, backward, forward)

        elif x.dim()==3:
            bath_head_dim, num_windows, window_size = x.shape 

            pad = (0, 0, backward, forward)

        x = nn.functional.pad(x, pad=pad, value=pad_value)
        gathered = []

        for i in range(num_windows): #2 windows
            start_idx = i 
            end_idx = i+forward+backward 

            grabbed_win = x[:, start_idx:end_idx+1] #[[pad], W0, W1] bs, win=3, window_size, embed_dim
            grabbed_win = grabbed_win.flatten(1, 2).unsqueeze(1) #bs, 1, 3*window_size, embed_dim

            gathered.append(grabbed_win) 
        gathered = torch.cat(gathered, dim=1) #bs, num_win, win_size, embed_dim 
        ''' 
        earlier we were having for each window window_size tokens 
        but for k and v it should find the affinity to backward and forward window
        then the net length of each window should incvrease from window_size to 
        window_size*(backward+forward+1=3)
        the again make it for each window the context length to 3*window_size 
        for such num_windows 
        
        '''
        return gathered 





    def forward(self, x, attention_mask = None):
        b, ori_seq_len, d = x.shape 
        h = self.num_heads 
        head_dim = self.head_dim 

        q = self.q_proj(x).view(b, ori_seq_len, h, head_dim).permute(0, 2, 1, 3).contiguous()
        k = self.k_proj(x).view(b, ori_seq_len, h, head_dim).permute(0,2,1,3).contiguous()
        v = self.k_proj(x).view(b, ori_seq_len, h, head_dim).permute(0,2,1,3).contiguous()

        ### Merge together Head/Batch Dimension ###
        q = q.flatten(0,1) #batched_head(b*h), seq_len, head_dim
        k = k.flatten(0,1)
        v = v.flatten(0,1)
        device = q.device()
        if attention_mask is not None:
            attention_mask = attention_mask.repeat(self.num_heads, 1) 

        if ori_seq_len % self.window_size == 0:
            difference = self.window_size*math.ceil(ori_seq_len/self.window_size) - ori_seq_len 
            q = nn.functional.pad(q, pad=(0,0,0, difference))
            k = nn.functional.pad(k, pad=(0,0,0, difference))
            v = nn.functional.pad(v, pad=(0,0,0, difference))

        seq_len = q.shape[1]
        num_windows = seq_len // self.window_size 
        idx = torch.arange(seq_len, device=device)
        bucketed_idx = idx.reshape(1, num_windows, self.window_size) #bs,n, win_size of indexes


        ### Bucket our Q,K,V into the Chunked Windows ###
        bucketed_q = q.reshape(b*self.num_heads, num_windows, self.window_size, self.head_dim)
        bucketed_k = k.reshape(b*self.num_heads, num_windows, self.window_size, self.head_dim)
        bucketed_v = v.reshape(b*self.num_heads, num_windows, self.window_size, self.head_dim)

        bucketed_k = self.collect_windows(bucketed_k, self.look_backward, self.look_forward)
        bucketed_v = self.collect_windows(bucketed_v, self.look_backward, self.look_forward)
        # to know the pad token lets collect through collect window 
        # batchsize*heads, num_win, win_size
        collected_bucket_idx = self.collect_windows(bucketed_idx,self.look_backward, self.look_forward)
        bucket_pad_mask = (collected_bucket_idx == -1) #b, num_win, 3*win_size 


        attention_scores = bucketed_q @ bucketed_k.transpose(-1, -2) 
        # b*h, num_win, win_size, embed_dim @ b*h, num_win, 3*win_size, embed_dim
        # attention_score has dim b*h, num_win, win_size(query tokens list length), 3*win_size(key tokens length list) 
        
        # b=1, num_win, win_size=512, 3*win_size=1536 same as attetion scores
        bucket_pad_mask = bucket_pad_mask.unsqueeze(-2).repeat(1,1,self.window_size, 1)
        ''' 
        bucket_pad_mask originally told us:
        In this window, which key tokens are fake pads?
        But attention is computed query × key.
        So we need:
        For every query token in this window, which key tokens are fake pads?
        
        '''
        attention_scores = attention_scores.masked_fill(bucket_pad_mask, float("-inf"))
        collected_bucket_idx = collected_bucket_idx.unsqueeze(-2).repeat(1,1,self.window_size, 1)

        # causal masking 


        #non causal masking



        q = self.split_into_windows(q) #b, h, num_w, win_size,head_dim
        k = self.split_into_windows(k)
        v = self.split_into_windows(v) 

        #attention in each windows 
        scores = q @ k.transpose(-1, -2) #b, h, num_w, win_size, win_size 
        scores = scores/math.sqrt(head_dim) 
        attn = torch.softmax(scores, dim=-1)
        out = attn @ v #b, h, num_w, win_size, head_dim 

        out = out.reshape(b, h, seq_len, head_dim).permute(0,2,1,3).reshape(b, seq_len, d)
        return scores, self.out_proj(out)

In [14]:
attn = LocalWindowAttention()
scores, out = attn(x) 
print("Input:", x.shape)
print("Output:", out.shape)
print(scores)

Input: torch.Size([2, 8, 4])
Output: torch.Size([8, 4])
tensor([[ 0.0159,  0.1055, -0.0799, -0.1230],
        [ 0.0427,  0.0783, -0.0897, -0.1552],
        [ 0.3784, -0.3329, -0.2693, -0.4014],
        [ 0.4179, -0.3910, -0.2918, -0.4502],
        [ 0.1549, -0.0497, -0.1635, -0.4414],
        [ 0.1166,  0.0926, -0.1035, -0.3945],
        [ 0.0485, -0.0642, -0.0954, -0.1074],
        [ 0.1839, -0.1954, -0.1418, -0.2772]], grad_fn=<UnbindBackward0>)
