In [1]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
device="mps"

In [2]:
def calculate_masked_attention(
    values: torch.Tensor,
    keys: torch.Tensor,
    query: torch.Tensor,
    mask: torch.Tensor = None,
):
    attention_scores = torch.matmul(query, keys.transpose(-2, -1))
    attention_scores = attention_scores / math.sqrt(keys.shape[-1])
    if mask is not None:
        attention_scores = torch.where(
            mask == 0,
            torch.full_like(attention_scores, -1e9),
            attention_scores
        )
    attention_scores = F.softmax(attention_scores, dim=-1)
    attention = torch.matmul(attention_scores, values)
    return attention, attention_scores

In [3]:
class FeedForward(nn.Module):
    def __init__(self,embed_size:int):
        super().__init__()
        self.layer1=nn.Linear(embed_size,embed_size)
        self.layer2=nn.Linear(embed_size,embed_size)
    def forward(self,x):
        x=self.layer1(x)
        x=F.gelu(x)
        x=self.layer2(x)
        return x;
class AttentionLayer(nn.Module):
    def __init__(self,embed_size:int):
        super().__init__()
        self.embed_size=embed_size
        self.query_dense=nn.Linear(embed_size,embed_size)
        self.key_dense=nn.Linear(embed_size,embed_size)
        self.value_dense=nn.Linear(embed_size,embed_size)
        self.output_dense=nn.Linear(embed_size,embed_size)
    def forward(self,embeddings:torch.Tensor):
        batch_size=embeddings.shape[0]
        seq_length=embeddings.shape[1]
        query=self.query_dense(embeddings)
        key=self.key_dense(embeddings)
        value=self.value_dense(embeddings)
        right_triangular_mask=torch.tril(torch.ones((1,seq_length,seq_length)).to(embeddings.device))
        attention, attention_scores=calculate_masked_attention(value,key,query,right_triangular_mask)
        return  attention, attention_scores



In [4]:
class TransformerBlock(nn.Module):
    def __init__(self,embed_size:int):
        super().__init__()
        self.attention_layer=AttentionLayer(embed_size)
        self.feed_forward=FeedForward(embed_size)
        self.layer_norm1=nn.LayerNorm(embed_size)
    def forward(self,x:torch.Tensor):
        context,attention_scores=self.attention_layer(x)
        context=self.layer_norm1(context)
        context=self.feed_forward(context)
        context=F.gelu(context)
        output=context+x
        return output,attention_scores
class Transformer(nn.Module):
    def __init__(self,embed_size=int,num_layers=int):
        super().__init__()
        self.transformer_blocks=nn.ModuleList([TransformerBlock(embed_size) for _ in range(num_layers)])
    def forward(self,x:torch.Tensor):
        attention_scores=[]
        for transformer_block in self.transformer_blocks:
            x,attention_score=transformer_block(x)
            attention_scores.append(attention_score)
        return x,attention_scores


In [None]:
class sinusodialPositionalEncoding(nn.Module):
    def __init__(self,embed_size:int,max_sequence_length:int):
        super().__init__()
        position=torch