In [4]:
# import and setup
# ------------------------------------------------

import os
import pickle
import torch
import urllib.request
from multiprocessing import Pool, cpu_count
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoTokenizer, AutoModelForCausalLM, logging as hf_logging
from torch.utils.data import DataLoader
import random
import copy
import re
import os
import numpy as np

import torch.nn as nn


import ReasoningChat
from ExactSampleDataset import ExactSampleDataset
from ExactSampleTrainLoop import train_exact
from Config import SimpleConfig

# Suppress warnings from Transformers
hf_logging.set_verbosity_error()

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [5]:
# Pytorch random seed idiom
# ------------------------------------------------

def set_random_seed(seed: int = 42):
    # Set the seed for Python's built-in random module
    random.seed(seed)
    # Set the seed for NumPy
    np.random.seed(seed)
    # Set the seed for PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # Ensure deterministic behavior in cuDNN (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_random_seed()

In [6]:
# Configuration and Settings
# ------------------------------------------------
class ExactSampleConfig(SimpleConfig):
    def __init__(self):
        super().__init__()
        self.learning_rate = 1e-6
        self.weight_decay = 0.1
        self.num_epochs = 1200
        self.max_new_tokens = 0.1
        self.context_len = 256  # Maximum sequence length
        self.num_batches = 1  # Batch size for DataLoader
        self.prompt_len = 829

    def get_training_layers(self, model):
        # get first and last layer's parameters
        params = list(model.parameters())
        # merge first and last layer's parameters
        return params # params[0:2]

In [7]:
# Logger Stub
# --------------------------
class LogLevel:
    ERROR = 40


class Logger:
    _instance = None

    def __init__(self):
        self.level = LogLevel.ERROR

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = Logger()
        return cls._instance

In [8]:
# Instantiate Config and Settings
# ------------------------------------------------

config = ExactSampleConfig()

Padding side: right


In [None]:
# Load Model and Tokenizer
# ------------------------------------------------
model, optimizer, tokenizer = config.load(force_new=True)
max_id = tokenizer.vocab_size
torch.autograd.set_detect_anomaly(True)
print("Padding side:", tokenizer.padding_side)

In [9]:
# Check config pad text
# ------------------------------------------------

config.pad

'[PAD]'

In [10]:
# TensorBoard SummaryWriter
# ------------------------------------------------

tb_log_dir = os.path.join(config.model_path, "tensorboard_logs")
writer = SummaryWriter(log_dir=tb_log_dir)

def collate_fn(batch):
    batch =  torch.tensor(batch)
    m = config.max_context_len # max([batch[i].shape[0] for i in range(len(batch))])
    pad = tokenizer.pad_token_id
    # fill padding tokens
    batch = [torch.cat([batch[i], torch.tensor([pad] * (m - batch[i].shape[0]), dtype=torch.long)]) for i in range(len(batch))]
    batch = torch.stack(batch, dim=0)
    return batch

dataset = ExactSampleDataset(tokenizer, config)
print("Dataset length:", len(dataset))


from torchsummary import summary

#summary(model, (2, 256, tokenizer.vocab_size))
#for name, param in model.named_parameters():
    #print(name, param.size())
config.num_batches

# of tokens in txt: 9844
Dataset length: 1693
Train Loader done


1

In [11]:
# Print model structure
# ------------------------------------------------

for name, param in model.named_parameters():
    print(name, param.size())

model.embed_tokens.weight torch.Size([49152, 3072])
model.layers.0.self_attn.q_proj.weight torch.Size([3072, 3072])
model.layers.0.self_attn.q_proj.bias torch.Size([3072])
model.layers.0.self_attn.k_proj.weight torch.Size([256, 3072])
model.layers.0.self_attn.k_proj.bias torch.Size([256])
model.layers.0.self_attn.v_proj.weight torch.Size([256, 3072])
model.layers.0.self_attn.v_proj.bias torch.Size([256])
model.layers.0.self_attn.o_proj.weight torch.Size([3072, 3072])
model.layers.0.self_attn.o_proj.bias torch.Size([3072])
model.layers.0.mlp.c_fc.weight torch.Size([12288, 3072])
model.layers.0.mlp.c_fc.bias torch.Size([12288])
model.layers.0.mlp.c_proj.weight torch.Size([3072, 12288])
model.layers.0.mlp.c_proj.bias torch.Size([3072])
model.layers.0.input_layernorm.weight torch.Size([3072])
model.layers.0.input_layernorm.bias torch.Size([3072])
model.layers.0.post_attention_layernorm.weight torch.Size([3072])
model.layers.0.post_attention_layernorm.bias torch.Size([3072])
model.layers.1.

In [12]:
# TensorBoard SummaryWriter
# Ref: https://blog.gopenai.com/coding-grpo-from-scratch-a-guide-to-distributed-implementation-with-qwen2-5-1-5b-instruct-59b34227edac
# ------------------------------------------------

def selective_log_softmax(logits, input_ids):
    log_probs = nn.functional.log_softmax(logits, dim=-1)
    return log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def compute_log_probs(model, input_ids, attention_mask, logits_to_keep):
    logits = model(input_ids=input_ids, attention_mask=attention_mask).logits[:, :-1, :]
    input_ids = input_ids[:, -logits_to_keep:]
    logits = logits[:, -logits_to_keep:, :]
    return selective_log_softmax(logits, input_ids)


In [14]:
# Load Model
# ------------------------------------------------

from datasets import load_dataset
starcoder_dataset = load_dataset("bigcode/the-stack-v2-train-smol-ids", split="train")

README.md:   0%|          | 0.00/21.0k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/64 [00:00<?, ?files/s]

train-00000-of-00064.parquet:   0%|          | 0.00/940M [00:00<?, ?B/s]

train-00001-of-00064.parquet:   0%|          | 0.00/918M [00:00<?, ?B/s]

train-00002-of-00064.parquet:   0%|          | 0.00/927M [00:00<?, ?B/s]

train-00003-of-00064.parquet:   0%|          | 0.00/925M [00:00<?, ?B/s]

train-00004-of-00064.parquet:   0%|          | 0.00/927M [00:00<?, ?B/s]

train-00005-of-00064.parquet:   0%|          | 0.00/930M [00:00<?, ?B/s]

train-00006-of-00064.parquet:   0%|          | 0.00/915M [00:00<?, ?B/s]

train-00007-of-00064.parquet:   0%|          | 0.00/914M [00:00<?, ?B/s]

train-00008-of-00064.parquet:   0%|          | 0.00/908M [00:00<?, ?B/s]

train-00009-of-00064.parquet:   0%|          | 0.00/923M [00:00<?, ?B/s]

train-00010-of-00064.parquet:   0%|          | 0.00/994M [00:00<?, ?B/s]

train-00011-of-00064.parquet:   0%|          | 0.00/921M [00:00<?, ?B/s]

train-00012-of-00064.parquet:   0%|          | 0.00/944M [00:00<?, ?B/s]

train-00013-of-00064.parquet:   0%|          | 0.00/920M [00:00<?, ?B/s]

train-00014-of-00064.parquet:   0%|          | 0.00/922M [00:00<?, ?B/s]

train-00015-of-00064.parquet:   0%|          | 0.00/963M [00:00<?, ?B/s]

train-00016-of-00064.parquet:   0%|          | 0.00/916M [00:00<?, ?B/s]

train-00017-of-00064.parquet:   0%|          | 0.00/938M [00:00<?, ?B/s]

train-00018-of-00064.parquet:   0%|          | 0.00/918M [00:00<?, ?B/s]

train-00019-of-00064.parquet:   0%|          | 0.00/926M [00:00<?, ?B/s]

train-00020-of-00064.parquet:   0%|          | 0.00/932M [00:00<?, ?B/s]

train-00021-of-00064.parquet:   0%|          | 0.00/921M [00:00<?, ?B/s]

train-00022-of-00064.parquet:   0%|          | 0.00/926M [00:00<?, ?B/s]

train-00023-of-00064.parquet:   0%|          | 0.00/937M [00:00<?, ?B/s]

train-00024-of-00064.parquet:   0%|          | 0.00/938M [00:00<?, ?B/s]

train-00025-of-00064.parquet:   0%|          | 0.00/917M [00:00<?, ?B/s]

train-00026-of-00064.parquet:   0%|          | 0.00/942M [00:00<?, ?B/s]

train-00027-of-00064.parquet:   0%|          | 0.00/930M [00:00<?, ?B/s]

train-00028-of-00064.parquet:   0%|          | 0.00/913M [00:00<?, ?B/s]

train-00029-of-00064.parquet:   0%|          | 0.00/920M [00:00<?, ?B/s]

train-00030-of-00064.parquet:   0%|          | 0.00/931M [00:00<?, ?B/s]

train-00031-of-00064.parquet:   0%|          | 0.00/914M [00:00<?, ?B/s]

train-00032-of-00064.parquet:   0%|          | 0.00/943M [00:00<?, ?B/s]

train-00033-of-00064.parquet:   0%|          | 0.00/923M [00:00<?, ?B/s]

train-00034-of-00064.parquet:   0%|          | 0.00/920M [00:00<?, ?B/s]

train-00035-of-00064.parquet:   0%|          | 0.00/928M [00:00<?, ?B/s]

train-00036-of-00064.parquet:   0%|          | 0.00/909M [00:00<?, ?B/s]

train-00037-of-00064.parquet:   0%|          | 0.00/919M [00:00<?, ?B/s]

train-00038-of-00064.parquet:   0%|          | 0.00/924M [00:00<?, ?B/s]

train-00039-of-00064.parquet:   0%|          | 0.00/913M [00:00<?, ?B/s]

train-00040-of-00064.parquet:   0%|          | 0.00/926M [00:00<?, ?B/s]

train-00041-of-00064.parquet:   0%|          | 0.00/932M [00:00<?, ?B/s]

train-00042-of-00064.parquet:   0%|          | 0.00/917M [00:00<?, ?B/s]

train-00043-of-00064.parquet:   0%|          | 0.00/910M [00:00<?, ?B/s]

train-00044-of-00064.parquet:   0%|          | 0.00/917M [00:00<?, ?B/s]

train-00045-of-00064.parquet:   0%|          | 0.00/917M [00:00<?, ?B/s]

train-00046-of-00064.parquet:   0%|          | 0.00/925M [00:00<?, ?B/s]

train-00047-of-00064.parquet:   0%|          | 0.00/911M [00:00<?, ?B/s]

train-00048-of-00064.parquet:   0%|          | 0.00/919M [00:00<?, ?B/s]

train-00049-of-00064.parquet:   0%|          | 0.00/930M [00:00<?, ?B/s]

train-00050-of-00064.parquet:   0%|          | 0.00/927M [00:00<?, ?B/s]

train-00051-of-00064.parquet:   0%|          | 0.00/927M [00:00<?, ?B/s]

train-00052-of-00064.parquet:   0%|          | 0.00/911M [00:00<?, ?B/s]

train-00053-of-00064.parquet:   0%|          | 0.00/928M [00:00<?, ?B/s]

train-00054-of-00064.parquet:   0%|          | 0.00/917M [00:00<?, ?B/s]

train-00055-of-00064.parquet:   0%|          | 0.00/929M [00:00<?, ?B/s]

train-00056-of-00064.parquet:   0%|          | 0.00/939M [00:00<?, ?B/s]

train-00057-of-00064.parquet:   0%|          | 0.00/924M [00:00<?, ?B/s]

train-00058-of-00064.parquet:   0%|          | 0.00/919M [00:00<?, ?B/s]

train-00059-of-00064.parquet:   0%|          | 0.00/910M [00:00<?, ?B/s]

train-00060-of-00064.parquet:   0%|          | 0.00/958M [00:00<?, ?B/s]

train-00061-of-00064.parquet:   0%|          | 0.00/934M [00:00<?, ?B/s]

train-00062-of-00064.parquet:   0%|          | 0.00/919M [00:00<?, ?B/s]

train-00063-of-00064.parquet:   0%|          | 0.00/986M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/40138809 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/185 [00:00<?, ?it/s]

In [10]:
# Train Exact Prompt sample
# ------------------------------------------------

import torch
import os
from transformers import AdamW
import torch.nn.functional as F

class CombinedDataset(Dataset):
    def __init__(self, dataset, starcoder_dataset):
        self.dataset = dataset
        self.starcoder_dataset = starcoder_dataset
        self.dataset_size = len(dataset)
        self.starcoder_size = len(starcoder_dataset)
        self.total_size = max(self.dataset_size, self.starcoder_size)
        self.multiplier = self.total_size // self.dataset_size
        self.remainder = self.total_size % self.dataset_size

    def __len__(self):
        return self.total_size

    def __getitem__(self, idx):
        if idx < self.starcoder_size:
            starcoder_sample = self.starcoder_dataset[idx]
        else:
            starcoder_sample = self.starcoder_dataset[idx % self.starcoder_size]

        dataset_idx = idx % self.dataset_size
        dataset_sample = self.dataset[dataset_idx]

        return dataset_sample, starcoder_sample

combined_dataset = dataset # CombinedDataset(dataset, starcoder_dataset)
train_loader = DataLoader(combined_dataset, batch_size=config.num_batches, num_workers=8,collate_fn=collate_fn) #shuffle=True,
print("Train Loader done")

def train_exact(
        model,           # The model being trained
        optimizer,       # Optimizer for training
        tokenizer,       # Tokenizer for encoding/decoding
        config,
        writer,
        chat,                   # Function to generate output
        train_loader,           # Training questions
        reward_fn = None,              # Function to calculate rewards
        group_size=16,          # Number of outputs to sample per question
        learning_rate=1e-6,     # Learning rate
        validation_batches=10,  # Number of batches to use for validation
        model_path=None,        # Path to save model checkpoints
        module_name="reasoning",
        epsilon=0.2,            # PPO clipping parameter
        beta=0.04               # KL penalty coefficient
):
    # Store the current policy as the old policy
    model.train()
    train_layers = config.get_training_layers(model)
    global_step = 0
    for epoch in range(config.num_epochs):
        print("Enter epch:", epoch)
        epoch_loss = 0.0
        optimizer.zero_grad()
        for idx, qa in enumerate(train_loader):
            print("Inner Loop:", idx)
            qa = qa.to(model.device)
            lengths = [qa[i].shape[0] for i in range(len(qa))]
            max_len = max(lengths)
            print("max length", lengths)
            print("max token value:", torch.max(qa).sum().item())
            prompt_len = config.prompt_len
            pad = config.pad
            cur_in = qa[:, 0:-1]
            cur_mask = torch.ones_like(cur_in).to(model.device)
            expected = qa[:, 1:]

            logits = model(input_ids=cur_in, attention_mask=cur_mask).logits
            logits = logits.transpose(1, 2)
            
            print(logits.shape) # torch.Size([1, 256, 32001])
            print(expected.shape) # torch.Size([1, 256])
            logits = logits[:, :, :]
            logits = logits.clamp(min=-100, max=100)
            #loss = torch.nn.functional.cross_entropy(logits, expected)
            loss = selective_log_softmax(logits, expected).sum()
            print("Shape loss:", loss.shape)

            if False: #torch.isnan(loss).any():
                print("Loss is NaN!")
                input_text = tokenizer.decode(cur_in[0], skip_special_tokens=True)
                print(f"Input text: {input_text}")
                output_text = tokenizer.decode(expected[0], skip_special_tokens=True)
                print(f"Output text: {output_text}")
                continue
                
            # Update policy model
            loss.backward()
            torch.nn.utils.clip_grad_norm_(train_layers, max_norm=0.1)
            optimizer.step()
            optimizer.zero_grad()

            writer.add_scalar("Train/BatchLoss", loss.item(), global_step)
            global_step += 1
            epoch_loss += loss.item()


        avg_epoch_loss = epoch_loss / len(train_loader)
        writer.add_scalar("Train/EpochLoss", avg_epoch_loss, epoch)
        print(f"Epoch {epoch} average training loss: {avg_epoch_loss:.4f}")
        
        # --------------------------
        # Validation Phase (with reasoning test)
        # --------------------------
        if epoch % 10 == 0:
            model.eval()
            with torch.no_grad():
                for idx, question_ids, answer_ids in enumerate(train_loader):
                    outputs = model(input_ids=question_ids, labels=answer_ids)
                    loss = outputs.loss

                    # Update policy model
                    loss.backward()
                    if idx % group_size == 0:
                        optimizer.step()

                    writer.add_scalar("Train/BatchLoss", loss.item(), global_step)
                    global_step += 1
                    epoch_loss += loss.item()

                avg_epoch_loss = epoch_loss / len(train_loader)
                writer.add_scalar("Train/EpochLoss", avg_epoch_loss, epoch)
                print(f"Epoch {epoch} average training loss: {avg_epoch_loss:.4f}")
            model.train()
            config.save(epoch, model, optimizer)

    return model

# --------------------------
# Training Loop with Reasoning Verification
# --------------------------
train_exact(model, optimizer, tokenizer, config, writer, ReasoningChat.chat, train_loader)
# Close the TensorBoard writer after training.
writer.close()