In [1]:
import torch
import tiktoken
from src.shraygpt import ShrayGPT
torch.set_num_threads(1) # Prevents deadlocks with DataLoader and multiple workers

tokenizer = tiktoken.get_encoding("r50k_base")

def get_total_param_count(module):
    return sum(p.numel() for p in module.parameters())

d_model = 32*32
n_head = 32
d_head = 32
n_layers = 32
num_experts = 8
num_experts_per_tok = 1
block_size = 8192
batch_size = 1
lr = 1e-5

model = ShrayGPT(
    vocab_size=tokenizer.n_vocab, 
    block_size=block_size, 
    d_model=d_model,
    n_head=n_head, 
    d_head=d_head, 
    n_layers=n_layers, 
    num_experts=num_experts, 
    num_experts_per_tok=num_experts_per_tok
)
model.hparams.learning_rate = lr
model.hparams.aux_loss_weight = 1e-2
# model.compile(backend="inductor", dynamic=True, mode="reduce-overhead")

params = get_total_param_count(model)
print(f"Total parameters: {params/1e9:.2f}B")

Total parameters: 2.67B


In [2]:
from torch.utils.data import IterableDataset, DataLoader
from datasets import load_dataset
import torch.distributed as dist

class IterableTextDataset(IterableDataset):
    def __init__(self, tokenizer, hf_dataset, block_size):
        self.tokenizer = tokenizer
        self.hf_dataset = hf_dataset
        self.block_size = block_size

    def _rank_world(self):
        if dist.is_available() and dist.is_initialized():
            return dist.get_rank(), dist.get_world_size()
        return 0, 1

    def __iter__(self):
        rank, world = self._rank_world()

        # Shard the HF streaming dataset so each rank reads a disjoint slice
        ds = self.hf_dataset
        if hasattr(ds, "shard"):
            ds = ds.shard(num_shards=world, index=rank, contiguous=True)

        buffer = []
        for item in ds:
            if 'text' in item:
                tokenized = self.tokenizer.encode(item['text']) + [self.tokenizer.eot_token]
                buffer.extend(tokenized)
                while len(buffer) >= self.block_size + 1:
                    x = torch.tensor(buffer[:self.block_size], dtype=torch.long)
                    y = torch.tensor(buffer[1:self.block_size+1], dtype=torch.long)
                    yield x, y
                    buffer = buffer[self.block_size:]


full_train_stream = load_dataset('HuggingFaceFW/fineweb-edu', name='sample-350BT', split='train', streaming=True)
num_val_samples = 10000  # Let's reserve 10,000 samples for validation.

val_stream_full = full_train_stream.take(num_val_samples)
train_stream_full = full_train_stream.skip(num_val_samples)

train_dataset_full = IterableTextDataset(tokenizer, train_stream_full, block_size)
val_dataset_full = IterableTextDataset(tokenizer, val_stream_full, block_size)

train_loader = DataLoader(
    train_dataset_full, 
    batch_size=batch_size, 
    num_workers=2,
    prefetch_factor=2,  
    pin_memory=True # Helps speed up data transfer to the GPU
)
val_loader = DataLoader(
    val_dataset_full, 
    batch_size=batch_size, 
    num_workers=2,
    prefetch_factor=2,  
    pin_memory=True
)

Resolving data files:   0%|          | 0/2410 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/472 [00:00<?, ?it/s]

In [3]:
import torch
import evaluate
from datasets import load_dataset
from tqdm import tqdm

print("Loading HellaSwag validation set and accuracy metric...")
hellaswag_val = load_dataset("hellaswag", split="validation")
accuracy_metric = evaluate.load("accuracy")

# For a quick demonstration, let's use a small subset of the validation set.
# A full evaluation would run on the entire set.
subset_size = 100
hellaswag_subset = hellaswag_val.select(range(subset_size))

# 2. Create an evaluation function
def evaluate_on_hellaswag(model, tokenizer, dataset):
    """
    Evaluates a model on the HellaSwag dataset.
    """
    model.eval()
    
    predictions = []
    references = []
    
    # The core idea is to calculate the loss for the context concatenated with each ending.
    # The ending that results in the lowest loss is the model's prediction.
    
    for example in tqdm(dataset, desc="Evaluating HellaSwag"):
        context = example['ctx']
        endings = example['endings']
        correct_label = int(example['label'])
        
        context_tokens = tokenizer.encode(context)
        
        losses = []
        with torch.no_grad():
            for ending in endings:
                # Create the full input by combining context and the current ending
                full_text_tokens = context_tokens + tokenizer.encode(ending)
                
                # Prepare input and target tensors
                x = torch.tensor([full_text_tokens[:-1]], dtype=torch.long, device=model.device)
                y = torch.tensor([full_text_tokens[1:]], dtype=torch.long, device=model.device)
                
                # Get the loss for this specific continuation
                logits, _, aux_loss_ = model(x)
                total_loss, main_loss, aux_loss = model._calculate_loss(logits, y.to(model.device), aux_loss_)
                losses.append(total_loss.item())
        
        # The prediction is the index of the ending with the minimum loss
        prediction = torch.argmin(torch.tensor(losses)).item()
        
        predictions.append(prediction)
        references.append(correct_label)
        
    # 3. Compute the final score
    print("Computing final accuracy...")
    results = accuracy_metric.compute(predictions=predictions, references=references)
    return results

# hellaswag_results = evaluate_on_hellaswag(model, tokenizer, hellaswag_subset)
# hellaswag_results

Loading HellaSwag validation set and accuracy metric...


In [1]:
import lightning as L
from lightning.pytorch.strategies import DDPStrategy
torch._dynamo.config.capture_scalar_outputs = True
torch.set_float32_matmul_precision('medium')
from huggingface_hub import login

login(token='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') 

class GenerateTextCallback(L.Callback):
    """A PyTorch Lightning callback to generate text samples at the end of each validation epoch."""
    def __init__(self, prompts, tokenizer, every_n_steps=100):
        super().__init__()
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.every_n_steps = every_n_steps

    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.global_step == 0 or trainer.global_step % self.every_n_steps != 0:
            return

        if not trainer.is_global_zero:
            return  # only rank 0 prints/logs text
        pl_module.print(f"\n\n--- Generating text at step {trainer.global_step + 1} ---")
        tb = getattr(trainer.logger, "experiment", None)
        
        for prompt in self.prompts:
            start_tokens = self.tokenizer.encode(prompt)
            context = torch.tensor(start_tokens, dtype=torch.long, device=pl_module.device).unsqueeze(0)
            generated_tokens = pl_module.generate(context, max_new_tokens=100, temperature=0.8, top_k=20)
            generated_text = self.tokenizer.decode(generated_tokens[0].tolist())
            pl_module.print(f"PROMPT: '{prompt}'")
            pl_module.print(f"GENERATED: {generated_text}\n")
            if tb is not None and hasattr(tb, "add_text"):
                tb.add_text(f"samples/prompt_{i}", f"**Prompt:** {prompt}\n\n**Generated:** {text}",
                            global_step=trainer.global_step)
        # optional: flush
        if tb is not None and hasattr(tb, "flush"):
            tb.flush()

class EvaluateHellaSwag(L.Callback):
    """A PyTorch Lightning callback to evaluate the LLM."""
    def __init__(self, every_n_steps=100):
        super().__init__()
        self.every_n_steps = every_n_steps

    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.global_step == 0 or trainer.global_step % self.every_n_steps != 0:
            return

        # do heavy eval only on rank 0
        if not trainer.is_global_zero:
            return
        pl_module.print(f"\n\n--- Evaluating at step {trainer.global_step} ---")
        
        hellaswag_results = evaluate_on_hellaswag(model, tokenizer, hellaswag_subset)
        acc = hellaswag_results['accuracy']
        pl_module.print(f"\n\n--- Accuracy: {acc} at step {trainer.global_step} ---")
        pl_module.log("hellaswag/accuracy", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        tb = getattr(trainer.logger, "experiment", None)
        if tb is not None and hasattr(tb, "add_scalar"):
            tb.add_scalar("hellaswag/accuracy", acc, global_step=trainer.global_step)
            if hasattr(tb, "flush"):
                tb.flush()


callback = GenerateTextCallback(prompts=["The verdict was", "In a shocking turn of events", "The jury decided to"], 
    tokenizer=tokenizer, every_n_steps=100)
evalcallback = EvaluateHellaSwag(every_n_steps=100)

trainer = L.Trainer(max_steps=20_000, accelerator='auto', devices=8, precision='16-mixed', strategy='auto', 
                    num_sanity_val_steps=0, val_check_interval=100, check_val_every_n_epoch=None, limit_train_batches=100, limit_val_batches=100,
                    callbacks=[callback, L.pytorch.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=3), evalcallback],
                    logger=L.pytorch.loggers.TensorBoardLogger("logs/"), log_every_n_steps=1) 

model.automatic_optimization = False
trainer.fit(model, train_loader, val_loader)

NameError: name 'torch' is not defined

: 