In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as f
import math

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self,d_model,max_length=5000):
     super().__init__()
     pe = torch.zeros(max_length,d_model)
     position = torch.arange(0,max_length,dtype=torch.float).unsqueeze(1)
     div_term = torch.exp(torch.arange(0,d_model,2).float() * (-math.log(1000.0)/d_model))
     pe[:,0::2] = torch.sin(position * div_term)
     pe[:,1::2] = torch.cos(position * div_term)
     pe = pe.unsqueeze(0)
     self.register_buffer('pe',pe)
  def forward(self,x):
    seq_len = x.size(1)
    return x + self.pe[:,seq_len].to(x.device)



In [None]:
class LayerNorm(nn.Module):
   def __init__(self, dim, eps= 1e-5):
      super().__init__()
      self.ln = nn.LayerNorm(dim,eps=eps)
   def forward(self,x):
      return self.ln(x)

In [None]:
class FeedForward(nn.Module):
   def __init__(self,d_model,d_ff,dropout=0.1):
      super().__init__()
      self.fc1 = nn.Linear(d_model,d_ff)
      self.fc2 = nn.Linear(d_ff,d_model)
      self.dropout = nn.Dropout(dropout)
      self.activation = f.gelu
   def forwad(self,x):
     x = self.fc1(x)
     x = self.activation(x)
     x = self.dropout(x)
     x = self.fc2(x)
     return x

In [None]:
#.  Attention Module
class scaled_dot_product_attention(nn.Module):
  def __init__(self,q,k,v,mask=None,dropout=None):
     super().__init__()
     dk = q.size(-1)
     scores = torch.matul(q,k.transpose(-2,-1))/math.sqrt(dk)
     if mask is not None:
       scores = scores.mask_fill(mask==0,float=('-1e9'))
     attn = torch.softmax(scores,dim=-1)
     if dropout is not None:
        attn = dropout(attn)
     output = torch.matmul(attn,v)
     return output , attn


In [None]:
class MultiHeadAttention(nn.Module):
   def __init__(self,d_model,n_heads,dropout=0.1):
     super().__init__()
     assert d_model % n_heads==0
     self.d_model = d_model
     self.n_heads = n_heads
     self.d_heads  = d_model // n_heads

     self.w_q = nn.Linear(d_model,d_model)
     self.w_k = nn.Linear(d_model,d_model)
     self.w_v = nn.Linear(d_model,d_model)
     self.w_o = nn.Linear(d_model,d_model)

     self.dropout = nn.Dropout(dropout)

   def forwad(self,query,key,value,mask=None):
      B = query.size(0)

      q = self.w_q(query).view(B,-1,self.n_heads,self.n_heads).transpose(1,2)
      k = self.w_k(key).view(B, -1, self.n_heads, self.d_head).transpose(1, 2) # (B, n_heads, seq_k, d_head)
      v = self.w_v(value).view(B, -1, self.n_heads, self.d_head).transpose(1, 2) # (B, n_heads, seq_v, d_head)
      if mask is not None:
        if mask.dim() == 3:
             mask = mask.unsqueeze(1) # add head dim so it can broadcast
      attn_output, attn_weights = scaled_dot_product_attention(q, k, v, mask=mask, dropout=self.dropout)
      attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.d_model) # (B, seq_q, d_model)
      output = self.w_o(attn_output) # linear projection to d_model
      return output, attn_weights



In [None]:
class EncoderLayer(nn.Module):
     def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout=dropout) # self-attention
        self.layernorm1 = LayerNorm(d_model) # layernorm after residual 1
        self.ff = FeedForward(d_model, d_ff, dropout=dropout) # feed-forward network
        self.layernorm2 = LayerNorm(d_model) # layernorm after residual 2
        self.dropout = nn.Dropout(dropout) # dropout for residuals


     def forward(self, x, src_mask=None):
# x shape: (B, seq_len, d_model)
       attn_out, _ = self.self_attn(x, x, x, mask=src_mask) # self-attention where Q=K=V=x
       x = x + self.dropout(attn_out) # residual connection
       x = self.layernorm1(x) # layer normalization
       ff_out = self.ff(x) # feed-forward
       x = x + self.dropout(ff_out) # residual connection
       x = self.layernorm2(x) # layer normalization
       return x

In [None]:
class DecoderLayer(nn.Module):
     def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
          super().__init__()
          self.self_attn = MultiHeadAttention(d_model, n_heads, dropout=dropout) # masked self-attention
          self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout=dropout) # cross-attention over encoder outputs
          self.layernorm1 = LayerNorm(d_model) # norm after self-attn residual
          self.layernorm2 = LayerNorm(d_model) # norm after cross-attn residual
          self.ff = FeedForward(d_model, d_ff, dropout=dropout) # feed-forward
          self.layernorm3 = LayerNorm(d_model) # norm after ff residual
          self.dropout = nn.Dropout(dropout)
     def forward(self, x, enc_out, tgt_mask=None, memory_mask=None):
# x: (B, tgt_seq, d_model), enc_out: (B, src_seq, d_model)
         self_attn_out, _ = self.self_attn(x, x, x, mask=tgt_mask) # masked self-attention
         x = x + self.dropout(self_attn_out) # residual
         x = self.layernorm1(x) # norm


         cross_attn_out, _ = self.cross_attn(x, enc_out, enc_out, mask=memory_mask) # cross-attention
         x = x + self.dropout(cross_attn_out) # residual
         x = self.layernorm2(x) # norm


         ff_out = self.ff(x) # feed-forward
         x = x + self.dropout(ff_out) # residual
         x = self.layernorm3(x) # norm
         return x


In [None]:
# -------------------------- Full Encoder & Decoder stacks --------------------------

class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model=512, n_heads=8, num_layers=6, d_ff=2048, max_len=1024, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)  # token embeddings
        self.pos_embedding = PositionalEncoding(d_model, max_len=max_len)  # positional encodings
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)
        ])  # stack of encoder layers
        self.dropout = nn.Dropout(dropout)  # dropout on embeddings

    def forward(self, src_tokens, src_mask=None):
        # src_tokens shape: (B, src_seq)
        x = self.token_embedding(src_tokens)      # (B, src_seq, d_model)
        x = self.pos_embedding(x)                 # add positional encodings
        x = self.dropout(x)                       # apply dropout
        for layer in self.layers:
            x = layer(x, src_mask=src_mask)       # pass through encoder layers
        return x                                  # encoder memory


# Decoder
class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model=512, n_heads=8, num_layers=6, d_ff=2048, max_len=1024, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)  # token embeddings for decoder
        self.pos_embedding = PositionalEncoding(d_model, max_len=max_len)  # positional encodings
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)
        ])  # decoder layers
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt_tokens, enc_out, tgt_mask=None, memory_mask=None):
        # tgt_tokens: (B, tgt_seq)
        x = self.token_embedding(tgt_tokens)                           # embed target tokens
        x = self.pos_embedding(x)                                      # add positional info
        x = self.dropout(x)                                            # apply dropout
        for layer in self.layers:
            x = layer(x, enc_out, tgt_mask=tgt_mask, memory_mask=memory_mask)  # decoder layers
        return x                                                       # decoder outputs

In [None]:
class SimpleBART(nn.Module):
    def __init__(self, vocab_size, d_model=512, n_heads=8,
                 num_encoder_layers=6, num_decoder_layers=6,
                 d_ff=2048, max_len=1024, dropout=0.1):
        super().__init__()
        # Encoder and Decoder stacks
        self.encoder = Encoder(vocab_size, d_model, n_heads,
                               num_encoder_layers, d_ff, max_len, dropout)
        self.decoder = Decoder(vocab_size, d_model, n_heads,
                               num_decoder_layers, d_ff, max_len, dropout)
        # Final LM head: maps decoder outputs → vocabulary logits
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # Initialize parameters
        self._reset_parameters()

    def _reset_parameters(self):
        """Initialize parameters with Xavier uniform (like Transformer defaults)."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def _make_src_mask(self, src_tokens, pad_token_id=1):
        """
        Create encoder attention mask.
        src_tokens: (B, src_seq)
        Returns: mask (B, 1, 1, src_seq) with True for non-pad tokens
        """
        src_mask = (src_tokens != pad_token_id).unsqueeze(1).unsqueeze(2)
        return src_mask

    def _make_tgt_mask(self, tgt_tokens, pad_token_id=1):
        """
        Create decoder attention mask.
        Combines:
        - padding mask
        - subsequent mask (causal masking: prevents attending to future tokens)
        """
        # Padding mask: (B, 1, 1, tgt_seq)
        tgt_pad_mask = (tgt_tokens != pad_token_id).unsqueeze(1).unsqueeze(2)

        # Subsequent mask: lower triangular (tgt_seq, tgt_seq)
        seq_len = tgt_tokens.size(1)
        subsequent_mask = torch.tril(
            torch.ones((seq_len, seq_len), device=tgt_tokens.device)
        ).bool()
        subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1)  # (1, 1, seq, seq)

        # Combine both
        tgt_mask = tgt_pad_mask & subsequent_mask
        return tgt_mask

    def forward(self, src_tokens, tgt_tokens, src_pad_token_id=1, tgt_pad_token_id=1):
        """
        Forward pass through SimpleBART.
        src_tokens: (B, src_seq)
        tgt_tokens: (B, tgt_seq)
        Returns: logits (B, tgt_seq, vocab_size)
        """
        # Masks
        src_mask = self._make_src_mask(src_tokens, pad_token_id=src_pad_token_id)
        tgt_mask = self._make_tgt_mask(tgt_tokens, pad_token_id=tgt_pad_token_id)
        memory_mask = src_mask  # cross-attention uses src mask

        # Encoder
        enc_out = self.encoder(src_tokens, src_mask=src_mask)

        # Decoder
        dec_out = self.decoder(tgt_tokens, enc_out,
                               tgt_mask=tgt_mask, memory_mask=memory_mask)

        # Final linear layer → vocab logits
        logits = self.lm_head(dec_out)
        return logits


In [None]:
def greedy_decode(model, src_tokens, src_pad_token_id=1,
                  max_len=50, start_token_id=2, end_token_id=3):
    """
    Simple greedy decoding loop.

    Args:
        model: SimpleBART model
        src_tokens: (B, src_seq) input source token IDs
        src_pad_token_id: pad token ID for source
        max_len: maximum length of generated sequence
        start_token_id: ID for <BOS> / start token
        end_token_id: ID for <EOS> / end token

    Returns:
        ys: (B, generated_seq_len) tensor of generated token IDs
    """
    model.eval()  # evaluation mode (disable dropout etc.)
    B = src_tokens.size(0)   # batch size
    device = src_tokens.device

    # Create encoder mask and encode source once
    src_mask = model._make_src_mask(src_tokens, pad_token_id=src_pad_token_id)
    with torch.no_grad():
        enc_out = model.encoder(src_tokens, src_mask=src_mask)

        # Initialize decoder input with <BOS> token
        ys = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)

        # Loop to generate tokens
        for _ in range(max_len - 1):
            # Create target mask for current sequence
            tgt_mask = model._make_tgt_mask(ys, pad_token_id=0)

            # Decoder forward pass
            dec_out = model.decoder(ys, enc_out,
                                    tgt_mask=tgt_mask,
                                    memory_mask=src_mask)

            # Project to vocabulary logits
            logits = model.lm_head(dec_out)  # (B, seq_len, vocab_size)

            # Pick the most probable next token (greedy)
            next_tok = logits[:, -1, :].argmax(dim=-1, keepdim=True)

            # Append to sequence
            ys = torch.cat([ys, next_tok], dim=1)

            # Stop if all sequences ended with <EOS>
            if (next_tok == end_token_id).all():
                break

    return ys  # (B, generated_seq_len)
