In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
from torch.utils.data import Dataset, DataLoader

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


## Transformer Architecture

In [13]:
class MultiHeadAttention(nn.Module):

  def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
    super().__init__()
    assert d_model % n_heads == 0

    self.d_model = d_model
    self.n_heads = n_heads
    self.d_k = d_model // n_heads

    self.q_linear = nn.Linear(d_model, self.d_k * n_heads)
    self.k_linear = nn.Linear(d_model, self.d_k * n_heads)
    self.v_linear = nn.Linear(d_model, self.d_k * n_heads) # Because d_v = d_k
    self.out = nn.Linear(d_model, self.d_k * n_heads)

    self.dropout = nn.Dropout(dropout)

  def forward(self, query, key, value, mask = None):

    batch_size = query.size(0)

    Q = self.q_linear(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # batch_size, n_head, seq_len, d_k
    K = self.k_linear(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # batch_size, n_head, seq_len, d_k
    V = self.v_linear(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # batch_size, n_head, seq_len, d_k

    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) # batch_size, n_head, seq_len, seq_len

    if mask is not None:
      scores = scores.masked_fill(mask == 0, -1e9)

    attention_weights = F.softmax(scores, dim = -1) # batch_size, n_head, seq_len, seq_len
    attention_weights = self.dropout(attention_weights)

    # Concatenate heads and put through final linear layer
    context = torch.matmul(attention_weights, V) # batch_size, n_head, seq_len, d_k

    context = context.transpose(1, 2).contiguous().view(
        batch_size, -1, self.d_model
    )

    output = self.out(context) # batch_size, seq_len, d_model

    return output, attention_weights



In [14]:
class PositionalEncoding(nn.Module):

  def __init__(self, d_model: int, max_len: int):
    super().__init__()

    pe = torch.zeros(max_len, d_model)

    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

    div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                         (-math.log(10000.0)/d_model))

    pe[:,0::2] = torch.sin(position * div_term)
    pe[:,1::2] = torch.cos(position * div_term)

    pe = pe.unsqueeze(0).transpose(0, 1)

    self.register_buffer('pe', pe)

  def forward(self, x):
    return x + self.pe[:x.size(0), :]

In [15]:
class FeedForward(nn.Module):

  def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
    super().__init__()

    self.linear1 = nn.Linear(d_model, d_ff)
    self.linear2 = nn.Linear(d_ff, d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    return self.linear2(self.dropout(F.relu(self.linear1(x))))

In [16]:
class TransformerBlock(nn.Module):

  def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
    super().__init__()

    self.attention = MultiHeadAttention(d_model, n_heads, dropout)
    self.feed_forward = FeedForward(d_model, d_ff, dropout)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask = None):

    attn_output, attention_weights = self.attention(x, x, x, mask)
    x = self.norm1(x + self.dropout(attn_output))

    ff_output = self.feed_forward(x)

    x = self.norm2(x + self.dropout(ff_output))

    return x, attention_weights

In [17]:
class TinyTransformer(nn.Module):

  def __init__(self, vocab_size: int, d_model: int, n_heads: int,
               n_layers: int, d_ff: int, max_len: int, dropout: float = 0.1):

    super().__init__()

    self.d_model = d_model
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_len)

    self.transformer_blocks = nn.ModuleList([
        TransformerBlock(d_model, n_heads, d_ff, dropout)
        for _ in range(n_layers)
    ])

    self.ln_f = nn.LayerNorm(d_model)
    self.head = nn.Linear(d_model, vocab_size)
    self.dropout = nn.Dropout(dropout)

    self.apply(self._init_weights)

  def _init_weights(self, module):

    if isinstance(module, nn.Linear):
      torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
      if module.bias is not None:
        torch.nn.init.zeros_(module.bias)

    elif isinstance(module, nn.Embedding):
      torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    elif isinstance(module, nn.LayerNorm):
      torch.nn.init.zeros_(module.bias)
      torch.nn.init.ones_(module.weight)

  def create_causal_mask(self, size):
    mask = torch.tril(torch.ones(size, size)).unsqueeze(0).unsqueeze(0)
    return mask

  def forward(self, x, targets=None):

    batch_size, seq_len = x.size()

    mask = self.create_causal_mask(seq_len).to(x.device)

    # Token embedding + positional encoding
    x = self.embedding(x) * math.sqrt(self.d_model)
    x = self.pos_encoding(x)
    x = self.dropout(x)

    # Pass through transformer blocks
    attention_weights = []
    for block in self.transformer_blocks:
      x, atten = block(x, mask)
      attention_weights.append(atten)

    x = self.ln_f(x)
    logits = self.head(x)

    loss = None
    if targets is not None:
      loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

    return logits, loss, attention_weights

In [18]:
class SimpleTextDataset(Dataset):

  def __init__(self, text: str, seq_len: int):

    self.chars = sorted(list(set(text)))
    self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
    self.idx_to_char = {i:ch for i, ch in enumerate(self.chars)}

    self.data = [self.char_to_idx[ch] for ch in text]
    self.seq_len = seq_len

  def __len__(self):
    return len(self.data) - self.seq_len

  def __getitem__(self, idx):
    x = torch.tensor(self.data[idx:idx + self.seq_len], dtype=torch.long)
    y = torch.tensor(self.data[idx+1:idx+self.seq_len+1], dtype = torch.long)
    return x, y

  @property
  def vocab_size(self):
    return len(self.chars)