In [None]:
import torch.nn as nn
class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, max_seq_len, dropout=0.1):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.zeros(1, max_seq_len, d_model))
        self.dropout = nn.Dropout(dropout)

        self.layers = nn.ModuleList([
            nn.ModuleDict({
                'ln1': nn.LayerNorm(d_model),
                'attn': nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True),
                'ln2': nn.LayerNorm(d_model),
                'mlp': nn.Sequential(
                    nn.Linear(d_model, 4 * d_model),
                    nn.GELU(),
                    nn.Linear(4 * d_model, d_model),
                    nn.Dropout(dropout),
                )
            }) for _ in range(n_layers)
        ])

        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.apply(self.init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            nn.init.normal_(module.weight, mean=0, std=0.02)
        if isinstance(module, nn.LayerNorm) and module.bias is not None:
            nn.init.zeros_(module.bias)

    def forward(self, x):
        B, T = x.size()
        x = self.token_emb(x) + self.pos_emb[:,:T]
        x = self.dropout(x)

        for layer in self.layers:
            x_norm = layer['ln1'](x)
            attn_output, _ = layer['attn'](
                x_norm, x_norm, x_norm,
                attn_mask=None,
                need_weights=False,
            )
            x = x + attn_output
            x = x + layer['mlp'](layer['ln2'](x))

        x = self.ln_f(x)
        logits = self.head(x)
        return logits

In [None]:
from accelerate import Accelerator
from torch.optim import AdamW
from tqdm import tqdm

def train(model, dataloader, vocab_size, epochs=3, lr=3e-4, weight_decay=0.0):
    accelerator = Accelerator()
    device = accelerator.device

    model.to(device)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
    model.train()
    for epoch in range(epochs):
        pbar = tqdm(dataloader, desc=f"Epoch: {epoch+1}", disable=not accelerator.is_local_main_process)
        for batch in pbar:
            input_ids = batch['input_ids']
            targets = batch['labels']

            outputs = model(input_ids)
            loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

            optimizer.zero_grad()
            accelerator.backward(loss)
            optimizer.step()

            pbar.postfix(loss=loss.item())

In [7]:
import torch
from torch.utils.data import Dataset, DataLoader
from src.notebook.dataloader import JPNDataset

dataset = torch.load("../data/dataset_1024.pt", weights_only=False)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

train(Decoder, dataloader, vocab_size=1024, epochs=3, lr=3e-4, weight_decay=0.01)