In [31]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
device="mps"

In [32]:
def calculate_masked_attention(
    values: torch.Tensor,
    keys: torch.Tensor,
    query: torch.Tensor,
    mask: torch.Tensor = None,
):
    attention_scores = torch.matmul(query, keys.transpose(-2, -1))
    attention_scores = attention_scores / math.sqrt(keys.shape[-1])
    if mask is not None:
        attention_scores = torch.where(
            mask == 0,
            torch.full_like(attention_scores, -1e9),
            attention_scores
        )
    attention_scores = F.softmax(attention_scores, dim=-1)
    attention = torch.matmul(attention_scores, values)
    return attention, attention_scores

In [33]:
class FeedForward(nn.Module):
    def __init__(self,embed_size:int):
        super().__init__()
        self.layer1=nn.Linear(embed_size,embed_size)
        self.layer2=nn.Linear(embed_size,embed_size)
    def forward(self,x):
        x=self.layer1(x)
        x=F.gelu(x)
        x=self.layer2(x)
        return x;
class AttentionLayer(nn.Module):
    def __init__(self,embed_size:int):
        super().__init__()
        self.embed_size=embed_size
        self.query_dense=nn.Linear(embed_size,embed_size)
        self.key_dense=nn.Linear(embed_size,embed_size)
        self.value_dense=nn.Linear(embed_size,embed_size)
        self.output_dense=nn.Linear(embed_size,embed_size)
    def forward(self,embeddings:torch.Tensor):
        batch_size=embeddings.shape[0]
        seq_length=embeddings.shape[1]
        query=self.query_dense(embeddings)
        key=self.key_dense(embeddings)
        value=self.value_dense(embeddings)
        right_triangular_mask=torch.tril(torch.ones((1,seq_length,seq_length)).to(embeddings.device))
        attention, attention_scores=calculate_masked_attention(value,key,query,right_triangular_mask)
        return  attention, attention_scores



In [34]:
class TransformerBlock(nn.Module):
    def __init__(self,embed_size:int):
        super().__init__()
        self.attention_layer=AttentionLayer(embed_size)
        self.feed_forward=FeedForward(embed_size)
        self.layer_norm1=nn.LayerNorm(embed_size)
    def forward(self,x:torch.Tensor):
        context,attention_scores=self.attention_layer(x)
        context=self.layer_norm1(context)
        context=self.feed_forward(context)
        context=F.gelu(context)
        output=context+x
        return output,attention_scores
class Transformer(nn.Module):
    def __init__(self,embed_size=int,num_layers=int):
        super().__init__()
        self.transformer_blocks=nn.ModuleList([TransformerBlock(embed_size) for _ in range(num_layers)])
    def forward(self,x:torch.Tensor):
        attention_scores=[]
        for transformer_block in self.transformer_blocks:
            x,attention_score=transformer_block(x)
            attention_scores.append(attention_score)
        return x,attention_scores


In [35]:

class sinusodialPositionalEncoding(nn.Module):
    def __init__(self, embed_size: int, max_seq_length: int):
        super().__init__()
        position = torch.arange(max_seq_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2) * (-math.log(10000.0) / embed_size))
        pe = torch.zeros(max_seq_length, embed_size)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('positional_embedding', pe)

    def forward(self, x: torch.Tensor):
        return x + self.positional_embedding[:x.size(1), :].unsqueeze(0)

class CasualLanguageModel(nn.Module):
    def __init__(self, embed_size: int, vocab_size: int, num_layers: int):
        super().__init__()
        self.embedding_layer = nn.Parameter(torch.randn(vocab_size, embed_size))
        self.transformer = Transformer(embed_size, num_layers)
        self.positional_encoding = sinusodialPositionalEncoding(embed_size, max_seq_length=20)

    def forward(self, x: torch.Tensor, return_attention_scores: bool = False):
        x = torch.nn.functional.embedding(x, self.embedding_layer)
        x = self.positional_encoding(x)
        x, attention_scores = self.transformer(x)
        logits = torch.matmul(x, self.embedding_layer.T)
        if return_attention_scores:
            return logits, attention_scores
        return logits

In [36]:
int_tokens = torch.randint(0, 5, (1, 5), dtype=torch.long).to(device)


In [37]:
next_token_prediction=CasualLanguageModel(embed_size=4,vocab_size=5,num_layers=2).to(device)
logits=next_token_prediction(int_tokens,return_attention_scores=False)


In [38]:
logits


tensor([[[-2.2827,  1.6855,  0.5489,  0.7084,  1.2900],
         [-6.0782,  3.0891, -1.6240, -0.2409,  5.0601],
         [ 6.1015, -1.2128,  1.3850, -1.4173, -4.7350],
         [ 6.9063, -0.5099,  1.4042,  0.4223, -5.2702],
         [ 0.2946,  1.6437,  1.8448,  3.7889, -1.1823]]], device='mps:0',
       grad_fn=<UnsafeViewBackward0>)