Create a tiny language model from scratch, with only two attentional layers. This model is like a BERT model, but in very simplified version. The goal here is to understand what happens behind the scenes of self-attention based language models.

**Installations for TPU usages in Google Colab**

In [1]:
# Runtime > Change runtime type > TPU
!pip install -q -U accelerate
!pip install -q -U cloud-tpu-client torch torchvision https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl

**Imports**

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla
import torch_xla.core.xla_model as xm

**Create class**
- PositionalEncoding - For positional encoding
- MultiHeadSelfAttention - For self-attention layers
- TinyLM - The full model flow

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.pe = torch.randn(1, max_len, d_model)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_head = d_model // num_heads
        self.num_heads = num_heads
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.fc_out = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size = x.size(0)
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
        attn_score = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.d_head**0.5
        attn_prob = F.softmax(attn_score, dim=-1)
        attn_output = torch.matmul(attn_prob, V).permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head)
        return self.fc_out(attn_output)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

class TinyLM(nn.Module):
    def __init__(self, vocab_size, d_model=16, num_heads=3, d_ff=64, max_len=512):
        super(TinyLM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.position_enc = PositionalEncoding(d_model, max_len)
        self.attn1 = MultiHeadSelfAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.attn2 = MultiHeadSelfAttention(d_model, num_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.feed_forward = FeedForward(d_model, d_ff)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, tokens):
        embeddings = self.embedding(tokens)
        x = embeddings + self.position_enc(embeddings)
        attn1_out = self.norm1(x + self.attn1(x))
        attn2_out = self.norm2(attn1_out + self.attn2(attn1_out))
        ff_out = self.norm3(attn2_out + self.feed_forward(attn2_out))
        return ff_out

**Example usage**

In [4]:
device = xm.xla_device()

tokens = torch.randint(0, 10000, (32, 100)).to(device)  # batch=32 sentences with 100 tokens each
model = TinyLM(vocab_size=10000, d_model=16, num_heads=2).to(device)

output = model(tokens)
output = output.view(-1, output.size(-1))

print(output.shape)

torch.Size([3200, 16])
