In [1]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm import tqdm
from model import Transformer, ModelArgs
from tokenizer import Tokenizer

class WikipediaDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        tokens = self.tokenizer.encode(text, bos=True, eos=True)
        tokens = tokens[:self.seq_len] + [self.tokenizer.pad_id] * (self.seq_len - len(tokens))
        return torch.tensor(tokens)

def run_layer_step(layer, x, is_forward=True, next_error=None, optimizer=None, scaler=None, loss_fn=None):
    if is_forward:
        with torch.cuda.amp.autocast():
            x = layer(x)
        return x
    else:
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            x = layer(x)
            logits = x.view(-1, x.size(-1))
            loss = loss_fn(logits, next_error.view(-1), ignore_index=tokenizer.pad_id)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        return loss.item()

def main():
    torch.autograd.set_detect_anomaly(True)

    # Load the dataset
    dataset = load_dataset("wikipedia", language="en", date="20240401", split='train[:5%]', trust_remote_code=True)
    texts = dataset['text']

    # Initialize the tokenizer
    tokenizer = Tokenizer(encoding_name='cl100k_base')

    # Prepare the dataset and dataloader
    seq_len = 2048
    wiki_dataset = WikipediaDataset(texts, tokenizer, seq_len)
    dataloader = DataLoader(wiki_dataset, batch_size=1, shuffle=True, num_workers=2)

    # Initialize the model
    model_args = ModelArgs(
        vocab_size=tokenizer.get_vocab_size(),
        dim=512,
        n_layers=6,
        n_heads=8,
        ffn_dim_multiplier=4
    )

    model = Transformer(model_args).cuda()

    # Define optimizer and learning rate scheduler
    optimizer = AdamW(model.parameters(), lr=1e-4)
    num_epochs = 3
    total_steps = len(dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

    # Mixed precision scaler
    scaler = torch.cuda.amp.GradScaler()

    # Training loop with gradient accumulation and mixed precision
    gradient_accumulation_steps = 4
    model.train()

    for epoch in range(num_epochs):
        epoch_loss = 0
        optimizer.zero_grad()

        for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
            batch = batch.cuda()
            inputs = model.embedding(batch)
            inputs = inputs.permute(1, 0, 2)  # (seq_len, batch, dim)

            # Forward pass
            for layer in model.layers:
                inputs = run_layer_step(layer, inputs, is_forward=True)

            logits = model.fc(inputs.permute(1, 0, 2))

            # Backward pass
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), batch.view(-1), ignore_index=tokenizer.pad_id)
            loss = loss / gradient_accumulation_steps
            scaler.scale(loss).backward()

            if (step + 1) % gradient_accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            epoch_loss += loss.item() * gradient_accumulation_steps

        print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(dataloader)}")

if __name__ == "__main__":
    main()


Epoch 1/3:   0%|          | 0/340203 [00:00<?, ?it/s]../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [6,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [8,0,0] Assertion `t >= 0 && t < n_classes` failed.

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
