In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from datasets import load_dataset
from torch.optim import AdamW
from torch.cuda.amp import GradScaler, autocast
from tokenizer import Tokenizer  # Assuming your tokenizer script is named tokenizer.py

# Load the dataset
NUM_PROC = 24
BATCH_SIZE = 2  # Reduced batch size to handle memory issues
MAX_SEQ_LEN = 128  # Reduced sequence length to handle memory issues

dataset = load_dataset("wikipedia", language="en", date="20240401", split='train[:1%]', num_proc=NUM_PROC, trust_remote_code=True)
tokenizer_path = 'cl100k_base.tiktoken'
tokenizer = Tokenizer(tokenizer_path)

# Tokenization and data preparation
def tokenize_function(examples):
    input_ids = [tokenizer.encode(text, bos=True, eos=True) for text in examples['text']]
    return {'input_ids': input_ids}

tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=NUM_PROC)
tokenized_datasets.set_format('torch', columns=['input_ids'])

# Data loader setup
def collate_batch(batch):
    input_ids_list = [item['input_ids'].clone().detach().to(torch.long) for item in batch]
    padded_input_ids = [
        ids[:MAX_SEQ_LEN] if len(ids) > MAX_SEQ_LEN else F.pad(ids, (0, MAX_SEQ_LEN - len(ids)), value=tokenizer.pad_id)
        for ids in input_ids_list
    ]
    return {'input_ids': pad_sequence(padded_input_ids, batch_first=True, padding_value=tokenizer.pad_id)}

train_dataloader = DataLoader(tokenized_datasets, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

# Model definition
class ModelArgs:
    def __init__(self, vocab_size, dim, n_layers, n_heads, ffn_dim_multiplier):
        self.vocab_size = vocab_size
        self.dim = dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.ffn_dim_multiplier = ffn_dim_multiplier

class Transformer(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.embedding = nn.Embedding(args.vocab_size, args.dim)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=args.dim, nhead=args.n_heads, dim_feedforward=int(args.dim * args.ffn_dim_multiplier))
            for _ in range(args.n_layers)
        ])
        self.linear = nn.Linear(args.dim, args.vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x)
        x = self.linear(x)
        return x

# Model arguments
model_args = ModelArgs(
    vocab_size=tokenizer.get_vocab_size(),
    dim=128,  # Reduced dimensions for memory efficiency
    n_layers=2,  # Reduced layers
    n_heads=4,  # Reduced heads
    ffn_dim_multiplier=2  # Reduced feed-forward size
)

# Initialize model
model = Transformer(model_args)

# Initialize weights
def init_weights(m):
    if isinstance(m, (nn.Linear, nn.Embedding)):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)
    elif isinstance(m, nn.LayerNorm):
        nn.init.normal_(m.weight, mean=1.0, std=0.02)
        nn.init.constant_(m.bias, 0)

model.apply(init_weights)

# Setup optimizer and scaler
optimizer = AdamW(model.parameters(), lr=1e-5)  # Reduced learning rate
scaler = GradScaler()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

# Gradient clipping function
def clip_gradients(model, max_norm=1.0):
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

# Enable CUDA debugging
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

for epoch in range(1):  # Example: one epoch
    for i, batch in enumerate(train_dataloader):
        input_ids = batch['input_ids'].to(device)
        labels = input_ids.clone()  # Assuming labels are the shifted input_ids

        optimizer.zero_grad()  # Reset gradients
        with autocast():
            outputs = model(input_ids)

            # Clamp logits to avoid NaNs
            outputs = torch.clamp(outputs, min=-1e9, max=1e9)

            shift_logits = outputs[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # Debug: Print shapes and some statistics
            print(f"Batch {i}, shift_logits.shape: {shift_logits.shape}, shift_labels.shape: {shift_labels.shape}")
            print(f"shift_logits stats: min={shift_logits.min()}, max={shift_logits.max()}, mean={shift_logits.mean()}")
            print(f"shift_labels stats: min={shift_labels.min()}, max={shift_labels.max()}, mean={shift_labels.mean()}")

            # Mask out padding from the loss calculation
            mask = shift_labels != tokenizer.pad_id
            masked_logits = shift_logits[mask]
            masked_labels = shift_labels[mask]

            # Debug: Check masked logits and labels
            print(f"masked_logits.shape: {masked_logits.shape}, masked_labels.shape: {masked_labels.shape}")
            print(f"masked_logits stats: min={masked_logits.min()}, max={masked_logits.max()}, mean={masked_logits.mean()}")
            print(f"masked_labels stats: min={masked_labels.min()}, max={masked_labels.max()}, mean={masked_labels.mean()}")

            if masked_logits.numel() > 0:
                if torch.isnan(masked_logits).any():
                    print(f"NaN detected in logits at Batch {i}")
                    continue

                if torch.isnan(masked_labels).any():
                    print(f"NaN detected in labels at Batch {i}")
                    continue

                loss = F.cross_entropy(masked_logits.view(-1, masked_logits.size(-1)), masked_labels.view(-1))
                if torch.isnan(loss):
                    print(f"NaN detected in loss at Batch {i}")
                    continue
                
                # Scale loss and perform backward
                scaler.scale(loss).backward()
                clip_gradients(model)  # Gradient clipping
            else:
                # Skip the backward pass if there's no valid data
                print("Skipping backward as no valid data is present in this batch.")
                continue

        # Perform optimization step and clear gradients at defined accumulation steps
        if (i + 1) % BATCH_SIZE == 0 or i == len(train_dataloader) - 1:  # ensure last batch is used
            scaler.step(optimizer)
            scaler.update()

        print(f"Epoch {epoch}, Batch {i}, Loss: {loss.item() if 'loss' in locals() else 'No Loss Computed'}")

torch.save(model.state_dict(), 'llm_model.pth')


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Batch 0, shift_logits.shape: torch.Size([2, 127, 100512]), shift_labels.shape: torch.Size([2, 127])
shift_logits: tensor([[[-2.1561e-02,  3.4155e-01,  2.0312e-01,  ..., -1.8567e-01,
           1.7285e-01,  1.7609e-02],
         [ 1.6138e-01,  5.6915e-02,  2.1277e-01,  ..., -4.0356e-01,
           9.9304e-02,  1.4961e-04],
         [ 2.9922e-02,  1.9971e-01,  1.5552e-01,  ..., -3.2532e-02,
          -2.7490e-01,  2.5439e-01],
         ...,
         [ 2.2131e-01, -1.6388e-02,  9.1309e-02,  ..., -3.2715e-01,
           4.0253e-02,  1.2610e-01],
         [-1.2219e-01,  4.9744e-02,  5.0049e-01,  ..., -1.3269e-01,
           7.0923e-02,  1.5308e-01],
         [-2.5586e-01, -1.2891e-01,  1.8494e-01,  ..., -1.7651e-01,
          -3.0231e-03,  1.4355e-01]],

        [[-1.7548e-02,  3.1543e-01,  2.2375e-01,  ..., -2.1545e-01,
           1.5479e-01,  1.8204e-02],
         [-7.4646e-02,  1.5649e-01,  1.9312e-01,  ...,  4.6704e-01,
           1.3660e-01,  2.8662e-01],
         [-8.2581e-02, -2.9785

../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [70,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [70,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [70,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [70,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [70,0,0], thread: [100,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [70,0,0], thread: [101,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [70,0,0],

RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
