In [1]:
from typing import List
import math
import time
import torch
from torch import nn
import torch.nn.functional as F
from transformers import RobertaConfig, RobertaModel

from diagonaled_mm_tvm import diagonaled_mm as diagonaled_mm_tvm, mask_invalid_locations
from sliding_chunks import sliding_chunks_matmul_qk, sliding_chunks_matmul_pv, sliding_chunks_no_overlap_matmul_qk, sliding_chunks_no_overlap_matmul_pv
from longformer import LongformerConfig,LongformerSelfAttention

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = LongformerConfig.from_pretrained("C:/Users/Pankaj Deb Roy/Documents/DeepLearning/LLM_large_seq_length/longformer-base/longformer-base-4096/longformer-base-4096")
tensor = torch.randn([1,1024*6,config.hidden_size])

In [4]:
longformer_attention = LongformerSelfAttention(config,layers_id=1)
start = time.time()
output = longformer_attention(tensor)
print(f"Time taken : {(time.time()-start):.2f}s")

Time taken : 0.58s


In [5]:
class Attention(nn.Module):
    def __init__(self,config,layer_id):
        super().__init__()
        assert config.hidden_size%config.num_attention_heads ==0,ValueError(f'Hidden Size {config.hidden_size} is not a multiple of number of attention heads {config.num_attention_heads}')
        self.num_heads = config.num_attention_heads
        self.head_dim = int(config.hidden_size / config.num_attention_heads)
        self.embed_dim = config.hidden_size

        self.wq = nn.Linear(self.embed_dim,self.embed_dim)
        self.wk = nn.Linear(self.embed_dim,self.embed_dim)
        self.wv = nn.Linear(self.embed_dim,self.embed_dim)
        self.wo = nn.Linear(self.embed_dim,self.embed_dim)
    
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=False,
    ):
        bsz,seq_len,embed_dims = hidden_states.shape
        q = self.wq(hidden_states).view(bsz,seq_len,self.num_heads,self.head_dim).transpose(1,2)
        k = self.wk(hidden_states).view(bsz,seq_len,self.num_heads,self.head_dim).transpose(1,2)
        v = self.wv(hidden_states).view(bsz,seq_len,self.num_heads,self.head_dim).transpose(1,2)

        attention_score = q@(k.transpose(-1,-2)) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            attention_score = attention_score.masked_fill(attention_mask == 0, -1e9)
        
        attention_weights = F.softmax(attention_score,dim=-1)

        output = (attention_weights @ v).transpose(1,2).reshape(bsz,seq_len,embed_dims)
        output = self.wo(output)
         
        print(attention_weights.shape)
        print(output.shape)

attention = Attention(config,0)
start = time.time()
output = attention(tensor)
print(f"Time taken : {(time.time()-start):.2f}s")

torch.Size([1, 12, 6144, 6144])
torch.Size([1, 6144, 768])
Time taken : 0.94s


Time taken : 0.59s


torch.Size([1, 12, 6144, 6144])
Time taken : 0.42s
