In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchtext.datasets import IMDB
from collections import Counter
import re



In [2]:
def tokenize(text):
    return re.findall(r'\b\w+\b', text.lower())

def build_vocab(data, max_size=10000):
    counter = Counter()
    for label, line in data:
        counter.update(tokenize(line))
    most_common = counter.most_common(max_size - 2)
    vocab = {'<PAD>': 0, '<UNK>': 1}
    vocab.update({word: i+2 for i, (word, _) in enumerate(most_common)})
    return vocab

class IMDBDataset(Dataset):
    def __init__(self, data, vocab, max_len=200):
        self.data = []
        for label, text in data:
            tokens = tokenize(text)
            ids = [vocab.get(token, 1) for token in tokens[:max_len]]
            ids += [0] * (max_len - len(ids))
            label = 1 if label == 'pos' else 0
            self.data.append((torch.tensor(ids), torch.tensor(label)))
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Positional Encoding (reuse)
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-torch.log(torch.tensor(10000.0)) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, max_len, D)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# Multi-head Attention (self and cross)
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.head_dim = embed_dim // num_heads
        self.num_heads = num_heads
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value, mask=None):
        B, T_q, D = query.shape
        T_k = key.size(1)

        Q = self.q_proj(query).reshape(B, T_q, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(key).reshape(B, T_k, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).reshape(B, T_k, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.head_dim ** 0.5  # (B, heads, T_q, T_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)  # (B, heads, T_q, head_dim)
        out = out.transpose(1, 2).contiguous().reshape(B, T_q, D)
        return self.out_proj(out)

# Feed Forward Network
class FeedForward(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )

    def forward(self, x):
        return self.net(x)

# Encoder Block
class EncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(embed_dim, num_heads)
        self.ff = FeedForward(embed_dim, ff_hidden_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask=None):
        x = self.norm1(x + self.dropout(self.attn(x, x, x, src_mask)))
        x = self.norm2(x + self.dropout(self.ff(x)))
        return x

# Decoder Block
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads)
        self.cross_attn = MultiHeadAttention(embed_dim, num_heads)
        self.ff = FeedForward(embed_dim, ff_hidden_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, tgt_mask=None, memory_mask=None):
        x = self.norm1(x + self.dropout(self.self_attn(x, x, x, tgt_mask)))
        x = self.norm2(x + self.dropout(self.cross_attn(x, enc_out, enc_out, memory_mask)))
        x = self.norm3(x + self.dropout(self.ff(x)))
        return x

# Full Transformer
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, embed_dim, num_heads, ff_hidden_dim, num_layers, max_len=512):
        super().__init__()
        self.src_embed = nn.Embedding(src_vocab_size, embed_dim)
        self.tgt_embed = nn.Embedding(tgt_vocab_size, embed_dim)
        self.pos_enc = PositionalEncoding(embed_dim, max_len)

        self.encoder = nn.ModuleList([
            EncoderBlock(embed_dim, num_heads, ff_hidden_dim) for _ in range(num_layers)
        ])
        self.decoder = nn.ModuleList([
            DecoderBlock(embed_dim, num_heads, ff_hidden_dim) for _ in range(num_layers)
        ])
        self.output_proj = nn.Linear(embed_dim, tgt_vocab_size)

    def make_tgt_mask(self, tgt):
        T = tgt.size(1)
        mask = torch.tril(torch.ones(T, T, device=tgt.device)).unsqueeze(0).unsqueeze(1)
        return mask  # (1, 1, T, T)

    def forward(self, src, tgt, src_mask=None):
        src = self.pos_enc(self.src_embed(src))
        tgt = self.pos_enc(self.tgt_embed(tgt))
        tgt_mask = self.make_tgt_mask(tgt)

        for layer in self.encoder:
            src = layer(src, src_mask)

        for layer in self.decoder:
            tgt = layer(tgt, src, tgt_mask, src_mask)

        return self.output_proj(tgt)  # (B, T, vocab_size)
