In [2]:
import os, sys
sys.path.append("../../../")

from src.core.module import Module
from src.core.losses import CrossEntropy
from src.core.optim import AdamW
from src.core.tensor import Tensor
from src.utils.backend import xp

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math


src = np.random.randint(low=1, high=16, size=(128, 16))
x = src[:, :-1]
y = src[:, 1:]

x_mine = Tensor(x, requires_grad=False)
y_mine = Tensor(y, requires_grad=False)

In [3]:
x.shape

(128, 15)

In [4]:
y.shape

(128, 15)

In [None]:


class Net(Module):
    def __init__(self, d_model, n_heads, vocab_size, max_seq_len, pad_idx=0):
        super().__init__()

        self.e = self.embedding(vocab_size, d_model, max_seq_len, pad_idx, name="Embedding")

        self.head1 = self.transformer(d_model=d_model, n_heads=n_heads)
        self.head2 = self.transformer(d_model=d_model, n_heads=n_heads)
        # self.head3 = self.transformer(d_model=d_model, n_heads=n_heads)
        # self.head4 = self.transformer(d_model=d_model, n_heads=n_heads)
        # self.head5 = self.transformer(d_model=d_model, n_heads=n_heads)
        # self.head6 = self.transformer(d_model=d_model, n_heads=n_heads)
        # self.head7 = self.transformer(d_model=d_model, n_heads=n_heads)
        # self.head8 = self.transformer(d_model=d_model, n_heads=n_heads)
        self.project = self.linear(d_model, vocab_size, name="project")
    
    def forward(self, idx):
        x, padding_mask = self.e.get_sentence_embedding(idx)
        x = Tensor(x.data, requires_grad=False)
        x = self.head1(x, padding_mask)
        x = self.head2(x, padding_mask)
        # x = self.head3(x, padding_mask)
        # x = self.head4(x, padding_mask)
        # x = self.head5(x, padding_mask)
        # x = self.head6(x, padding_mask)
        # x = self.head7(x, padding_mask)
        # x = self.head8(x, padding_mask)
        x = self.project(x)
        return x

    def train(self, x, y, epochs, optimizer):
        for epoch in range(epochs):
            y_hat = self.forward(x)
            # print(y_hat.shape, y.shape)
            loss = CrossEntropy(y_hat, y, axis=-1)
    
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if epoch % 1 == 0:
                print(f"Epoch {epoch}, Loss: {loss.data}")
                
if __name__ == "__main__":
    D_MODEL = 48
    VOCAB_SIZE = 20
    N_HEADS = 2
    MAX_SEQ_LEN = 32
    PAD_IDX = 0

    model = Net(d_model=D_MODEL, n_heads=N_HEADS, vocab_size=VOCAB_SIZE, max_seq_len=MAX_SEQ_LEN, pad_idx=PAD_IDX)
    model._build((128, 15))
    optimizer = AdamW(model.parameters(), lr=0.001, precision=(xp.float32, xp.float32))


    model.train(x_mine, y_mine, epochs=1000, optimizer=optimizer)


    
        

Epoch 0, Loss: 4.8375639759591875
Epoch 1, Loss: 4.609916775187642
Epoch 2, Loss: 4.379700126500226
Epoch 3, Loss: 4.170276920331594
Epoch 4, Loss: 3.9866856348949065
Epoch 5, Loss: 3.827142159395912
Epoch 6, Loss: 3.6867193036659405
Epoch 7, Loss: 3.5604511942208434
Epoch 8, Loss: 3.4454579066798696
Epoch 9, Loss: 3.3411474066583384
Epoch 10, Loss: 3.2481156560884
Epoch 11, Loss: 3.1668597302522548
Epoch 12, Loss: 3.097044012626811
Epoch 13, Loss: 3.037444084325245
Epoch 14, Loss: 2.986288960002013
Epoch 15, Loss: 2.941659449756754
Epoch 16, Loss: 2.901819334305836
Epoch 17, Loss: 2.8654842381530488
Epoch 18, Loss: 2.831986067177748
Epoch 19, Loss: 2.801212228327453
Epoch 20, Loss: 2.7732827113654848
Epoch 21, Loss: 2.74813071932352
Epoch 22, Loss: 2.725293788262215
Epoch 23, Loss: 2.70411176819096
Epoch 24, Loss: 2.68412485158979
Epoch 25, Loss: 2.6652321503229888
Epoch 26, Loss: 2.647522689334891
Epoch 27, Loss: 2.631087278606865


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, d_model, n_heads, vocab_size, max_seq_len, num_layers=1, pad_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos_embed = nn.Embedding(max_seq_len, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=n_heads, 
            dim_feedforward=d_model * 4, 
            batch_first=True,
            bias=False,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.project = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, idx):
        batch_size, seq_len = idx.size()
        pos = torch.arange(seq_len, device=idx.device).unsqueeze(0).expand(batch_size, seq_len)
        
        x = self.embedding(idx) + self.pos_embed(pos)
        padding_mask = (idx == 0)
        x = self.encoder(x, src_key_padding_mask=padding_mask)
        logits = self.project(x)
        return logits

    def train_model(self, x, y, epochs, optimizer, criterion):
        for epoch in range(epochs):
            optimizer.zero_grad()
            logits = self.forward(x)
            loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
            loss.backward()
            optimizer.step()
            if epoch % 1 == 0:
                print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# --- Config ---
D_MODEL = 48
VOCAB_SIZE = 20
N_HEADS = 2
MAX_SEQ_LEN = 32
PAD_IDX = 0
BATCH_SIZE = 128

# --- Model Training ---
model = Net(D_MODEL, N_HEADS, VOCAB_SIZE, MAX_SEQ_LEN, num_layers=2, pad_idx=PAD_IDX)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

x_pt = torch.tensor(x).long()
y_pt = torch.tensor(y).long()

model.train_model(x_pt, y_pt, epochs=1000, optimizer=optimizer, criterion=criterion)




Epoch 0, Loss: 3.1254
Epoch 1, Loss: 3.0547
Epoch 2, Loss: 3.0008
Epoch 3, Loss: 2.9542
Epoch 4, Loss: 2.9115
Epoch 5, Loss: 2.8769
Epoch 6, Loss: 2.8484
Epoch 7, Loss: 2.8175
Epoch 8, Loss: 2.7971
Epoch 9, Loss: 2.7807
Epoch 10, Loss: 2.7687
Epoch 11, Loss: 2.7529
Epoch 12, Loss: 2.7348
Epoch 13, Loss: 2.7284
Epoch 14, Loss: 2.7227
Epoch 15, Loss: 2.7134
Epoch 16, Loss: 2.7118
Epoch 17, Loss: 2.7034
Epoch 18, Loss: 2.6993
Epoch 19, Loss: 2.6959
Epoch 20, Loss: 2.6907
Epoch 21, Loss: 2.6801
Epoch 22, Loss: 2.6674
Epoch 23, Loss: 2.6697
Epoch 24, Loss: 2.6641
Epoch 25, Loss: 2.6567
Epoch 26, Loss: 2.6440
Epoch 27, Loss: 2.6346
Epoch 28, Loss: 2.6254
Epoch 29, Loss: 2.6153
Epoch 30, Loss: 2.6045
Epoch 31, Loss: 2.5874
Epoch 32, Loss: 2.5663
Epoch 33, Loss: 2.5516
Epoch 34, Loss: 2.5404
Epoch 35, Loss: 2.5161
Epoch 36, Loss: 2.5000
Epoch 37, Loss: 2.4775
Epoch 38, Loss: 2.4485
Epoch 39, Loss: 2.4219
Epoch 40, Loss: 2.3925
Epoch 41, Loss: 2.3673
Epoch 42, Loss: 2.3326
Epoch 43, Loss: 2.287

KeyboardInterrupt: 