In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, dk):
        super().__init__()
        self.dk = dk

    def forward(self, Q, K, V, mask=None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.dk)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, V)
        return output, attn

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.dk = d_model // num_heads
        self.num_heads = num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.attn = ScaledDotProductAttention(self.dk)

    def forward(self, Q, K, V, mask=None):
        B, tgt_len, _ = Q.size()
        B, src_len, _ = K.size()

        Q = self.W_q(Q).view(B, tgt_len, self.num_heads, self.dk).transpose(1, 2)
        K = self.W_k(K).view(B, src_len, self.num_heads, self.dk).transpose(1, 2)
        V = self.W_v(V).view(B, src_len, self.num_heads, self.dk).transpose(1, 2)


        output, attn = self.attn(Q, K, V, mask)

        output = output.transpose(1, 2).contiguous().view(B, tgt_len, self.num_heads * self.dk)
        return self.W_o(output)

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)  # even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]

        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :].to(x.device)
        return x

In [5]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))

In [6]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_out = self.attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout(ff_out))
        return x

In [7]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, N, num_heads, d_ff, max_len=512, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(N)
        ])
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, mask=None):
        x = self.embed(src)
        x = self.pos_enc(x)
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x, mask)
        return x

In [8]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.enc_attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, tgt_mask=None, src_mask=None):
        # Masked self-attn
        _x = x
        x = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(_x + self.dropout(x))

        # Encoder-Decoder attn
        _x = x
        x = self.enc_attn(x, enc_out, enc_out, src_mask)
        x = self.norm2(_x + self.dropout(x))

        # FeedForward
        _x = x
        x = self.ff(x)
        x = self.norm3(_x + self.dropout(x))
        return x

In [9]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, N, num_heads, d_ff, max_len=512, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(N)
        ])
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, enc_out, tgt_mask=None, src_mask=None):
        x = self.embed(tgt)
        x = self.pos_enc(x)
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x, enc_out, tgt_mask, src_mask)
        return x

In [10]:
def generate_subsequent_mask(size):
    mask = torch.tril(torch.ones(size, size)).unsqueeze(0).unsqueeze(0)
    return mask  # shape: [1, 1, size, size]

In [11]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, N=6,
                 num_heads=8, d_ff=2048, max_len=512, dropout=0.1):
        super().__init__()

        self.encoder = Encoder(src_vocab_size, d_model, N, num_heads, d_ff, max_len, dropout)
        self.decoder = Decoder(tgt_vocab_size, d_model, N, num_heads, d_ff, max_len, dropout)
        self.out_proj = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        enc_out = self.encoder(src, src_mask)
        dec_out = self.decoder(tgt, enc_out, tgt_mask, src_mask)
        logits = self.out_proj(dec_out)
        return logits

In [12]:
# Sample dummy input
src = torch.randint(0, 100, (2, 10))  # batch_size=2, seq_len=10
tgt = torch.randint(0, 100, (2, 9))   # shifted target for training

model = Transformer(src_vocab_size=100, tgt_vocab_size=100)

# Masks
src_mask = None
tgt_mask = generate_subsequent_mask(tgt.size(1)).to(tgt.device)

logits = model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
print("Logits shape:", logits.shape)  # Expect: [batch, tgt_seq_len, vocab_size]

Logits shape: torch.Size([2, 9, 100])


In [13]:
def greedy_decode(model, src, src_mask, max_len, start_symbol, device):
    src = src.to(device)
    memory = model.encoder(src, src_mask)

    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)

    for i in range(max_len - 1):
        tgt_mask = generate_subsequent_mask(ys.size(1)).to(device)
        out = model.decoder(ys, memory, tgt_mask, src_mask)
        prob = model.out_proj(out[:, -1])
        next_word = torch.argmax(prob, dim=-1).unsqueeze(0)
        ys = torch.cat([ys, next_word], dim=1)
        if next_word.item() == end_token:
            break

    return ys

In [14]:
src_sentence = torch.tensor([[1, 5, 6, 7, 8, 9, 0, 0]])  # padded
src_mask = None
start_token = 1
end_token = 2
max_len = 20

output = greedy_decode(model, src_sentence, src_mask, max_len, start_token, device='cpu')
print("Generated:", output)

Generated: tensor([[ 1, 35, 24, 78,  7, 78, 65, 78,  7, 78,  7, 78, 10, 22,  7, 78, 65, 78,
         41, 97]])


In [30]:
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0

    for src, tgt in dataloader:
        src = src.to(device)
        tgt_input = tgt[:, :-1].to(device)
        tgt_output = tgt[:, 1:].to(device)

        tgt_mask = generate_subsequent_mask(tgt_input.size(1)).to(device)

        optimizer.zero_grad()
        logits = model(src, tgt_input, tgt_mask=tgt_mask)
        loss = criterion(logits.view(-1, logits.size(-1)), tgt_output.reshape(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [32]:
index_to_word = {
    1: "<sos>",
    2: "<eos>",
    7: "I",
    10: "am",
    22: "hungry",
    24: "you",
    35: "Hello",
    41: "now",
    65: "world",
    78: ",",
    97: "."
}

In [34]:
decoded = [index_to_word[token.item()] for token in output[0]]
print(" ".join(decoded))

<sos> Hello you , I , world , I , I , am hungry I , world , now .


In [36]:
toy_data = [
    ("i am happy", "😊"),
    ("i am sad", "😢"),
    ("you are amazing", "💯"),
    ("i love pizza", "🍕"),
    ("i am angry", "😡"),
    ("good night", "🌙"),
]

In [38]:
word2idx = {
    "<pad>": 0,
    "<sos>": 1,
    "<eos>": 2,
    "i": 3,
    "am": 4,
    "happy": 5,
    "sad": 6,
    "you": 7,
    "are": 8,
    "amazing": 9,
    "love": 10,
    "pizza": 11,
    "angry": 12,
    "good": 13,
    "night": 14,
    "😊": 15,
    "😢": 16,
    "💯": 17,
    "🍕": 18,
    "😡": 19,
    "🌙": 20
}

idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(word2idx)

In [40]:
def encode(sentence):
    return [word2idx["<sos>"]] + [word2idx[word] for word in sentence.split()] + [word2idx["<eos>"]]

dataset = [(encode(src), encode(tgt)) for src, tgt in toy_data]

In [42]:
def pad_sequence(seq, max_len):
    return seq + [word2idx["<pad>"]] * (max_len - len(seq))

def collate_batch(batch):
    src_seqs, tgt_seqs = zip(*batch)
    max_src = max(len(s) for s in src_seqs)
    max_tgt = max(len(t) for t in tgt_seqs)

    src_batch = torch.tensor([pad_sequence(s, max_src) for s in src_seqs])
    tgt_batch = torch.tensor([pad_sequence(t, max_tgt) for t in tgt_seqs])
    return src_batch, tgt_batch

In [44]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_batch)
model = Transformer(src_vocab_size=vocab_size, tgt_vocab_size=vocab_size)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=word2idx["<pad>"])

for epoch in range(20):
    model.train()
    total_loss = 0

    for src, tgt in dataloader:
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        tgt_mask = generate_subsequent_mask(tgt_input.size(1))

        logits = model(src, tgt_input, tgt_mask=tgt_mask)
        loss = criterion(logits.view(-1, logits.size(-1)), tgt_output.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

Epoch 1, Loss: 8.6247
Epoch 2, Loss: 4.9630
Epoch 3, Loss: 3.8968
Epoch 4, Loss: 3.3496
Epoch 5, Loss: 2.6970
Epoch 6, Loss: 2.0072
Epoch 7, Loss: 1.1908
Epoch 8, Loss: 0.8020
Epoch 9, Loss: 0.3298
Epoch 10, Loss: 0.1751
Epoch 11, Loss: 0.0795
Epoch 12, Loss: 0.0418
Epoch 13, Loss: 0.0245
Epoch 14, Loss: 0.0210
Epoch 15, Loss: 0.0163
Epoch 16, Loss: 0.0158
Epoch 17, Loss: 0.0121
Epoch 18, Loss: 0.0097
Epoch 19, Loss: 0.0094
Epoch 20, Loss: 0.0089


In [46]:
def beam_search(model, src, src_mask, max_len, start_symbol, k=3):
    src = src.to(model.out_proj.weight.device)
    memory = model.encoder(src, src_mask)
    beams = [(torch.tensor([[start_symbol]]), 0)]

    for _ in range(max_len):
        new_beams = []
        for ys, score in beams:
            tgt_mask = generate_subsequent_mask(ys.size(1)).to(src.device)
            out = model.decoder(ys, memory, tgt_mask, src_mask)
            logits = model.out_proj(out[:, -1])
            probs = torch.log_softmax(logits, dim=-1)

            topk_probs, topk_indices = probs.topk(k)
            for i in range(k):
                new_ys = torch.cat([ys, topk_indices[:, i].unsqueeze(0)], dim=1)
                new_score = score + topk_probs[0, i].item()
                new_beams.append((new_ys, new_score))

        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:k]

    return beams[0][0]


In [48]:
import torch.nn.functional as F

def top_k_sampling(model, src, src_mask, max_len, start_symbol, k=5):
    src = src.to(model.out_proj.weight.device)
    memory = model.encoder(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).long().to(src.device)

    for _ in range(max_len - 1):
        tgt_mask = generate_subsequent_mask(ys.size(1)).to(src.device)
        out = model.decoder(ys, memory, tgt_mask, src_mask)
        logits = model.out_proj(out[:, -1])
        probs = F.softmax(logits, dim=-1)
        topk_probs, topk_indices = probs.topk(k)
        chosen = torch.multinomial(topk_probs, 1)
        next_word = topk_indices.gather(-1, chosen)
        ys = torch.cat([ys, next_word], dim=1)

    return ys
