In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# What we have
eng_vocab = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "cat": 3, "sat": 4, "mat": 5}
zh_vocab = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "猫": 3, "坐": 4, "在": 5, "垫子上": 6}

inv_zh_vocab = {v: k for k, v in zh_vocab.items()}

# Combine them
pairs = [
    (["cat", "sat", "mat"], ["猫", "坐", "在", "垫子上"])
]

# encoding func
def encode(tokens, vocab, add_special=True):
    ids = [vocab[t] for t in tokens]
    if add_special:
        return [vocab["<sos>"]] + ids + [vocab["<eos>"]]
    return ids

# max padding length
max_len = 10  

def pad(seq, max_len):
    return seq + [0] * (max_len - len(seq))

data = []
for src, tgt in pairs:
    src_ids = pad(encode(src, eng_vocab), max_len)
    tgt_ids = pad(encode(tgt, zh_vocab), max_len)
    data.append((src_ids, tgt_ids))

# to tensors
src_batch = torch.tensor([item[0] for item in data])  # shape (batch, src_len)
tgt_batch = torch.tensor([item[1] for item in data])  # shape (batch, tgt_len)

# real transformer model
class TinyTransformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=32, nhead=4, num_layers=2):
        super().__init__()
        self.d_model = d_model
        self.src_embed = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model)
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead, num_encoder_layers=num_layers, 
            num_decoder_layers=num_layers, batch_first=True
        )
        self.output_proj = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):
        src_emb = self.src_embed(src) * (self.d_model ** 0.5)
        tgt_emb = self.tgt_embed(tgt) * (self.d_model ** 0.5)
        
        # Generate padding masks
        src_key_padding_mask = (src == 0)
        tgt_key_padding_mask = (tgt == 0)
        
        # Create causal mask for decoder
        tgt_seq_len = tgt.shape[1]
        causal_mask = torch.triu(torch.ones(tgt_seq_len, tgt_seq_len), diagonal=1).bool()

        out = self.transformer(
            src_emb, tgt_emb,
            tgt_mask=causal_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
        )
        return self.output_proj(out)

# training
model = TinyTransformer(len(eng_vocab), len(zh_vocab))
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# shift
tgt_input = tgt_batch[:, :-1]
tgt_output = tgt_batch[:, 1:]

# train the model
for epoch in range(100):
    optimizer.zero_grad()
    logits = model(src_batch, tgt_input)  # shape: (batch, tgt_len-1, vocab_size)
    loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_output.reshape(-1))
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f"Epoch {epoch} | Loss: {loss.item():.4f}")


Epoch 0 | Loss: 2.2175
Epoch 10 | Loss: 0.4325
Epoch 20 | Loss: 0.2532
Epoch 30 | Loss: 0.1809
Epoch 40 | Loss: 0.1407
Epoch 50 | Loss: 0.1168
Epoch 60 | Loss: 0.0930
Epoch 70 | Loss: 0.0837
Epoch 80 | Loss: 0.0671
Epoch 90 | Loss: 0.0581


In [None]:
# a translate function to demo
def translate(model, src_sentence, max_len=10):
    model.eval()
    
    # Step 1: Encode the English input
    src_ids = encode(src_sentence, eng_vocab)
    src_ids = pad(src_ids, max_len)
    src_tensor = torch.tensor([src_ids])

    # Step 2: Initialize decoder input with <sos>
    tgt_ids = [zh_vocab["<sos>"]]

    for _ in range(max_len):
        tgt_tensor = torch.tensor([pad(tgt_ids, max_len)])

        # Step 3: Run the model
        with torch.no_grad():
            output = model(src_tensor, tgt_tensor)

        # Step 4: Get next token (last time step)
        next_token_logits = output[0, len(tgt_ids) - 1]
        next_token_id = next_token_logits.argmax().item()

        # Step 5: Append next token
        tgt_ids.append(next_token_id)

        # Stop at <eos>
        if next_token_id == zh_vocab["<eos>"]:
            break

    # Step 6: Convert token ids to words
    translated = [inv_zh_vocab.get(tok, "?") for tok in tgt_ids[1:-1]]
    return " ".join(translated)


In [None]:
# demo
test_sentence = ["cat", "sat", "mat"]
output = translate(model, test_sentence)
print(f'Input: {" ".join(test_sentence)}')
print(f'Translation: {output}')

Input: cat sat mat
Translation: 猫 坐 在 垫子上


  output = torch._nested_tensor_from_mask(
