In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from datasets import load_dataset
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_model_parallel_world_size
from torch.optim import AdamW
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F
import torch.distributed as dist
import os

from model import Transformer, ModelArgs
from tokenizer import Tokenizer  # Assuming your tokenizer script is named tokenizer.py

# Setup for distributed and model parallel environments
def setup_distributed(world_size=1, rank=0):
    if not dist.is_initialized():
        dist.init_process_group(
            backend='nccl',
            init_method='tcp://localhost:23456',
            world_size=world_size,
            rank=rank
        )
        initialize_model_parallel(world_size)

# Print number of parameters
def print_model_params(model):
    total_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            param_count = param.numel()
            total_params += param_count
            print(f"{name}: {param_count} params")
    print(f"Total number of parameters: {total_params}")

NUM_PROC = 24
BATCH_SIZE = 2  # Adjusted batch size
MAX_SEQ_LEN = 128  # Adjusted sequence length

# Load the dataset
dataset = load_dataset("wikipedia", language="en", date="20240401", split='train[:1%]', trust_remote_code=True, num_proc=NUM_PROC)
tokenizer_path = 'cl100k_base.tiktoken'
tokenizer = Tokenizer(tokenizer_path)

# Tokenization and data preparation
def tokenize_function(examples):
    input_ids = [tokenizer.encode(text, bos=True, eos=True) for text in examples['text']]
    return {'input_ids': input_ids}

tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=NUM_PROC)
tokenized_datasets.set_format('torch', columns=['input_ids'])

# Data loader setup
def collate_batch(batch):
    input_ids_list = [item['input_ids'].clone().detach().to(torch.long) for item in batch]
    padded_input_ids = [
        ids[:MAX_SEQ_LEN] if len(ids) > MAX_SEQ_LEN else F.pad(ids, (0, MAX_SEQ_LEN - len(ids)), value=tokenizer.pad_id)
        for ids in input_ids_list
    ]
    return {'input_ids': pad_sequence(padded_input_ids, batch_first=True, padding_value=tokenizer.pad_id)}

train_dataloader = DataLoader(tokenized_datasets, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

# Setup distributed training
setup_distributed()

# Model and training setup
model_args = ModelArgs(
    vocab_size=tokenizer.get_vocab_size(),
    dim=128,  # Adjusted model dimensions
    n_layers=2,  # Adjusted number of layers
    n_heads=4,  # Adjusted number of heads
    ffn_dim_multiplier=2  # Adjusted feed-forward network size
)
model = Transformer(model_args)

# Initialize weights
def init_weights(m):
    if isinstance(m, (nn.Linear, nn.Embedding)):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)
    elif isinstance(m, nn.LayerNorm):
        nn.init.normal_(m.weight, mean=1.0, std=0.02)
        nn.init.constant_(m.bias, 0)

model.apply(init_weights)
print_model_params(model)

# Assuming model setup and dataloader are correctly initialized
optimizer = AdamW(model.parameters(), lr=1e-5)  # Adjusted learning rate
scaler = GradScaler()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

# Gradient clipping function
def clip_gradients(model, max_norm=0.5):  # Adjusted max_norm
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

# Training loop
for epoch in range(1):  # Example: one epoch
    for i, batch in enumerate(train_dataloader):
        input_ids = batch['input_ids'].to(device)
        labels = input_ids.clone()  # Assuming labels are the shifted input_ids

        optimizer.zero_grad()  # Moved to the beginning of the loop
        with autocast():
            outputs = model(input_ids, start_pos=0)

            # Clamp logits to avoid NaNs
            outputs = torch.clamp(outputs, min=-1e9, max=1e9)

            shift_logits = outputs[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # Mask out padding from the loss calculation
            mask = shift_labels != tokenizer.pad_id
            masked_logits = shift_logits[mask]
            masked_labels = shift_labels[mask]

            if masked_logits.numel() > 0:
                if torch.isnan(masked_logits).any():
                    print(f"NaN detected in logits at Batch {i}")
                    continue

                if torch.isnan(masked_labels).any():
                    print(f"NaN detected in labels at Batch {i}")
                    continue

                loss = F.cross_entropy(masked_logits.view(-1, masked_logits.size(-1)), masked_labels.view(-1))
                if torch.isnan(loss):
                    print(f"NaN detected in loss at Batch {i}")
                    continue
                
                # Scale loss and perform backward
                scaler.scale(loss).backward()
                clip_gradients(model)  # Gradient clipping
            else:
                # Skip the backward pass if there's no valid data
                print("Skipping backward as no valid data is present in this batch.")
                continue

        # Perform optimization step and clear gradients at defined accumulation steps
        if (i + 1) % BATCH_SIZE == 0 or i == len(train_dataloader) - 1:  # ensure last batch is used
            scaler.step(optimizer)
            scaler.update()

        print(f"Epoch {epoch}, Batch {i}, Loss: {loss.item() if 'loss' in locals() else 'No Loss Computed'}")

torch.save(model.state_dict(), 'llm_model.pth')


tok_embeddings.weight: 12865536 params
layers.0.attention.wq.weight: 16384 params
layers.0.attention.wk.weight: 16384 params
layers.0.attention.wv.weight: 16384 params
layers.0.attention.wo.weight: 16384 params
layers.0.feed_forward.w1.weight: 98304 params
layers.0.feed_forward.w2.weight: 98304 params
layers.0.feed_forward.w3.weight: 98304 params
layers.0.attention_norm.weight: 128 params
layers.0.ffn_norm.weight: 128 params
layers.1.attention.wq.weight: 16384 params
layers.1.attention.wk.weight: 16384 params
layers.1.attention.wv.weight: 16384 params
layers.1.attention.wo.weight: 16384 params
layers.1.feed_forward.w1.weight: 98304 params
layers.1.feed_forward.w2.weight: 98304 params
layers.1.feed_forward.w3.weight: 98304 params
layers.1.attention_norm.weight: 128 params
layers.1.ffn_norm.weight: 128 params
norm.weight: 128 params
output.weight: 12865536 params
Total number of parameters: 26452608
NaN detected in logits at Batch 0
NaN detected in logits at Batch 1
NaN detected in logit

KeyboardInterrupt: 