# Xaiqo Core Model Implementation

This notebook implements the core model architecture for the Xaiqo chatbot.

In [None]:
# Install dependencies
%pip install torch transformers

In [None]:
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_heads = config.HEADS
        self.head_dim = config.EMBED_DIM // config.HEADS
        self.qkv = nn.Linear(config.EMBED_DIM, 3 * config.EMBED_DIM)
        self.out = nn.Linear(config.EMBED_DIM, config.EMBED_DIM)
        self.dropout = nn.Dropout(config.DROPOUT_RATE)
        
    def forward(self, x, mask=None):
        B, L, D = x.shape
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.view(B, L, self.num_heads, self.head_dim).transpose(1, 2), qkv)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(B, L, D)
        return self.out(out)

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = MultiHeadAttention(config)
        self.norm1 = nn.LayerNorm(config.EMBED_DIM)
        self.norm2 = nn.LayerNorm(config.EMBED_DIM)
        self.feed_forward = nn.Sequential(
            nn.Linear(config.EMBED_DIM, config.HIDDEN_SIZE),
            nn.GELU(),
            nn.Linear(config.HIDDEN_SIZE, config.EMBED_DIM),
            nn.Dropout(config.DROPOUT_RATE)
        )
        
    def forward(self, x, mask=None):
        x = x + self.attention(self.norm1(x), mask)
        x = x + self.feed_forward(self.norm2(x))
        return x

class XaiqoModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embedding = nn.Embedding(config.BPE_VOCAB_SIZE, config.EMBED_DIM)
        self.position_embedding = nn.Parameter(torch.zeros(1, config.MAX_SEQ_LENGTH, config.EMBED_DIM))
        self.dropout = nn.Dropout(config.DROPOUT_RATE)
        
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.NUM_BLOCKS)
        ])
        
        self.norm = nn.LayerNorm(config.EMBED_DIM)
        self.head = nn.Linear(config.EMBED_DIM, config.BPE_VOCAB_SIZE, bias=False)
        
    def forward(self, input_ids, attention_mask=None):
        B, L = input_ids.shape
        
        x = self.token_embedding(input_ids)
        x = x + self.position_embedding[:, :L, :]
        x = self.dropout(x)
        
        for block in self.transformer_blocks:
            x = block(x, attention_mask)
            
        x = self.norm(x)
        logits = self.head(x)
        
        return logits