In [None]:
import matplotlib.pyplot as plt
import torch


import os, sys
sys.path.append("../../../../")

from src.core.module import Module
from src.core.losses import CrossEntropy
from src.core.optim import AdamW, Standard
from src.core.tensor import Tensor
import numpy as np


In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class T(nn.Module):
    def __init__(self, d_model, n_heads, vocab_size, max_seq_len, pad_idx):
        super().__init__()
        self.pad_idx = pad_idx
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos_embedding = nn.Parameter(torch.randn(1, max_seq_len, d_model))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=4 * d_model,
            dropout=0.0,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)

        self.project = nn.Linear(d_model, vocab_size)

    def forward(self, idx):
        # idx: (B, T)
        padding_mask = (idx == self.pad_idx)  # (B, T)
        x = self.embedding(idx) + self.pos_embedding[:, :idx.size(1)]
        x = self.encoder(x, src_key_padding_mask=padding_mask)
        logits = self.project(x)
        return logits


In [20]:
class Net(Module):
    def __init__(self, d_model, n_heads, vocab_size, max_seq_len, pad_idx):
        super().__init__()
        self.e = self.embedding(vocab_size, d_model, max_seq_len, pad_idx, name="Embedding")

        self.head1 = self.transformer(d_model=d_model, n_heads=n_heads)
        self.project = self.linear(d_model, vocab_size, name="project")
    
    def forward(self, idx):
        x, padding_mask = self.e.get_sentence_embedding(idx)
        x = self.head1(x, padding_mask)
        x = self.project(x)
        return x

    def train(self, x, y, epochs, optimizer):
        for epoch in range(epochs):
            y_hat = self.forward(x)
            loss = CrossEntropy(y_hat, y, axis=-1)
    
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if epoch % 1 == 0:
                print(f"Epoch {epoch}, Loss: {loss.data}")
                

        

In [16]:
def create_dummy_data(batch_size, seq_len, vocab_size):
    src = torch.randint(2, vocab_size - 1, (batch_size, seq_len))
    x = src[:, :-1]
    y = src[:, 1:]
    return x, y

D_MODEL = 16
N_HEADS = 4
VOCAB_SIZE = 64
MAX_SEQ_LEN = 16
PAD_IDX = 0
BATCH_SIZE = 64

model = T(D_MODEL, N_HEADS, VOCAB_SIZE, MAX_SEQ_LEN, PAD_IDX)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

x, y = create_dummy_data(BATCH_SIZE, MAX_SEQ_LEN, VOCAB_SIZE)

for epoch in range(1000):
    logits = model(x)
    loss = loss_fn(logits.view(-1, VOCAB_SIZE), y.reshape(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % 1 == 0:
        print(f"Epoch {epoch} | Loss: {loss.item():.4f}")


Epoch 0 | Loss: 4.3274
Epoch 1 | Loss: 4.2448
Epoch 2 | Loss: 4.1745
Epoch 3 | Loss: 4.1144
Epoch 4 | Loss: 4.0627
Epoch 5 | Loss: 4.0178
Epoch 6 | Loss: 3.9782
Epoch 7 | Loss: 3.9423
Epoch 8 | Loss: 3.9085
Epoch 9 | Loss: 3.8756
Epoch 10 | Loss: 3.8425
Epoch 11 | Loss: 3.8081
Epoch 12 | Loss: 3.7718
Epoch 13 | Loss: 3.7333
Epoch 14 | Loss: 3.6929
Epoch 15 | Loss: 3.6508
Epoch 16 | Loss: 3.6072
Epoch 17 | Loss: 3.5623
Epoch 18 | Loss: 3.5152
Epoch 19 | Loss: 3.4657
Epoch 20 | Loss: 3.4138
Epoch 21 | Loss: 3.3594
Epoch 22 | Loss: 3.3029
Epoch 23 | Loss: 3.2447
Epoch 24 | Loss: 3.1858
Epoch 25 | Loss: 3.1261
Epoch 26 | Loss: 3.0644
Epoch 27 | Loss: 3.0016
Epoch 28 | Loss: 2.9376
Epoch 29 | Loss: 2.8727
Epoch 30 | Loss: 2.8073
Epoch 31 | Loss: 2.7412
Epoch 32 | Loss: 2.6746
Epoch 33 | Loss: 2.6055
Epoch 34 | Loss: 2.5346
Epoch 35 | Loss: 2.4629
Epoch 36 | Loss: 2.3918
Epoch 37 | Loss: 2.3209
Epoch 38 | Loss: 2.2503
Epoch 39 | Loss: 2.1791
Epoch 40 | Loss: 2.1084
Epoch 41 | Loss: 2.0389
Ep

KeyboardInterrupt: 

In [21]:
D_MODEL = 16
N_HEADS = 4
VOCAB_SIZE = 64
MAX_SEQ_LEN = 16
PAD_IDX = 0
BATCH_SIZE = 64

In [25]:
def create_dummy_data():
    src = np.random.randint(2, VOCAB_SIZE - 1, (BATCH_SIZE, MAX_SEQ_LEN))
    x = src[:, :-1]
    y = src[:, -1:]
    return x, y

x, y = create_dummy_data()

In [24]:
net = Net(D_MODEL, N_HEADS, VOCAB_SIZE, MAX_SEQ_LEN, PAD_IDX)
net._build(x.shape)
print(net)
optimizer = Standard(net.parameters(), lr=0.01)


net.train(Tensor(x), Tensor(y), 100, optimizer)

Architecture:
  embedding_0 (embedding):
    Embedding (embedding):
      embedding_1_embedding_1_embed: shape=(64, 64), dtype=float16
      embedding_1_embedding_1_pe: shape=(64, 64), dtype=float16
  transformer_0 (transformer):
    q (linear):
      transformer_1_linear_1_q_weight: shape=(64, 64), dtype=float16
      transformer_1_linear_1_q_bias: shape=(64,), dtype=float16
    k (linear):
      transformer_1_linear_2_k_weight: shape=(64, 64), dtype=float16
      transformer_1_linear_2_k_bias: shape=(64,), dtype=float16
    v (linear):
      transformer_1_linear_3_v_weight: shape=(64, 64), dtype=float16
      transformer_1_linear_3_v_bias: shape=(64,), dtype=float16
    o (linear):
      transformer_1_linear_4_o_weight: shape=(64, 64), dtype=float16
      transformer_1_linear_4_o_bias: shape=(64,), dtype=float16
    proj_up (linear):
      transformer_1_linear_5_proj_up_weight: shape=(64, 256), dtype=float16
      transformer_1_linear_5_proj_up_bias: shape=(256,), dtype=float16
    p

KeyboardInterrupt: 