In [1]:
import torch
import torch.nn as nn
import math

# RMSNorm is a normalization technique that normalizes the input by dividing by the square root of the variance plus a small number to prevent division by zero
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-5): # the number of features/dimensions/embeddings in the input, eps is a small number to prevent division by zero
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size)) # weight is a learnable parameter that scales the input
        self.eps = eps

    def forward(self, x):
        norm = x.pow(2).mean(-1, keepdim=True).sqrt() + self.eps # compute the norm of the input
        return x / norm * self.weight # normalize the input by dividing by the norm and scale it by the weight parameter


# RotaryEmbedding is a technique that rotates the input by a learnable angle
class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.dim = dim

    def forward(self, seq_len, device):
        # Create position embeddings
        t = torch.arange(seq_len, device=device)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)

        # Create rotation matrices [seq_len, dim/2]
        emb = torch.cat([freqs, freqs], dim=-1)

        # [1, seq_len, 1, dim]
        cos = emb.cos().view(1, seq_len, 1, self.dim)
        sin = emb.sin().view(1, seq_len, 1, self.dim)

        return cos, sin

    def rotate_half(self, x):
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat([-x2, x1], dim=-1)

    def apply_rotary_pos_emb(self, x, cos, sin):
        # Expand cos/sin to match batch size and heads
        cos = cos.expand(x.shape[0], -1, x.shape[2], -1)
        sin = sin.expand(x.shape[0], -1, x.shape[2], -1)

        return (x * cos) + (self.rotate_half(x) * sin)

# This code is commented as new LlamaMLP to be created as per Mixture of Experts implementation
# class LlamaMLP(nn.Module):
#     def __init__(self, dim, hidden_dim):
#         super().__init__()
#         self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) # create the gate projection layer with the input dimension and the hidden dimension
#         self.up_proj = nn.Linear(dim, hidden_dim, bias=False) # create the up projection layer with the input dimension and the hidden dimension
#         self.down_proj = nn.Linear(hidden_dim, dim, bias=False) # create the down projection layer with the hidden dimension and the output dimension
#         self.act_fn = nn.SiLU() # create the activation function

#     def forward(self, x):
#         gated = self.gate_proj(x) # apply the gate projection to the input
#         hidden = self.up_proj(x) # apply the up projection to the input
#         return self.down_proj(self.act_fn(gated * hidden)) # apply the activation function to the gated and hidden values and then apply the down projection

class LlamaMLP(nn.Module):
    def __init__(self, dim, hidden_dim, num_experts, num_shared_experts, top_k):
        super().__init__()
        self.moe = DeepSeekMoE(
            dim=dim,
            hidden_dim=hidden_dim,
            num_experts=num_experts,
            num_shared_experts=num_shared_experts,
            top_k=top_k
        )
    def forward(self, x):
        return self.moe(x)

class DeepSeekExpertLayer(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.gate_proj = nn.Linear(dim, hidden_dim,bias=False)
        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

class DeepSeekMoE(nn.Module):
    def __init__(self, dim, hidden_dim, num_experts, num_shared_experts, top_k):
        super().__init__()
        self.num_experts = num_experts
        self.num_shared_experts = num_shared_experts
        self.num_routed_experts = num_experts - num_shared_experts
        self.top_k = top_k
        self.dim = dim

        # Shared experts
        self.shared_experts = nn.ModuleList([
            DeepSeekExpertLayer(dim, hidden_dim)
            for _ in range(self.num_shared_experts)
        ])

        # Routed Experts
        self.routed_experts = nn.ModuleList([
            DeepSeekExpertLayer(dim, hidden_dim)
            for _ in range(self.num_routed_experts)
        ])

        # Routed Components
        self.router = nn.Linear(dim, self.num_routed_experts, bias=False)
        self.routing_bias = nn.Parameter(torch.zeros(self.num_routed_experts))

    def forward(self, x):
        batch_size, seq_len, dim = x.shape

        shared_output = sum(expert(x) for expert in self.shared_experts)
        if self.num_shared_experts > 1:
            shared_output = shared_output / self.num_shared_experts

        # calculating routing scores
        routing_logits = self.router(x) * self.routing_bias

        # get top-k experts per token
        routing_probs = torch.sigmoid(routing_logits)
        scores, indices = torch.topk(routing_probs, self.top_k, dim=-1)

        #normalize the top k scores
        scores = scores / scores.sum(dim=-1, keepdim=True)

        # Process through selected experts
        combined_output = torch.zeros_like(x)
        for k in range(self.top_k):
            expert_indices = indices[..., k]
            expert_scores = scores[..., k:k+1]

            #process each expert
            for i in range(self.num_routed_experts):
                mask = (expert_indices == i)
                if mask.any():
                    expert_input = x[mask]
                    expert_output = self.routed_experts[i](expert_input)
                    combined_output[mask] += expert_output * expert_scores[mask]

        #combine shared and routed outputs
        final_output = shared_output + combined_output
        return final_output

    def update_bias_terms(self, expert_load):
        target_load = 1.0 / self.num_routed_experts
        load_diff = expert_load - target_load

        update_rate = 0.1 * torch.abs(load_diff)

        self.routing_bias.data -= update_rate * load_diff

class LlamaAttention(nn.Module):
    def __init__(self, dim, num_heads, compress_ratio):
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        self.latent_dim = dim // compress_ratio

        # Decomposed projections for latent attention
        self.q_proj_d = nn.Linear(dim, self.latent_dim, bias=False)  # Down projection for Q
        self.kv_proj_d = nn.Linear(dim, self.latent_dim, bias=False)  # Down projection for K,V

        half_head_dim = self.head_dim // 2
        # Up projections from latent space
        self.q_proj_u = nn.Linear(self.latent_dim, num_heads * half_head_dim, bias=False)
        self.k_proj_u = nn.Linear(self.latent_dim, num_heads * half_head_dim, bias=False)
        self.v_proj_u = nn.Linear(self.latent_dim, dim, bias=False)

        # Rotary components
        self.rotary_emb = LlamaRotaryEmbedding(dim=half_head_dim)

        # Output projection
        self.o_proj = nn.Linear(dim, dim, bias=False)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        half_head_dim = self.head_dim // 2

        # Project to latent space
        q_latent = self.q_proj_d(x)
        kv_latent = self.kv_proj_d(x)

        # Project up from latent space
        q = self.q_proj_u(q_latent)
        k = self.k_proj_u(kv_latent)
        v = self.v_proj_u(kv_latent)

        # Reshape for attention
        q = q.view(batch_size, seq_len, self.num_heads, half_head_dim)
        k = k.view(batch_size, seq_len, self.num_heads, half_head_dim)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim)

        # Apply rotary embeddings
        cos, sin = self.rotary_emb(seq_len, x.device)
        q = self.rotary_emb.apply_rotary_pos_emb(q, cos, sin)
        k = self.rotary_emb.apply_rotary_pos_emb(k, cos, sin)

        # Prepare for attention
        q = q.transpose(1, 2)  # [batch, heads, seq, half_head_dim]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Compute attention with scaled dot product
        scale = 1 / math.sqrt(half_head_dim)
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale
        attn = torch.softmax(attn, dim=-1)

        # Apply attention to values
        out = torch.matmul(attn, v)

        # Reshape and project output
        out = out.transpose(1, 2).contiguous()  # [batch, seq, heads, head_dim]
        out = out.reshape(batch_size, seq_len, self.dim)

        return self.o_proj(out)


        # previous working code
        # q = self.q_proj(x)
        # k = self.k_proj(x)
        # v = self.v_proj(x)

        # # Split heads
        # q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [batch_size, num_heads, seq_len, head_dim]
        # k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # # Scaled dot-product attention
        # scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        # attention = torch.softmax(scores, dim=-1)
        # context = torch.matmul(attention, v)

        # # Combine heads
        # context = context.transpose(1, 2).reshape(batch_size, seq_len, dim)
        # return self.o_proj(context)

class LlamaDecoderLayer(nn.Module):
    def __init__(self, dim, hidden_dim, num_heads, compress_ratio, num_experts, num_shared_experts, top_k):
        super().__init__()
        self.self_attn = LlamaAttention(dim, num_heads, compress_ratio=3)
        self.mlp = LlamaMLP(
            dim=dim,
            hidden_dim=hidden_dim,
            num_experts=num_experts,
            num_shared_experts=num_shared_experts,
            top_k=top_k)
        self.input_layernorm = LlamaRMSNorm(dim)
        self.post_attention_layernorm = LlamaRMSNorm(dim)

    def forward(self, x):
        residual = x
        x = self.input_layernorm(x)
        x = self.self_attn(x)
        x = x + residual

        residual = x
        x = self.post_attention_layernorm(x)
        x = self.mlp(x)
        x = x + residual
        return x


class LlamaModel(nn.Module):
    def __init__(self, vocab_size, dim, num_layers, hidden_dim, num_heads):
        super().__init__()
        self.embed_tokens = nn.Embedding(vocab_size, dim)
        self.layers = nn.ModuleList([
            LlamaDecoderLayer(dim, hidden_dim, num_heads, compress_ratio=3, num_experts=8, num_shared_experts=1, top_k=2) for _ in range(num_layers)
        ])
        self.norm = LlamaRMSNorm(dim)
        #self.rotary_emb = LlamaRotaryEmbedding(dim)

    def forward(self, x):
        x = self.embed_tokens(x)
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

class LlamaForCausalLM(nn.Module):
    def __init__(self, vocab_size, dim, num_layers, hidden_dim, num_heads):
        super().__init__()
        self.model = LlamaModel(vocab_size, dim, num_layers, hidden_dim, num_heads)
        self.num_heads = num_heads
        # Share weights between embedding and lm_head
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)
        # Tie weights
        self.lm_head.weight = self.model.embed_tokens.weight

    def forward(self, x):
        x = self.model(x)
        return self.lm_head(x)

def get_model(tokenizer):
    vocab_size = tokenizer.vocab_size  # Use actual tokenizer vocab size
    return LlamaForCausalLM(
        vocab_size=vocab_size,
        dim=576,
        num_layers=30,
        hidden_dim=1536,
        num_heads=9
    )

In [2]:
pip install datasets

Collecting datasets
  Downloading datasets-3.4.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.4.0-py3-none-any.whl (487 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.4/487.4 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading

In [3]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, get_scheduler
from torch.optim import AdamW
import wandb
import os, sys
import time

wandb.init(project="smollm-training", name="llama-smollm-corpus", mode="offline")


BATCH_SIZE = 4
ACCUMULATION_STEPS = 8

SEQ_LEN = 256
LEARNING_RATE = 1e-4
EPOCHS = 5
WARMUP_STEPS = 1000
GRADIENT_CLIP_VAL = 0.5
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
DEVICE = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)


def generate_text(
    model, tokenizer, prompt, max_length=50, temperature=0.7, top_k=50, device=DEVICE):
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(input_ids)
            next_token_logits = outputs[:, -1, :] / temperature

            # Apply top-k sampling
            top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
            probs = torch.softmax(top_k_logits, dim=-1)

            # Sample from the filtered distribution
            next_token_idx = torch.multinomial(probs, num_samples=1)
            next_token = top_k_indices[0, next_token_idx[0]]

            if next_token.item() == tokenizer.eos_token_id:
                break

            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)

    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    model.train()
    return generated_text


def save_checkpoint(model, optimizer, scheduler, epoch, step, loss, path):
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
            "loss": loss,
            "step": step,
        },
        path,
    )


def load_checkpoint(path, model, optimizer, scheduler):
    if os.path.exists(path):
        # path = './checkpoints/checkpoint_step_5000.pt'
        # print(f"Loading checkpoint from {path}")
        checkpoint = torch.load(path, weights_only=True)

        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        if scheduler and checkpoint["scheduler_state_dict"]:
            scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        return checkpoint["epoch"], checkpoint["step"]
    return 0, 0


def count_parameters(model):
    """Count the number of trainable parameters in the model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
if tokenizer.pad_token is None:
    if tokenizer.eos_token:
        tokenizer.pad_token = tokenizer.eos_token
    else:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        tokenizer.resize_token_embeddings(len(tokenizer))

dataset = load_dataset(
    "HuggingFaceTB/smollm-corpus", "cosmopedia-v2", streaming=True, split="train"
)


def tokenize_function(examples):
    return tokenizer(
        examples["text"], truncation=True, max_length=SEQ_LEN, padding="max_length"
    )


tokenized_dataset = dataset.map(tokenize_function, batched=True)


def collate_fn(batch):
    input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
    attention_mask = torch.tensor(
        [item["attention_mask"] for item in batch], dtype=torch.long
    )
    labels = input_ids.clone()
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}


train_loader = DataLoader(
    tokenized_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn
)

# Initialize model, optimizer, and scheduler
model = get_model(tokenizer)
model.to(DEVICE)

# Print model parameters
total_params = count_parameters(model)
print(f"\nModel Statistics:")
print(f"Total Parameters: {total_params:,}")
print(f"Model Size: {total_params * 4 / (1024 * 1024):.2f} MB")  # Assuming float32 (4 bytes)
print(f"Device: {DEVICE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Accumulation Steps: {ACCUMULATION_STEPS}")
print(f"Sequence Length: {SEQ_LEN}")
print(f"Learning Rate: {LEARNING_RATE}")
print("-" * 50 + "\n")


optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=LEARNING_RATE,
    total_steps=20000,
    pct_start=0.1,
    anneal_strategy="cos",
    cycle_momentum=False,
)

# Load checkpoint if exists
start_epoch, global_step = load_checkpoint(
    f"{CHECKPOINT_DIR}/latest_checkpoint.pt",
    #f"{CHECKPOINT_DIR}/interrupted_checkpoint.pt",
    model,
    optimizer,
    lr_scheduler
)

# Sample prompts for evaluation
sample_prompts = [
    "Particles in Action. Have you ever imagined being able to see tiny particles that zoom around us at incredible speeds? Welcome to the world of particle physics! ",
    "Developing number sense is a critical aspect of mathematics education that involves helping students understand numbers, their relationships, and operations involving them. ",
    "All parts of the coriander plant are edible - including its leaves, its fruits, its seeds and its roots. However, the fresh leaves and the dried seeds score over the other two, and are the most commonly employed in cooking. ",
    "There are several foods that can help boost your metabolism and promote calorie burning, thanks to their unique nutritional profiles. ",
    "Are you looking for vegan sandwich recipes? We’ve rounded up 21 of our favorite vegan sandwich ideas that you will want to make right now. ",
]

## TODO: The BELOW code needs to inluded
# expert_load = torch.zeros(model.model.layers[0].mlp.moe.num_routed_experts, device=DEVICE)
# for k in range(model.model.layers[0].mlp.moe.top_k):
#     routing_logits = model.model.layers[0].mlp.moe.router(input_ids) + model.model.layers[0].mpl.moe.routing_bias
#     routing_probs = torch.sigmoid(routing_logits)
#     _, indices = torch.topk(routing_probs, model.model.layers[0].mpl.moe.top_k, dim=-1)
#     for i in range(model.model.layers[0].mlp.moe.num_routed_experts):
#         expert_load[i] += (indices[..., k] == i).sum()

# expert_load = expert_load / (input_ids.size(0) * input_ids.size(1) * model.model.layers[0].mlp.moe.top_k)

# model.model.layers[0].mlp.moe.update_bias_terms(expert_load)

## TODO: The ABOVE code need to be include

model.train()
try:
    for epoch in range(start_epoch, EPOCHS):
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        optimizer.zero_grad()  # Zero gradients at start of epoch

        for step, batch in enumerate(train_loader, start=global_step):
            # Move batch to device
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            # Calculate expert load and update routing bias every 100 steps
            if step % 100 == 0:
                with torch.no_grad():
                    # Get initial hidden states
                    hidden_states = model.model.embed_tokens(input_ids)

                    # Update expert load for each layer
                    for layer_idx, layer in enumerate(model.model.layers):
                        # Initialize expert load tensor
                        expert_load = torch.zeros(layer.mlp.moe.num_routed_experts, device=DEVICE)

                        # Get routing probabilities
                        routing_logits = layer.mlp.moe.router(hidden_states)
                        routing_probs = torch.sigmoid(routing_logits + layer.mlp.moe.routing_bias)

                        # Get top-k expert indices
                        _, indices = torch.topk(routing_probs, k=layer.mlp.moe.top_k, dim=-1)

                        # Calculate load for each expert
                        for i in range(layer.mlp.moe.num_routed_experts):
                            expert_load[i] = (indices == i).sum().float()

                        # Normalize the expert load
                        total_tokens = input_ids.size(0) * input_ids.size(1)
                        expert_load = expert_load / (total_tokens * layer.mlp.moe.top_k)

                        # Update routing bias
                        layer.mlp.moe.update_bias_terms(expert_load)

                        # Log expert utilization
                        for i, load in enumerate(expert_load):
                            wandb.log({
                                f"layer_{layer_idx}_expert_{i}_load": load.item(),
                                "step": step
                            })

                        # Process hidden states through the layer for next iteration
                        hidden_states = layer.input_layernorm(hidden_states)
                        hidden_states = layer.self_attn(hidden_states)
                        hidden_states = layer.post_attention_layernorm(hidden_states)

            # Forward pass
            outputs = model(input_ids)
            logits = outputs.view(-1, tokenizer.vocab_size)

            # Calculate loss
            loss = torch.nn.functional.cross_entropy(
                logits, labels.view(-1), label_smoothing=0.1
            )

            # Scale loss by accumulation steps
            scaled_loss = loss / ACCUMULATION_STEPS

            # Backward pass
            scaled_loss.backward()

            current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

            # Logging (show unscaled loss for monitoring)
            if step % 10 == 0:
                print(
                    f"Step {step}, Loss: {loss.item():.4f}, "  # Original loss
                    f"Scaled Loss: {scaled_loss.item():.4f}, "  # Scaled loss
                    f"LR: {lr_scheduler.get_last_lr()[0]:.2e}, "
                    f"Accumulation Step: {(step + 1) % ACCUMULATION_STEPS}/{ACCUMULATION_STEPS}, "
                    f"Current Time: {current_time} "
                )
                wandb.log({
                    "loss": loss.item(),  # Log original loss
                    "scaled_loss": scaled_loss.item(),  # Log scaled loss
                    "lr": lr_scheduler.get_last_lr()[0],
                    "step": step,
                    "epoch": epoch,
                })

            # Update weights after accumulation steps
            if (step + 1) % ACCUMULATION_STEPS == 0:
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VAL)

                # Optimizer step
                optimizer.step()

                # Update learning rate
                lr_scheduler.step()

                # Zero gradients
                optimizer.zero_grad()

                # print(f"\nCompleted gradient accumulation step {step + 1}")

            # Save checkpoint every 100 steps (actual updates)
            if step % 100 == 0 and step != 0:
                save_checkpoint(
                    model,
                    optimizer,
                    lr_scheduler,
                    epoch,
                    step,
                    loss.item(),  # Save original loss
                    f"{CHECKPOINT_DIR}/latest_checkpoint.pt",
                )

            # Generate sample text every 500 steps
            if step != 0 and step % 500 == 0:
                model.eval()
                print("\n=== Generating Sample Texts ===")

                # Save model state for generation
                generation_state = model.state_dict()

                for temp in [1.0]:  # Added back temperature variation
                    for prompt in sample_prompts:
                        generated = generate_text(
                            model,
                            tokenizer,
                            prompt,
                            temperature=temp,
                            max_length=100,
                        )
                        print(f"\nPrompt: {prompt}")
                        print(f"Temperature: {temp}")
                        print(f"Generated: {generated}")
                        wandb.log({
                            f"generated_text_temp_{temp}_{prompt[:20]}": wandb.Html(generated),
                            "step": step
                        })

                print("\n=== End of Samples ===\n")
                model.train()

        # Save epoch checkpoint
        save_checkpoint(
            model,
            optimizer,
            lr_scheduler,
            epoch,
            step,
            loss.item(),  # Save original loss
            f"{CHECKPOINT_DIR}/checkpoint_epoch_{epoch+1}.pt",
        )

except KeyboardInterrupt:
    print("\nTraining interrupted! Saving checkpoint...")
    save_checkpoint(
        model,
        optimizer,
        lr_scheduler,
        epoch,
        step,
        loss.item(),  # Save original loss
        f"{CHECKPOINT_DIR}/interrupted_checkpoint.pt",
    )

print("Training complete!")
wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/3.91k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/801k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/466k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.10M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/489 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/7.05k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]


Model Statistics:
Total Parameters: 688,702,098
Model Size: 2627.19 MB
Device: cuda
Batch Size: 4
Accumulation Steps: 8
Sequence Length: 256
Learning Rate: 0.0001
--------------------------------------------------

Epoch 1/5
Step 0, Loss: 46.9087, Scaled Loss: 5.8636, LR: 4.00e-06, Accumulation Step: 1/8, Current Time: 2025-03-15 17:18:56 


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 16.12 MiB is free. Process 6808 has 14.72 GiB memory in use. Of the allocated memory 12.75 GiB is allocated by PyTorch, and 1.84 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)