In [12]:
import math
import torch
import random
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Problem 1 part a

class PositionalEncoding(nn.Module):
    """
    Classic Sin/Cos Positional Encoding
    """
    def __init__(self, d_model, max_len=10):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0,d_model,2).float()*(-math.log(1000.0)/d_model))
        pe[:, 0::2] = torch.sin(position * div)
        pe[:, 1::2] = torch.cos(position * div)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x is (seq_len, batch_size, embed_dim)
        x = x + self.pe[:x.size(0), :]
        return x



In [13]:
# Problem 1 part b

def scaled_dot_product_attention(Q, K, V, mask=None):

    d_k = Q.size(-1)

    Q_transposed = Q.transpose(0,1)
    K_transposed = K.transpose(0,1).transpose(-2,-1)

    print(f"Shape of Q.transpose(0,1) (tensor a): {Q_transposed.shape}")
    print(f"Shape of K.transpose(0,1).transpose(-2,-1) (tensor b): {K_transposed.shape}")

    scores = torch.matmul(Q_transposed, K_transposed)
    scores = scores / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    attn = torch.softmax(scores, dim=-1)    # (seq_len, batch, seq_len)
    output = torch.matmul(attn, V.transpose(0,1)).transpose(0,1)
    return output, attn

In [14]:
# Problem 1 part c

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        # x: (seq, batch, d_model)
        seq_len, batch, d_model = x.size()
        x = x.view(seq_len, batch, self.num_heads, self.d_k)
        x = x.permute(2, 0, 1, 3)
        return x

    def combine_heads(self, x):
        heads, seq, batch, d_k = x.size()
        x = x.permute(1, 2, 0, 3).contiguous()
        return x.view(seq, batch, heads * d_k)

    def forward(self, Q, K, V, mask=None):
        Q = self.Wq(Q)
        K = self.Wk(K)
        V = self.Wv(V)

        Q = self.split_heads(Q)
        K = self.split_heads(K)
        V = self.split_heads(V)

        outputs = []
        attn_weights = []

        for h in range(self.num_heads):
            out, att = scaled_dot_product_attention(Q[h], K[h], V[h], mask)
            outputs.append(out)
            attn_weights.append(att)

        out = torch.stack(outputs, dim=0)
        out = self.combine_heads(out)
        return self.Wo(out)

In [15]:
# Problem 1 part d

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(torch.relu(self.linear1(x)))


In [16]:
# Problem 1 part e

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        attn = self.self_attn(x, x, x, mask)
        x = self.norm1(x + attn)
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x


In [17]:
# Problem 1 part f
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_out, tgt_mask=None, src_mask=None):
        # masked self-attention
        x2 = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + x2)

        # cross-attention
        x2 = self.cross_attn(x, enc_out, enc_out, src_mask)
        x = self.norm2(x + x2)

        # FFN
        x2 = self.ffn(x)
        x = self.norm3(x + x2)
        return x

In [18]:
# Problem 1 part g
class MiniTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, max_len):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model, max_len)

        self.encoder = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])

        self.decoder = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])

        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # (seq, batch)
        src = self.pos(self.embed(src))
        tgt = self.pos(self.embed(tgt))

        # encoder
        for layer in self.encoder:
            src = layer(src, src_mask)

        # decoder
        out = tgt
        for layer in self.decoder:
            out = layer(out, src, tgt_mask, src_mask)

        return self.fc_out(out)

In [19]:
# Problem 1 part h
SOS = 11
EOS = 12
PAD = 10

def generate_example():
    digits = random.sample(range(0, 10), 8)
    sorted_digits = sorted(digits)

    src = [SOS] + digits + [EOS]
    tgt = [SOS] + sorted_digits + [EOS]

    return src, tgt

def generate_batch(batch_size, device):
    """
    Returns:
      src: LongTensor of shape (seq_len, batch_size)
      tgt: LongTensor of shape (seq_len, batch_size)
    """
    src_batch = []
    tgt_batch = []
    for _ in range(batch_size):
        src, tgt = generate_example()
        src_batch.append(src)
        tgt_batch.append(tgt)

    # convert to tensors and transpose to (seq_len, batch)
    src = torch.tensor(src_batch, dtype=torch.long, device=device).transpose(0, 1)
    tgt = torch.tensor(tgt_batch, dtype=torch.long, device=device).transpose(0, 1)
    return src, tgt

## Test function generate_batch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

src, tgt = generate_batch(batch_size=2, device=device)
print("src[:,0] =", src[:,0].tolist())
print("tgt[:,0] =", tgt[:,0].tolist())

src[:,0] = [11, 1, 9, 6, 8, 4, 2, 5, 7, 12]
tgt[:,0] = [11, 1, 2, 4, 5, 6, 7, 8, 9, 12]


In [20]:
# Problem 1 part i
## Mask for decoder

def generate_square_subsequent_mask(sz, device):

    mask = torch.tril(torch.ones(sz, sz, dtype=torch.uint8, device=device))
    return mask

## Training
VOCAB_SIZE = 13 # 0-9 for numbers, 10=PAD, 11=SOS, 12=EOS
EMBED_DIM = 16 # Embedding dimension
NUM_HEADS = 2 # Number of attention heads
NUM_LAYERS = 2 # Number of Encoder/Decoder layers
SEQ_LENGTH = 10 # (8 numbers + SOS + EOS)
D_FF = 64
BATCH_SIZE = 64
EPOCHS = 20
LEARNING_RATE = 0.001

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MiniTransformer(
    vocab_size=VOCAB_SIZE,
    d_model=EMBED_DIM,
    num_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    d_ff=D_FF,
    max_len=SEQ_LENGTH,
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=PAD)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

NUM_BATCHES_PER_EPOCH = 200

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0

    for _ in range(NUM_BATCHES_PER_EPOCH):
        src, tgt_full = generate_batch(BATCH_SIZE, device)

        # decoder input (tgt_in) and labels (tgt_out) for next-token prediction
        # shape: (seq_len-1, batch)
        tgt_in  = tgt_full[:-1, :]
        tgt_out = tgt_full[1:, :]

        # decoder look-ahead mask: size = seq_len-1
        tgt_mask = generate_square_subsequent_mask(tgt_in.size(0), device)

        # forward pass
        logits = model(src, tgt_in, src_mask=None, tgt_mask=tgt_mask)
        # logits: (seq_len-1, batch, vocab_size)

        # reshape for CrossEntropyLoss: (N, C) vs (N,)
        logits_flat = logits.reshape(-1, VOCAB_SIZE)
        tgt_out_flat = tgt_out.reshape(-1)

        loss = criterion(logits_flat, tgt_out_flat)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / NUM_BATCHES_PER_EPOCH
    print(f"Epoch {epoch+1:02d} | Loss = {avg_loss:.4f}")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Shape of Q.transpose(0,1) (tensor a): torch.Size([64, 9, 8])
Shape of K.transpose(0,1).transpose(-2,-1) (tensor b): torch.Size([64, 8, 9])
Shape of Q.transpose(0,1) (tensor a): torch.Size([64, 9, 8])
Shape of K.transpose(0,1).transpose(-2,-1) (tensor b): torch.Size([64, 8, 10])
Shape of Q.transpose(0,1) (tensor a): torch.Size([64, 9, 8])
Shape of K.transpose(0,1).transpose(-2,-1) (tensor b): torch.Size([64, 8, 10])
Shape of Q.transpose(0,1) (tensor a): torch.Size([64, 10, 8])
Shape of K.transpose(0,1).transpose(-2,-1) (tensor b): torch.Size([64, 8, 10])
Shape of Q.transpose(0,1) (tensor a): torch.Size([64, 10, 8])
Shape of K.transpose(0,1).transpose(-2,-1) (tensor b): torch.Size([64, 8, 10])
Shape of Q.transpose(0,1) (tensor a): torch.Size([64, 10, 8])
Shape of K.transpose(0,1).transpose(-2,-1) (tensor b): torch.Size([64, 8, 10])
Shape of Q.transpose(0,1) (tensor a): torch.Size([64, 10, 8])
Shape of K.transpose(0,1).trans

In [21]:
# Problem 1 part j
## Greedy decode generate tokens one by one, and at each step pick the token with the highest probability
def greedy_decode(model, src_digits, device, max_len=SEQ_LENGTH):
    """
    src_digits: list of 8 ints, e.g. [4,3,0,5,1,2,8,7]
    Returns: list of predicted sorted digits without SOS/EOS.
    """
    model.eval()
    with torch.no_grad():
        # build source sequence
        src_tokens = [SOS] + src_digits + [EOS]   # length 10
        src = torch.tensor(src_tokens, dtype=torch.long, device=device).unsqueeze(1)  # (seq_len, 1)

        # decoder starts with SOS
        tgt_tokens = [SOS]

        for _ in range(max_len - 1):
            tgt = torch.tensor(tgt_tokens, dtype=torch.long, device=device).unsqueeze(1)  # (cur_len, 1)
            tgt_mask = generate_square_subsequent_mask(tgt.size(0), device)

            # forward pass
            logits = model(src, tgt, src_mask=None, tgt_mask=tgt_mask)
            # take the last step's logits
            next_token = logits[-1, 0, :].argmax(dim=-1).item()

            tgt_tokens.append(next_token)

            if next_token == EOS:
                break

        # remove SOS/EOS to get just digits
        pred_digits = [t for t in tgt_tokens[1:] if t not in (SOS, EOS)]
        return pred_digits

# Example after training:
test_seq = [4, 3, 0, 5, 1, 2, 8, 7]
pred = greedy_decode(model, test_seq, device)
print("Input digits :", test_seq)
print("Predicted     :", pred)
# ideally: [0, 1, 2, 3, 4, 5, 7, 8]

Shape of Q.transpose(0,1) (tensor a): torch.Size([1, 10, 8])
Shape of K.transpose(0,1).transpose(-2,-1) (tensor b): torch.Size([1, 8, 10])
Shape of Q.transpose(0,1) (tensor a): torch.Size([1, 10, 8])
Shape of K.transpose(0,1).transpose(-2,-1) (tensor b): torch.Size([1, 8, 10])
Shape of Q.transpose(0,1) (tensor a): torch.Size([1, 10, 8])
Shape of K.transpose(0,1).transpose(-2,-1) (tensor b): torch.Size([1, 8, 10])
Shape of Q.transpose(0,1) (tensor a): torch.Size([1, 10, 8])
Shape of K.transpose(0,1).transpose(-2,-1) (tensor b): torch.Size([1, 8, 10])
Shape of Q.transpose(0,1) (tensor a): torch.Size([1, 1, 8])
Shape of K.transpose(0,1).transpose(-2,-1) (tensor b): torch.Size([1, 8, 1])
Shape of Q.transpose(0,1) (tensor a): torch.Size([1, 1, 8])
Shape of K.transpose(0,1).transpose(-2,-1) (tensor b): torch.Size([1, 8, 1])
Shape of Q.transpose(0,1) (tensor a): torch.Size([1, 1, 8])
Shape of K.transpose(0,1).transpose(-2,-1) (tensor b): torch.Size([1, 8, 10])
Shape of Q.transpose(0,1) (tenso