In [1]:
# Additional modules for docker environment
# ------------------------------------------------

from pathlib import Path
FLAG_FILE = Path("tmp/.3.1.Resonsing_Training") 
if not FLAG_FILE.exists():
    !mkdir tmp
    !pip install datasets torch_optimizer lion_pytorch clang_repl_kernel --break-system-packages
    !pip install --upgrade clang-repl-kernel --break-system-packages
    !pip install torch_tb_profiler
    from ClangReplInterface import ClangReplInterface
    clang_repl = ClangReplInterface()
    clang_repl.kernel.my_shell.del_loop()
    clang_repl.kernel.my_shell.process.kill()
    clang_repl = None
    FLAG_FILE.touch()

In [2]:
# Check CPU load
# ------------------------------------------------

import os

# Read the 1-minute load average from /proc/loadavg
try:
    with open("/proc/loadavg", "r") as f:
        load_avg = float(f.read().split()[0])
except Exception as e:
    print("Error reading /proc/loadavg:", e)
else:
    cores = os.cpu_count() or 1
    load_per_core = load_avg / cores
    print(f"Total 1-min load: {load_avg:.2f}")
    print(f"Number of cores: {cores}")
    print(f"Load per core: {load_per_core:.2f}")


Total 1-min load: 1.21
Number of cores: 32
Load per core: 0.04


In [3]:
# Check GPU load
# ------------------------------------------------

!nvidia-smi

Sun Apr  6 10:13:12 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.124.04             Driver Version: 570.124.04     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A6000               Off |   00000000:09:00.0 Off |                  Off |
| 46%   71C    P0            104W /  300W |       4MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
# import and setup
# ------------------------------------------------

import copy
import gc
import inspect
import json
import math
import os
import pickle
import random
import re
import shutil
import signal
import warnings
from datetime import datetime
from enum import Enum

import bitsandbytes as bnb
import numpy as np

# For hyperparameter optimization
import optuna
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from ClangReplInterface import ClangReplInterface, ObjectPool
from Config import SimpleConfig
from datasets import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import LambdaLR
from torch.profiler import (
    ProfilerActivity,
    profile,
    record_function,
    tensorboard_trace_handler,
)
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
)
from transformers.optimization import Adafactor

# to clean up log
warnings.filterwarnings("ignore", category=FutureWarning)

# for memory
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# huggingface tokenizer deadlock workaround
os.environ["TOKENIZERS_PARALLELISM"] = "false"



In [5]:
# Pytorch random seed idiom
# ------------------------------------------------

def set_random_seed(seed: int = 42):
    # Set the seed for Python's built-in random module
    random.seed(seed)
    # Set the seed for NumPy
    np.random.seed(seed)
    # Set the seed for PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # Ensure deterministic behavior in cuDNN (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_random_seed()

In [6]:
# Overall configuration
# ------------------------------------------------



config = SimpleConfig()
max_length = 512
MAX_REWARD = 4.4


class TrainLayers(Enum):
    FULL_LAYER = 1
    TWO_FRONT_LAYER = 2
    ODD_LAYER = 3
    EVEN_LAYER = 4
    SWITCH_PAIR_LAYER = 5


log_content = False
log_step = False
log_memory = False
log_logits = False
log_reward = False
log_tensor = False

checking_range = False
checking_shape = False


class HiperParam:
    def __init__(self):
        self.train_layer = TrainLayers.FULL_LAYER
        self.group_size = 4
        self.batch_size = 1  # 7
        self.category_count_start = 1  # 9

        self.num_epochs = 200
        self.lr = 3.131e-05
        self.kl_lambda = 8
        self.epsilon = 0.26207
        self.num_grpo = 1
        self.save_epochs = 10
        self.warming_up_step = 1
        self.temperature = 0.53715
        self.validation_interval = 3
        self.expected_meean_reward = 2.0

        self.try_old_model_update_in_batch_loop_update_divider = 2
        self.try_old_model_update_in_batch_loop_count = 0
        
    def __str__(self):
        return (
            f"lr: {self.lr}, kl_lambda: {self.kl_lambda}, "
            f"epsilon: {self.epsilon}, num_grpo: {self.num_grpo}, "
            f"warming_up_step: {self.warming_up_step}, "
            f"temperature: {self.temperature}, category_count_start: {self.category_count_start}"
        )
    def essence_str(self):
        lr_str = f"{self.lr:.5g}"
        kl_lambda_str = f"{self.kl_lambda:.5g}"
        epsilon_str = f"{self.epsilon:.5g}"
        temperature_str = f"{self.temperature:.5g}"
        return f"{lr_str}_{kl_lambda_str}_{epsilon_str}_{self.num_grpo}_{temperature_str}"
        
    def choose_generator(self, cur, old, ref):
        return cur;

    def try_old_model_update_in_batch_loop(self, model, old_model, batch_step, batch_len, cur_category_count):
        update_count = min(1, int((batch_len//cur_category_count)//self.try_old_model_update_in_batch_loop_update_divider))
        if (self.try_old_model_update_in_batch_loop_count%update_count == 0) or old_model is None:
            old_model = copy_inference_model(model)
        self.try_old_model_update_in_batch_loop_count += 1
        return old_model
            
    def get_optimizer(self, params, lr):
        self.optimizer = Adafactor(params, lr=lr, relative_step=False, scale_parameter=False)
        #self.optimizer = torch.optim.AdamW(params, lr=lr)
        return self.optimizer
        
    def tensor_summary(self):
        return {
            "lr": self.lr,
            "kl_lambda": self.kl_lambda,
            "epsilon": self.epsilon,
            "num_grpo": self.num_grpo,
            "warming_up_step": self.warming_up_step,
            "temperature": self.temperature,
            "category_count_start": self.category_count_start,
            "optimzier": type(self.optimizer).__name__,
        }



class State:
    def __init__(self, prefix, hparam):
        self.global_step = 0
        self.val_global_step = 0

        self.use_reference_model = True
        self.skip_validation_step = False
        self.is_finding_opt = False
        self.is_profile = False
        self.is_original_struct = False
        
        self.log_prefix = prefix
        self.ref_checkpoint_path = "./saved_models/starcoder2-3b_exact_sample/checkpoint.pt"
        self.test_target_object_file = "manual_data_set/ReasoningTestTarget.json"
        self.learning_name = f"starcoder2-3b_reasoning_{prefix}_{hparam.train_layer.name}_"
        self.now = datetime.now().strftime("%d_%H-%M")
        self.last_checkpoint_path = f"./saved_models/{self.learning_name}/checkpoint.pt"
        self.checkpoint_dir_pre = f"./saved_models/{self.learning_name}/epoch_"
        
        self.model_id = "bigcode/starcoder2-3b"
        self.tokenizer_save_dir = "./saved_models/tokenizer"
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # open next url to see profile: http://192.168.1.117:6006/#pytorch_profiler
        self.prof = profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1),
            on_trace_ready=tensorboard_trace_handler("prof_" + log_dir),
            record_shapes=True,
            profile_memory=True,
            with_stack=True,
            ) if self.is_profile else None
        
    def prof_start(self):
        if self.prof is not None:
            self.prof.start()
            
    def prof_step(self):
        if self.prof is not None:
            self.prof.step()            

    def prof_start(self):
        if self.prof is not None:
            self.prof.stop()    

hparam = HiperParam()
state = State("Adafactor", hparam)


def object_hiper_param(trial):
    global hparam
    # Shortened training for demonstration:
    hparam.num_epochs = 1  # 4  # 2   # or 2–3, to save time during hyperparameter search

    # Sample hyperparameters
    hparam.lr = 3.131e-05  # trial.suggest_float("lr", 1e-6, 1e-3, log=True) # 5e-6
    hparam.kl_lambda = trial.suggest_float("kl_lambda", 17, 32)  # 0.04 from
    hparam.epsilon = 0.26207  # trial.suggest_float("epsilon", 0.01, 0.3) # 0.1
    hparam.num_grpo = 0.5  # 1 #trial.suggest_int("num_grpo", 1, 2, step=1) # 1
    hparam.warming_up_step = 1  # trial.suggest_int("warming_up_step", 1, 1, step=1)
    hparam.temperature = trial.suggest_float("temperature", 0.5, 1.4)  # 1.0
    hparam.category_count_start = 1  # trial.suggest_int("category_count_start", 1, 2, step=1)

    return hparam


def shrink_dataset(reasoning_dataset, val_reasoning_dataset):
    reasoning_dataset = reasoning_dataset
    val_reasoning_dataset = val_reasoning_dataset
    return reasoning_dataset, val_reasoning_dataset

In [7]:
# Logging or Print
# ------------------------------------------------

previous_tensor_info = {}


def print_memory(tag):
    global previous_tensor_info, log_memory
    if not log_memory:
        return
    # Make sure you have a GPU device available.
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Print current allocated and reserved memory in MB:
    allocated = torch.cuda.memory_allocated(device) / (1024**2)
    reserved = torch.cuda.memory_reserved(device) / (1024**2)
    print(tag)
    print(f"Memory allocated: {allocated:.2f} MB")
    print(f"Memory reserved: {reserved:.2f} MB")

    if False:
        print(torch.cuda.memory_summary(device=None, abbreviated=False))

    if True:
        # List all CUDA tensors and collect their details in a dictionary.
        cuda_tensors = list_cuda_tensors()
        current_tensor_info = {
            id(tensor): (tensor.shape, tensor.device) for tensor in cuda_tensors
        }

        # Determine new and deleted tensor IDs.
        current_tensor_ids = set(current_tensor_info.keys())
        previous_tensor_ids = set(previous_tensor_info.keys())

        print(
            "Total Tensors:",
            len(current_tensor_ids),
            ", Changes:",
            len(current_tensor_ids) - len(previous_tensor_ids),
        )

        if False:
            # Determine new tensors since the last call.
            new_tensor_ids = current_tensor_ids - previous_tensor_ids
            deleted_tensor_ids = previous_tensor_ids - current_tensor_ids
            if new_tensor_ids:
                print("New CUDA tensors created since the last call:")
                for tid in new_tensor_ids:
                    shape, dev = current_tensor_info[tid]
                    print(f"Tensor id: {tid} | Shape: {shape} | Device: {dev}")

            if deleted_tensor_ids:
                print("Deleted CUDA tensors since the last call:")
                for tid in deleted_tensor_ids:
                    shape, dev = previous_tensor_info[tid]
                    print(f"Tensor id: {tid} | Shape: {shape} | Device: {dev}")
        previous_tensor_info = current_tensor_info.copy()


def cur_memory_ids():
    # List all CUDA tensors and collect their details in a dictionary.
    cuda_tensors = list_cuda_tensors()
    current_tensor_info = {
        id(tensor): (tensor.shape, tensor.device) for tensor in cuda_tensors
    }

    # Determine new and deleted tensor IDs.
    current_ids = set(current_tensor_info.keys())
    return current_ids


def compare_memory_ids(previous_tensor_ids):
    current_tensor_ids = cur_memory_ids()
    new_tensor_ids = current_tensor_ids - previous_tensor_ids
    deleted_tensor_ids = previous_tensor_ids - current_tensor_ids
    print("New tensor:", new_tensor_ids)
    print("Delted tensor:", deleted_tensor_ids)


def list_cuda_tensors():
    cuda_tensors = []
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) and obj.is_cuda:
                cuda_tensors.append(obj)
        except Exception:
            pass  # Some objects might not have the attributes we need.
    return cuda_tensors


def print_step(tag, main_step=False):
    global log_memory
    if log_memory:
        print_memory(tag)
    else:
        if log_step or main_step:
            print(tag)


def check_optimizer_duplicates(optimizer):
    seen_ids = set()
    duplicates = []
    for group in optimizer.param_groups:
        for param in group["params"]:
            pid = id(param)
            if pid in seen_ids:
                duplicates.append(param)
            else:
                seen_ids.add(pid)
    return duplicates


def print_logits_ids(tag, logits, ids):
    global log_logits
    if log_logits:
        logits_len = logits.shape[1]
        ids_len = ids.shape[1]
        if True:
            logits_ids = torch.argmax(logits, dim=-1)
            ids_text = [
                tokenizer.decode(ids[i], skip_special_tokens=True)
                for i in range(ids.size(0))
            ]
            logits_text = [
                tokenizer.decode(logits_ids[i], skip_special_tokens=True)
                for i in range(logits_ids.size(0))
            ]

            print(
                "##### ", tag, "( logits_len:", logits_len, ", ids_len:", ids_len, " )"
            )
            print(
                "First five logits_ids:",
                logits_ids[0][:5].tolist(),
                ", First five ids:",
                ids[0][:5].tolist(),
            )
            print("###### logit text:", logits_text[0][:100])
            print("###### ids_text:", ids_text[0][:100])
    print_tensor(logits, name=tag + "(logits)")
    print_tensor(ids, name=tag + "(ids)")


def print_tensor(tensor, name=None):
    if log_tensor:
        if name is None:
            # Try to infer the variable name from the caller's local variables.
            frame = inspect.currentframe().f_back
            # Look for local variables that are the same object as tensor.
            names = [
                var_name
                for var_name, var_val in frame.f_locals.items()
                if var_val is tensor
            ]
            name = names[0] if names else "tensor"
        if not torch.is_floating_point(tensor):
            mean_val = tensor.float().mean().item()
        else:
            mean_val = tensor.mean().item()
        print(
            name,
            tensor.shape,
            "(min=",
            tensor.min().item(),
            ", avg=",
            mean_val,
            ", max=",
            tensor.max().item(),
            ")",
        )

    return tensor


def match_shape(actual, expected):
    if len(actual) != len(expected):
        return False
    return all(e == a or e is None or e == -1 for a, e in zip(actual, expected))


def check_shape(self, expected_shape):
    if checking_shape:
        if not match_shape(self.shape, expected_shape):
            raise ValueError(
                f"Shape mismatch! Got {self.shape}, expected {expected_shape}"
            )
    return self


def check_range(self, min_val, max_val, not_values=None):
    if checking_range:
        # Range check
        in_range = (self >= min_val) & (self <= max_val)

        # Optional exclusion check
        if not_values is not None:
            for v in not_values:
                in_range &= self != v

        if not torch.all(in_range):
            raise ValueError(
                f"Tensor check_range failed: values not in range [{min_val}, {max_val}] or contain excluded {not_values}"
            )

    return self


torch.Tensor.log = print_tensor
torch.Tensor.check_shape = check_shape
torch.Tensor.check_range = check_range


def samping(model, tokenizer, device, epoch, writer, sample_prompt, expected):
    # Include attention_mask in the tokenization
    sample_prompt = f"### Instruction\n\n{sample_prompt}\n\n### Response"
    inputs = tokenizer(sample_prompt, return_tensors="pt", return_attention_mask=True)
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    # Pass the attention_mask and explicitly set pad_token_id to eos_token_id for reliable generation
    full_ids = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=20,
        pad_token_id=tokenizer.eos_token_id,
    )
    sample_text = tokenizer.decode(full_ids[0], skip_special_tokens=True)

    sample_text = sample_text.strip()
    if log_content:
        print(f"Sample Output (Epoch {epoch + 1}): {sample_text}")
    if log_content:
        print("Expected:", expected)
    writer.add_text("Sample Output", f"Epoch {epoch + 1}: {sample_text}", epoch)


def write_time_file(folder):
    now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    filename = f"{now}.txt"
    os.makedirs(folder, exist_ok=True)

    filepath = os.path.join(folder, filename)
    with open(filepath, "w") as f:
        f.write("This is a dummy file.\n")

    print(f"Dummy file written: {filepath}")

In [8]:
# Define Tokenization
# ------------------------------------------------

if os.path.exists(state.tokenizer_save_dir):
    tokenizer = AutoTokenizer.from_pretrained(state.tokenizer_save_dir)
    tokenizer.padding_side = 'left'
    print_step("Loaded tokenizer from saved checkpoint.")
else:
    tokenizer = AutoTokenizer.from_pretrained(state.model_id)
    tokenizer.padding_side = 'left'
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

In [9]:
# Reward utility
# ------------------------------------------------

def reward_atag(front, end, response):
    tag_len = len(front + end)
    start = response.find(front + end)
    end = response.find(front + '/' + end)
    reward = 0
    if start != -1: reward += 0.1
    if end != -1: reward += 0.1
    
    if start + tag_len < end:
        if len(response[start + tag_len:end].strip()) > 1:
            reward += 0.1
    return reward


def remove_comments(code: str):
    pattern = re.compile(r'//.*?$|/\*.*?\*/', re.DOTALL | re.MULTILINE)
    return re.sub(pattern, '', code)


def find_all_tag_indexes(text, tag):
    """Return a list of starting indexes where the tag occurs in the text."""
    indexes = []
    start = 0
    while True:
        idx = text.find(tag, start)
        if idx == -1:
            break
        indexes.append(idx)
        start = idx + len(tag)
    return indexes


def get_tag_start_end(idx, starts, ends, tag, full_text):
    start = starts[idx]+len(tag)
    end = ends[idx]
    return full_text[start: end].strip()
    

def reward_correct(clang_repl, full_text, writer, log_group, step):
    # handle only first answer
    reward = 0.0
    test_target_open = find_all_tag_indexes(full_text, "<Test Target>")
    test_target_close = find_all_tag_indexes(full_text, "</Test Target>")
    clang_repl_open = find_all_tag_indexes(full_text, "<Clang-repl Test>")
    clang_repl_close = find_all_tag_indexes(full_text, "</Clang-repl Test>")
    if len(test_target_open) == 0 or len(test_target_close) == 0 or len(clang_repl_open) == 0 or len(clang_repl_close) == 0:
        return reward, '<Test Target> or <Clang-repl Test> not found'
    if len(test_target_open) != len(test_target_close) or len(clang_repl_open) != len(clang_repl_close):
        return reward , '<Test Target> or <Clang-repl Test> pair not match'
    if not all(x < y for x, y in zip(test_target_open, test_target_close)):
        return reward, '<Test Target> not closed properly'
    if not all(x < y for x, y in zip(clang_repl_open, clang_repl_close)):
        return reward, '<Clang-repl Test> not closed properly' 
    target_text = get_tag_start_end(-1, test_target_open, test_target_close, "<Test Target>", full_text)
    target_text = remove_comments(target_text)
    target_text = ">>> "+target_text.replace('\n', '')

    for idx in range(len(clang_repl_open)):
        clang_repl_test = get_tag_start_end(idx, clang_repl_open, clang_repl_close, "<Clang-repl Test>", full_text)
        test_case_with_target = target_text+'\n'+clang_repl_test
        #print(test_case_with_target)
        result, response = clang_repl.run_verify(test_case_with_target)
        reward = 0.0
        if result == 'ok':
            reward = 2.0
        elif result == 'fail':
            reward = 1.0
        elif result == 'error':
            reward = 0.0
        else:
            assert False
        if log_reward and reward < 1.9:
            writer.add_text(f"{log_group}/reward_correct_context", f"Verify: {test_case_with_target}\nResponse: {response}", step)
        return reward, response
    else:
        return reward, ''

def reward(clang_repl, response_full, writer, log_group, step):
    # reference:
    # https://blog.gopenai.com/coding-grpo-from-scratch-a-guide-to-distributed-implementation-with-qwen2-5-1-5b-instruct-59b34227edacabs
    need_more_test_idx = response_full.find('<Need More Test')
    response = response_full if need_more_test_idx == -1 else response_full[:need_more_test_idx]
    score = 0.0
    score += reward_atag("<", "Test Object>", response) # 0.3
    score += reward_atag("<", "Input Data>", response) # 0.3
    score += reward_atag("<", "Expected Output>", response) # 0.3
    score += reward_atag("<", "Clang-repl Test>", response) # 0.3
    score += reward_atag("[", "REASON]", response) * 2 # 0.3*2
    score += reward_atag("[", "ANSWER]", response) * 2 # 0.3*2
    score = score / 10 # 0.24 *10 = 2.4
    if score >= 0.24:
        correct_reward, response = reward_correct(clang_repl, response, writer, log_group, step)
    else:
        correct_reward = 0.0
        response = "not formatted"
    reward = score + correct_reward
    writer.add_scalar(f"{log_group}/format_correct_sample", score, step)
    writer.add_scalar(f"{log_group}/reward_correct_sample", correct_reward, step)
    return reward, response

class RewardWorkPool(ObjectPool):     
    def __init__(self, group_size):
        super().__init__(ClangReplInterface, reward, group_size)

    def reward(self, codes, writer=None, log_group=None, global_step=None):
        global hparam
        args = [ [code, writer, log_group, global_step*hparam.group_size + idx] for idx, code in enumerate(codes)]
        self.start_tasks(args)
        
    def take_result(self):
        return self.get_results()




In [10]:
# Dataset and dataloader
# ------------------------------------------------
def load_qa_dataset(json_file):
    with open(json_file, "r", encoding="utf-8") as f:
        data = json.load(f)

    train_examples = []
    for item in data:
        q = item.get("Q", "")
        a = item.get("A", "")
        content = f"### Instruction\n\n{q}\n### Response\n\n{a}\n"
        train_examples.append({"content": content + "<|endoftext|>"})
    return train_examples


def load_sample_dataset(pk_file):
    with open(config.dataset_file, "rb") as f:
        global_samples = pickle.load(f)
        sample_dataset = []
        for sample in global_samples:
            sample_dataset.append({"content": sample + "<|endoftext|>"})
        return sample_dataset

def get_test_target_content(full_text):
    test_target_open = find_all_tag_indexes(full_text, "<Test Target>")
    test_target_close = find_all_tag_indexes(full_text, "</Test Target>")
    target_text = get_tag_start_end(-1, test_target_open, test_target_close, "<Test Target>", full_text)
    return target_text

def load_reasoning_dataset(test_target_object_file, seltected_categories):
    with open(test_target_object_file, 'r', encoding='utf-8') as file:
        data = json.load(file)
        train = []
        val = []
        categories = set()
        data_dic = {}
        for item in data:
            categories.add(item['category'])
        for cat in categories:
            data_dic[cat] = []
        for item in data:
            data_dic[item['category']].append(item['content'])

        for cat in seltected_categories:
            for idx, item in enumerate(data_dic[cat]):
                content = f"### Instruction\n\nn<Test Target>\n{get_test_target_content(item)}\n</Test Target>\nWrtie a Clang-repl Test\n### Response\n"
                if idx >=14:
                    val.append({"content":content})
                else:
                    train.append({"content":content})

        return train, val


def get_all_categories(state):
    with open(state.test_target_object_file, 'r', encoding='utf-8') as file:
        data = json.load(file)
        train = []
        val = []
        categories = set()
        category_list = []
        data_dic = {}
        for item in data:
            # reserve order
            if not item['category'] in categories:
                category_list.append(item['category'])
                categories.add(item['category'])
    return category_list

        
def get_dataloader(categories):
    global state, tokenizer
    reasoning_dataset, val_reasoning_dataset = load_reasoning_dataset(state.test_target_object_file, categories)
    reasoning_dataset, val_reasoning_dataset = shrink_dataset( reasoning_dataset, val_reasoning_dataset)
    
    # Create a Hugging Face Dataset from the list
    train_dataset = Dataset.from_list(reasoning_dataset)
    val_train_dataset = Dataset.from_list(val_reasoning_dataset)

    def tokenize_function(examples):
        global max_length
        return tokenizer(
            examples["content"],
            truncation=True,
            max_length=max_length,
            #padding="max_length"
        )
    
    print_step("eos: "+ str(tokenizer.eos_token) + str(tokenizer.eos_token_id))
    
    tokenized_dataset = train_dataset.map(tokenize_function, batched=True)
    tokenized_dataset = tokenized_dataset.remove_columns(["content"])
    tokenized_dataset.set_format("torch")
    
    val_tokenized_dataset = val_train_dataset.map(tokenize_function, batched=True)
    val_tokenized_dataset = val_tokenized_dataset.remove_columns(["content"])
    val_tokenized_dataset.set_format("torch")
    
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    # DataLoader
    global hparam
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=hparam.batch_size,
        shuffle=True,
        collate_fn=data_collator
    )
    val_dataloader = DataLoader(
        val_tokenized_dataset,
        batch_size=hparam.batch_size,
        shuffle=False,
        collate_fn=data_collator
    )
    return dataloader, val_dataloader

def get_train_dataloader(categories):
    train, _ = get_dataloader(categories)
    return train


def get_val_dataloader(categories):
    _, val = get_dataloader(categories)
    return val

In [11]:
# Optimization
# ------------------------------------------------
def write_weight_state(model, writer, step, log_group):
    for idx, (name, param) in enumerate(model.named_parameters()):
        if param.requires_grad:
            weight_mean = param.data.mean().item()
            weight_std = param.data.std().item()

            writer.add_scalar(f"{log_group}/{idx}_{name}/mean", weight_mean, step)
            writer.add_scalar(f"{log_group}/{idx}_{name}/std", weight_std, step)

def change_grad(model, layer_start, layer_end, multiple=0.01):
    for idx, val in enumerate(model.named_parameters()):
        name, param = val
        if param.grad is None:
            continue
        if layer_start <= idx < layer_end:
                param.grad.mul_(multiple)
            
class TrainLayerUpdater:
    def __init__(self, model, train_layer):
        self.model = model.model
        self.train_layer = train_layer

    def get_layer_params(self):
        if self.train_layer == TrainLayers.FULL_LAYER:
            params = list(self.model.parameters())
        elif self.train_layer == TrainLayers.TWO_FRONT_LAYER:
            return [p for layer in self.model.layers[:2] for p in layer.parameters()]
        elif self.train_layer == TrainLayers.ODD_LAYER:
            params = [p for idx, layer in enumerate(self.model.layers) if idx % 2 == 1 for p in layer.parameters()]
        elif self.train_layer == TrainLayers.EVEN_LAYER:
            params = [p for idx, layer in enumerate(self.model.layers) if idx % 2 == 0 for p in layer.parameters()]
        elif self.train_layer == TrainLayers.SWITCH_PAIR_LAYER:
            idx = self.config.current_layer_index
            params = []
            if idx < len(self.model.layers):
                params.extend(list(self.model.layers[idx].parameters()))
            if (idx + 1) < len(self.model.layers):
                params.extend(list(self.model.layers[idx + 1].parameters()))
            # Cycle the current_layer_index for the next update
            self.config.current_layer_index = (idx + 2) % len(self.model.layers) 
        else:
            raise ValueError("Invalid train_layer configuration")

        # add first two layer whatever
        #params.extend([p for layer in self.model.layers[:2] for p in layer.parameters()]) # params
        # remove first two layer
        params = [x for x in params if x not in self.model.layers[:2]] # 
        return params

    def update_optimizer_and_requires_grad(self, optimizer):
        # Get the new set of parameters for training.
        new_params = self.get_layer_params()
        new_params_set = set(new_params)
        
        # Update requires_grad flags for all model parameters.
        for param in self.model.parameters():
            param.requires_grad = param in new_params_set

        # Update the optimizer parameter group (assuming a single group).
        optimizer.param_groups[0]['params'] = list(new_params)

In [12]:
# Tensor utility
# ------------------------------------------------

def pad_to_match(tensor_a, tensor_b, padding_value=0):
    # Determine the current sequence lengths
    seq_len_a = tensor_a.size(1)
    seq_len_b = tensor_b.size(1)

    if seq_len_a > seq_len_b:
        max_seq_len = max(seq_len_a, seq_len_b)
    
        # Define padding function
        def pad_tensor(tensor, target_length):
            pad_length = target_length - tensor.size(1)
            if pad_length > 0:
                padding = (0, 0) * (tensor.dim() - 2) + (0, pad_length)
                tensor = F.pad(tensor, padding, value=padding_value)
            return tensor
    
        # Pad both tensors to the maximum sequence length
        tensor_a_padded = pad_tensor(tensor_a, max_seq_len)
        tensor_b_padded = pad_tensor(tensor_b, max_seq_len)
    else:
        tensor_b_padded = tensor_b[:, :seq_len_a]
        tensor_a_padded = tensor_a

    return tensor_a_padded, tensor_b_padded

def selective_log_softmax(logits, input_ids, tokenizer):
    # Ensure input_ids are on the same device as logits
    if input_ids.device != logits.device:
        input_ids = input_ids.to(logits.device)

    log_probs = nn.functional.log_softmax(logits, dim=-1)
    if input_ids.size(1) > log_probs.size(1):
        input_ids = input_ids[:, :log_probs.size(1)]

    # Gather log probabilities corresponding to input_ids
    selected_log_probs = log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

    if log_content:
        input_text = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
        print("Input Texts:")
        for text in input_text:
            print(text)
        logits_ids = logits.argmax(dim=-1)
        logit_text = tokenizer.batch_decode(logits_ids, skip_special_tokens=True)
        print("\nLogit Texts:")
        for text in logit_text:
            print(text)

    return selected_log_probs

def add_front_transformer_block(self, copy_weights: bool = True):
    # Retrieve the current first transformer block.
    layer_index = 1
    original_first_block = self.model.layers[layer_index]

    # Create a new block.
    new_block = copy.deepcopy(original_first_block) if copy_weights else type(original_first_block)()

    self.model.layers.insert(layer_index, new_block)

    self.config.num_hidden_layers += 1

def cut_tensors_by_min(a: torch.Tensor, b: torch.Tensor, dim: int):
    assert a.dim() > dim and b.dim() > dim, "Specified dim exceeds tensor rank"

    min_length = min(a.size(dim), b.size(dim))
    a_cut = torch.narrow(a, dim, 0, min_length)
    b_cut = torch.narrow(b, dim, 0, min_length)
    return a_cut, b_cut

def cut_ids_on_eos_tensor(full_ids, eos_token_id):
    processed_ids = []
    for seq in full_ids:
        eos_positions = (seq == eos_token_id).nonzero(as_tuple=True)[0]
        if eos_positions.numel() > 0:
            first_eos_index = eos_positions[0].item()
            processed_ids.append(seq[:first_eos_index])
        else:
            processed_ids.append(seq)
    return processed_ids
    
def cut_ids_on_eos(generated_ids, eos_token_id):
    processed_ids = []
    for seq in generated_ids:
        if eos_token_id in seq:
            # Truncate the sequence at the first occurrence of the EOS token
            first_eos_index = seq.index(eos_token_id)
            processed_ids.append(seq[:first_eos_index])
        else:
            processed_ids.append(seq)
    return processed_ids


def shift_ids_with_logits(ids, shift_logits):
    shift_ids = torch.cat([ids[:, 1:], torch.argmax(shift_logits[:, -1, :], dim=-1).unsqueeze(1)], dim=1)
    return shift_ids

In [13]:
# Train scheduling
# ------------------------------------------------


def copy_inference_model(model):
    old_model = copy.deepcopy(model).half()
    old_model.eval()
    for param in old_model.parameters():
        param.requires_grad = False
    return old_model


def try_to_update_category_count_and_ref_model(
    model,
    ref_model,
    hparam,
    state,
    mean_reward,
    cur_category_count,
    category_size
):
    if mean_reward > hparam.expected_meean_reward:
        print("Reference model updated")
        ref_model = copy_inference_model(model)
        if cur_category_count < category_size:
            if not state.is_finding_opt:
                cur_category_count += 1
    elif ref_model is None:
        ref_model = copy_inference_model(model)

    return cur_category_count, ref_model


_switch_pair_layer_counter = 0


def train_and_evaluate(
        model,
        hparam,
        state,
        optimizer,
        tokenizer,
        scaler,
        scheduler,
        writer):
    reward_work = RewardWorkPool(hparam.group_size * hparam.batch_size)
    ref_model = None

    # --- Generate sample output text after each epoch ---
    model.eval()  # Set to eval mode for generation.
    with torch.no_grad():
        samping(
            model,
            tokenizer,
            state.device,
            0,
            writer,
            "In Custom Clang-repl, What is the prompt in Custom Clang-repl?",
            "```\n>>> (prompt)\n```",
        )
        samping(
            model,
            tokenizer,
            state.device,
            0,
            writer,
            "In Custom Clang-repl, Do we allow multiline comments or backslash-extended lines in Custom Clang-repl Test?",
            "Custom Clang-repl takes only one line input.",
        )
    model.train()  # Switch back to training mode.

    mean_reward = 0
    categories = get_all_categories(state)
    category_size = len(categories)
    state.cur_category_count = hparam.category_count_start
    last_category_count = state.cur_category_count
    switch_pair_layer = 0

    print("Total categories:", category_size, categories)

    dataloader = get_train_dataloader(categories[:state.cur_category_count])
    print("Data counts:", len(dataloader))
    val_dataloader = get_val_dataloader(categories)

    for epoch in range(state.start_epoch, hparam.num_epochs):
        running_loss = 0.0
        print_step(f"Epoch {epoch+1}/{hparam.num_epochs} - Validation", main_step=True)

        if hparam.train_layer == TrainLayers.SWITCH_PAIR_LAYER:
            global _switch_pair_layer_counter
            _switch_pair_layer_counter += 1
            if _switch_pair_layer_counter % 4 == 0:
                state.cur_category_count, ref_model = try_to_update_category_count_and_ref_model(
                        model,
                        ref_model,
                        hparam,
                        state,
                        mean_reward,
                        state.cur_category_count,
                        category_size
                    )
        else:
            state.cur_category_count, ref_model = try_to_update_category_count_and_ref_model(
                model,
                ref_model,
                hparam,
                state,
                mean_reward,
                state.cur_category_count,
                category_size
            )

        print_memory(20)
        old_model = None

        if not state.skip_validation_step and (epoch % hparam.validation_interval) == 0:
            with torch.no_grad():
                print("Validation Start ....")
                # Run validation (no parameter updates).
                val_mean_loss, val_mean_reward, state.val_global_step = run(model, old_model, ref_model,
                    val_dataloader, optimizer, tokenizer,
                    hparam, state,
                    scaler, writer, reward_work, 
                    log_group="Validation",
                    scheduler=None,
                    global_step=state.val_global_step, num_grpo=1, group_size=1, is_validation=True)
                writer.add_scalar("Val_Epoch/mean_reward", val_mean_reward, epoch // hparam.validation_interval)
                writer.add_scalar("Val_Epoch/mean_loss", val_mean_loss, epoch // hparam.validation_interval)
                print("Validation End ....")

        print("Training Start ....")
        # Loop over gradient groups for training.
        print_step(
            f"Epoch {epoch+1}/{hparam.num_epochs} - category include ={categories[state.cur_category_count-1]}",
            main_step=True,
        )

        if last_category_count != state.cur_category_count:
            last_category_count = state.cur_category_count
            dataloader = get_train_dataloader(categories[:state.cur_category_count])
            print("Data counts:", len(dataloader))

        mean_loss, mean_reward, state.global_step = run(model, old_model, ref_model,
            dataloader, optimizer, tokenizer,
            hparam, state,
            scaler, writer, reward_work, 
            log_group="Training",
            scheduler=scheduler,
            global_step=state.global_step, num_grpo=hparam.num_grpo, group_size=hparam.group_size, is_validation=False
        )


        old_model = None
        print_step("7. End Epoch")
        writer.add_scalar("Epoch/mean_reward", mean_reward, epoch + 1)
        writer.add_scalar("Epoch/mean_loss", mean_loss, epoch + 1)
        print("Training End ....")

        # Save latest checkpoint.
        checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "epoch": epoch + 1,
            "global_step": state.global_step,
            "val_global_step": state.val_global_step,
        }
        os.makedirs(os.path.dirname(state.last_checkpoint_path), exist_ok=True)
        torch.save(checkpoint, state.last_checkpoint_path)

        # Optionally save checkpoint on specific epochs.
        if epoch % state.save_epochs == 0:
            checkpoint_dir = state.checkpoint_dir_pre + str(epoch + 1)
            os.makedirs(checkpoint_dir, exist_ok=True)
            checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pt")
            torch.save(checkpoint, checkpoint_path)
            print_step(f"Checkpoint saved at epoch {epoch + 1} to {checkpoint_path}")

            # Save tokenizer once if not already saved.
            tokenizer_save_dir = "./saved_models/tokenizer"
            if not os.path.exists(tokenizer_save_dir):
                os.makedirs(tokenizer_save_dir, exist_ok=True)
                tokenizer.save_pretrained(tokenizer_save_dir)
                print_step("Tokenizer saved.")

    state.prof_stop()
    writer.close()
    return mean_reward

def load(checkpoint_path, hparam, state, start_epoch=None):
    print_memory(1)
    checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
    _model = AutoModelForCausalLM.from_pretrained(state.model_id)
    config = copy.deepcopy(_model.config)
    _model = None
    print_memory(2)
    config.num_hidden_layers += 2
    config.max_position_embeddings = 512
    model = AutoModelForCausalLM.from_config(config)
    print_memory(3)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(state.device)
    print_memory(4)
    train_layer_updater = TrainLayerUpdater(model, hparam.train_layer)
    optimizer = hparam.get_optimizer(train_layer_updater.get_layer_params(), lr=hparam.lr)
    train_layer_updater.update_optimizer_and_requires_grad(optimizer)

    print_step( f"Loaded checkpoint from reference {checkpoint_path} at epoch {start_epoch if start_epoch is not None else checkpoint.get("epoch", 0)}")
    print_memory(7)

    return model, optimizer, checkpoint

def train(hparam, state, tokenizer):
    log_dir = f"runs/{state.log_prefix}_{state.learning_name}_{hparam.essence_str()}"
    write_time_file(log_dir)
    writer = SummaryWriter(log_dir=log_dir)

    state.prof_start()

    # Check if a latest checkpoint exists to load model and optimizer states
    if (
        os.path.exists(state.last_checkpoint_path)
        and not state.is_finding_opt
        and not state.use_reference_model
    ):
        print("==USING CHECK POINT MODEL==")
        model, optimizer, checkpoint = load(state.last_checkpoint_path, hparam, state)
        state.start_epoch = checkpoint.get("epoch", 0)
        state.global_step = checkpoint.get("global_step", 0)
        state.val_global_step = checkpoint.get("val_global_step", 0)
    else:
        if os.path.exists(state.ref_checkpoint_path):
            print("==USING REFERENCE MODEL==")
            model, optimizer, checkpoint = load(state.ref_checkpoint_path, hparam, state, start_epoch=0)
            state.start_epoch = 0  
            state.global_step = 0
            state.val_global_step = 0
        else:
            assert False, "prompt_last_checkpoint_path must exist"

    dups = check_optimizer_duplicates(optimizer)
    if dups:
        print("Warning: The optimizer contains duplicate parameters!")
        print(f"Duplicate parameter count: {len(dups)}")
    else:
        print("No duplicate parameters found in the optimizer.")

    # Clear cached memory that is no longer used
    torch.cuda.empty_cache()
    gc.collect()
    print_memory(9)

    # AMP GradScaler
    scaler = torch.cuda.amp.GradScaler()

    def lr_schedule(step):
        # Linear warm-up to 1.0, then constant
        return min(1.0, step / hparam.warming_up_step)

    scheduler = LambdaLR(optimizer, lr_lambda=lr_schedule)

    # Train & get final metric
    final_avg_loss = train_and_evaluate(
        model,
        hparam,
        state,
        optimizer,
        tokenizer,
        scaler=scaler,
        scheduler=scheduler,
        writer=writer
    )

    # Return the final average loss to Optuna
    return final_avg_loss


def objective(trial):
    global state, tokenizer
    hparam = object_hiper_param(trial)
    state.skip_validation_step=True
    state.is_finding_opt=True

    print(
        f"[Optuna] Trial hyperparameters -> {hparam}"
    )
    return train(hparam, state, tokenizer)

def main(hparam, state):
    if state.is_finding_opt:
        # Create study to minimize final loss
        global objective
        study = optuna.create_study(direction="maximize")
        study.optimize(objective, n_trials=5)  # You can increase n_trials

        print("Study completed!")
        print("Best trial:")
        best_trial = study.best_trial
        print(f"  Value: {best_trial.value}")
        print("  Params: ")
        for key, value in best_trial.params.items():
            print(f"#    {key}: {value}")
        with open("runs/hiper_param.json", "w") as f:
            json.dump(dict(best_trial.params.items()), f, indent=4)
    else:
        train(hparam, state, tokenizer)

In [14]:
# Core logic
# ------------------------------------------------

def generate_ids(model, batch, tokenizer, temperature):
    global max_length
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    eos_token_id = tokenizer.eos_token_id

    # Determine prompt length for each example in the batch based on the first occurrence of EOS.
    prompt_lengths = []
    for i in range(input_ids.size(0)):
        seq = input_ids[i]
        # Find indices where the token equals the eos_token_id.
        eos_positions = (seq == eos_token_id).nonzero(as_tuple=True)[0]
        # If there's at least one occurrence, use its index + 1 (if you want to include the EOS in the prompt).
        # Otherwise, fallback to the full sequence length.
        if eos_positions.numel() > 0:
            first_eos = eos_positions[0].item() + 1
        else:
            first_eos = seq.size(0)
        prompt_lengths.append(first_eos)
    
    print_memory("Prompt lengths per batch element: " + str(prompt_lengths))

    output = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=max_length,  
        temperature=temperature,
        do_sample=True,
        eos_token_id=eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        output_scores=True,
        return_dict_in_generate=True
    )
    full_ids = output.sequences.detach()
    truncated_ids = cut_ids_on_eos_tensor(full_ids, tokenizer.eos_token_id)    
    respone_ids =  pad_sequence([truncated_ids[idx][p_len:] for idx, p_len in enumerate(prompt_lengths)],
                                               batch_first=True, padding_value=tokenizer.pad_token_id)
    truncated_ids = pad_sequence(truncated_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    full_ids = pad_sequence(full_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    output = None
    print_memory("full_ids.shape[-1]: " + str(full_ids.shape[-1]))
    return full_ids, truncated_ids, respone_ids, prompt_lengths


def compute_logits(model, full_ids, prompt_lengths, respone_ids, tokenizer, detach_out=False):
    # Pad the list of full_ids to a whole tensor with shape (batch, max_seq_length)
    full_ids = pad_sequence(full_ids, batch_first=True, padding_value=tokenizer.pad_token_id).to(dtype=torch.int32)
    
    # Create an attention mask where non-pad tokens are 1 and pad tokens are 0
    full_ids_mask = (full_ids != tokenizer.pad_token_id).to(dtype=torch.int32, device=full_ids.device)
    
    # Compute logits for the whole padded tensor.
    logits = model(input_ids=full_ids, attention_mask=full_ids_mask, early_stop=False).logits
    
    truncated_response_ids_list = []
    truncated_response_logits_list = []
    batch_size = full_ids.size(0)
    
    for i in range(batch_size):
        p_len = prompt_lengths[i]
        # Determine the true sequence length (ignoring padding) for this batch element.
        actual_length = full_ids_mask[i].sum().item()
        # Ensure prompt length does not exceed actual length.
        if p_len > actual_length:
            p_len = actual_length

        # Extract completion token IDs for this example.
        comp_ids = full_ids[i, p_len:actual_length].detach()
        # For logits, if you want to include the token just before the completion, slice from p_len-1.
        comp_logits = logits[i, p_len-1:actual_length-1, :]
        
        # Optionally, adjust lengths to be consistent (if needed by downstream code)
        #comp_ids, comp_logits = cut_tensors_by_min(comp_ids, comp_logits, 0)
        expected_len = respone_ids.shape[1]
        truncated_response_ids_list.append((comp_ids.detach() if detach_out else comp_ids)[:expected_len])
        truncated_response_logits_list.append((comp_logits.detach() if detach_out else comp_logits)[:expected_len, :])

    truncated_response_logits = pad_sequence(truncated_response_logits_list, batch_first=True, padding_value=tokenizer.pad_token_id)
    truncated_response_ids = pad_sequence(truncated_response_ids_list, batch_first=True, padding_value=tokenizer.pad_token_id)
    return logits, truncated_response_logits, truncated_response_ids



# ------------------------------------------------
# Define Training Function
# ------------------------------------------------
def run(model, old_model, ref_model,
            dataloader, optimizer, tokenizer,
            hparam, state,
            scaler, writer, reward_work, log_group, scheduler, 
            global_step, num_grpo, group_size, is_validation):
    running_loss = 0.0
    mean_reward = 0.0
    sum_reward = 0.0
    mean_reward_list = []
    mean_loss_list = []
    run_start_global_step=global_step
    print_memory("_.1. run() enter")
    # For accumulation mode, ensure gradients are zeroed at the start.
    optimizer.zero_grad()

    for step, batch in enumerate(tqdm(dataloader, total=len(dataloader)), start=1):
        print_step(f"Processing batch {step}/{len(dataloader)}: Start Loop")
        state.prof_step()
        old_model = hparam.try_old_model_update_in_batch_loop(model, old_model, step, len(dataloader), state.cur_category_count)

  
        # Move batch to device and expand the tensors.
        batch = {k: v.to(device=state.device, dtype=torch.int32) for k, v in batch.items()}
        batch_size = len(batch)
        input_ids = batch['input_ids'].repeat_interleave(group_size, dim=0)
        attention_mask = batch['attention_mask'].repeat_interleave(group_size, dim=0)
        batch_size = input_ids.size(0)
        batch['input_ids'] = input_ids
        batch['attention_mask'] = attention_mask

        start_tensor_ids = cur_memory_ids()

        # LOGIT sample (min= -34.09375 , avg= 0.364013671875 , max= 42.25 )
        # LOG_LOGIT sample: (min= -18.188087463378906 , avg= -0.3128775656223297 , max= 0.0 )
        # LOG_PROBE sample: (min= -22.589847564697266 , avg= -0.4811277985572815 , max= 0.0 )
        # kl_div sample: (min= 0.07530781626701355 , avg= 0.0904245376586914 , max= 0.10227474570274353 )

        # std_rewards: 0.5049999952316284 
        # advantages: (min= 1.2400000095367432 , avg= 1.4900000095367432 , max= 2.240000009536743 )
        # A_hat: (min= -1.4997029304504395 , avg= 1.4901161193847656e-08 , max= 0.4999009966850281 )
        # unclipped_objective: (min= -483673.78125 , avg= -212.85939025878906 , max= 221674.734375 )
        # clipped_objective: (min= 0.7576461434364319 , avg= 0.7710149884223938 , max= 1.2423537969589233 )
        # ppo_loss: -0.5694103240966797 , kl_div: 0.0904245376586914

        # std_rewards: 0.0 Rewards: [1.24, 1.24, 1.24, 1.24]
        # advantages: (min= 1.2400000095367432 , avg= 1.2400000095367432 , max= 1.2400000095367432 )
        # A_hat: (min= 0.0 , avg= 0.0 , max= 0.0 )
        # unclipped_objective: (min= 0.0 , avg= 0.0 , max= 0.0 )
        # clipped_objective: (min= 0.7576461434364319 , avg= 0.7576462030410767 , max= 0.7576461434364319 )
        # ppo_loss: 0.0 , kl_div: 0.0454302616417408

        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            # 1. Model forward pass for generation.
            print_step("1. Model train")
            with torch.no_grad():
                model_for_generating_result = hparam.choose_generator(model, old_model, ref_model)
                full_ids, truncated_ids, respone_ids, prompt_lengths = generate_ids(model_for_generating_result, batch, tokenizer, hparam.temperature)
                full_ids.log() # FULL_IDS
                truncated_ids.log() # TRUNCATED_IDS
                respone_ids.log() # RESPONSE_IDS
                full_text_lists = tokenizer.batch_decode(truncated_ids, skip_special_tokens=True)
                reward_work.reward(full_text_lists, writer, log_group, global_step)
                # Release unused tensors from generation.
                full_text_lists = None
            full_shift_logits, response_truncated_logits, _ = compute_logits(model, full_ids, prompt_lengths, respone_ids, tokenizer) 
            full_shift_logits.log() # FULL_LOGITS
            response_truncated_logits.log() # RESPONSE_LOGITS

            full_shift_ids = shift_ids_with_logits(full_ids, full_shift_logits)
            print_logits_ids("model full", full_shift_logits, full_shift_ids)  # good format confirmed, full_shift_logits: FULL_LOGITS, full_shift_ids: FULL_IDS
            
            FULL_IDS = full_ids.shape  # [batch, full_ids_len], sample: batch=4, full_ids=512
            TRUNCATED_IDS = truncated_ids.shape # [batch, truncated_ids_len],  sample: truncated_ids_len=512 or less. FULL_IDS with cut out end parts after eos
            RESPONSE_IDS = respone_ids.shape  # [batch, respone_ids_len], sample: respone_ids_len = 466 = truncated_ids_len-prompt_length
            FULL_LOGITS = full_shift_logits.shape # [batch, full_ids_len, embedding_len], sample: embedding_len=49152
            RESPONSE_LOGITS = response_truncated_logits.shape #  [batch, respone_ids_len, embedding_len]
            # GROUPED_BATCH  = advantages.shape # [grouped_batch, group_size], example grouped_batch=1, group_size=4 see advantages creation
            

            # 2. Run legacy models (old and reference models).
            print_step("2. Legacy Models Run")
            with torch.no_grad():
                _, old_response_truncated_logits, _ = compute_logits(old_model, truncated_ids, prompt_lengths, respone_ids, tokenizer, detach_out=True)
                ref_full_shift_logits, _, _ = compute_logits(ref_model, full_ids, prompt_lengths, respone_ids, tokenizer, detach_out=True)
                print_logits_ids("ref model full", ref_full_shift_logits, full_shift_ids) # good format confirmed, ref_full_shift_logits: FULL_LOGITS, full_shift_ids: FULL_IDS
            truncated_ids = None
            prompt_lengths = None
            full_ids = None
            
            with torch.no_grad():
                print_logits_ids("model response", response_truncated_logits, respone_ids) # good format confirmed, response_truncated_logits: RESPONSE_LOGITS, respone_ids: RESPONSE_IDS
                #model_log_logits = F.log_softmax(response_truncated_logits, dim=2)
                model_log_logits = selective_log_softmax(response_truncated_logits, respone_ids, tokenizer).check_shape(RESPONSE_IDS)
                model_log_logits.log()
                print_logits_ids("old model response", old_response_truncated_logits, respone_ids) # good format confirmed, old_response_truncated_logits: RESPONSE_LOGITS, respone_ids: RESPONSE_IDS
                #old_model_log_logits = F.log_softmax(old_response_truncated_logits, dim=2)
                old_model_log_logits = selective_log_softmax(old_response_truncated_logits, respone_ids, tokenizer).check_shape(RESPONSE_IDS)
                old_model_log_logits.log() 
                probability_ratio = torch.exp(model_log_logits - old_model_log_logits).check_shape(RESPONSE_IDS) #.mean(dim=2)
                probability_ratio.log() 
                
                # Remove legacy model intermediates (no longer needed)
                full_truncated_full_logits = None
                response_truncated_logits = None
                old_response_truncated_logits = None
                ref_completion_ids = None
                model_log_logits = None
                old_model_log_logits = None


            # 3. kl_div Loss Calc
            print_step("3. kl_div Loss Calc")    
            # Calculate token-level log probabilities.
            model_log_probs = selective_log_softmax(full_shift_logits, full_shift_ids, tokenizer)
            model_log_probs.log() # RESPONSE_IDS
            ref_log_probs = selective_log_softmax(ref_full_shift_logits, full_shift_ids, tokenizer)
            ref_log_probs.log() # RESPONSE_IDS
            
            # Compute token-level KL divergence.
            token_kl_div = F.kl_div(model_log_probs, ref_log_probs, reduction='none', log_target=True).check_shape(FULL_IDS) # it is not an ids but parts of logits content. the shape is just like ids)
            token_kl_div.log()
            kl_div = token_kl_div.mean(dim=-1).check_shape([batch_size])
            kl_div.log() # average over tokens. range (0, infite) but for output of similar model. It is very small. sample: kl_div=0.09
            
            # Save scalar values for logging before clearing.
            kl_div_val=kl_div.mean().item()
            
            # Remove now-unused intermediate tensors.
            ref_log_probs = None
            ref_full_shift_logits = None
            full_shift_logits = None
            full_shift_ids = None
            model_log_probs = None
            token_kl = None 
            
            # 4. Calculate rewards.
            print_step("4. Reward calc")
            
            reward_work_result = reward_work.take_result()
            if not reward_work_result:  # reward list is empty
                response_texts = tokenizer.batch_decode(respone_ids, skip_special_tokens=True)
                writer.add_text(f"{log_group}/reward_empty_response", str(response_texts), global_step=step)
                rewards = [0.0]*group_size
                if False:
                    # Remove now-unused intermediate tensors.
                    kl_div = None
                    respone_ids = None
                    probability_ratio = None
                    continue  # skip to next batch
            else:
                rewards, responses = reward_work_result
            sum_reward += sum(rewards)
            # rewards list[batch]
            
            if all(reward > 2.0 for reward in rewards):
                # perfect no loss in grouped_ppo
                grouped_ppo_loss = -(torch.ones(len(rewards), dtype=torch.float32, device=state.device) + hparam.epsilon).check_shape([batch_size])
                grouped_ppo_loss.log()
            else:
                # Convert rewards to tensor
                grouped_batch_size = len(rewards) // group_size
                advantages = torch.tensor(rewards, dtype=torch.float32, device=state.device).view(grouped_batch_size, group_size).check_range(0, 2.24)
                advantages.log() 
                
                # Calculate mean and std per batch (along dim=1) and repeat to match original size
                mean_rewards = advantages.mean(dim=1).repeat_interleave(group_size).check_shape([batch_size]).check_range(0, 2.24)
                mean_rewards.log()
                std_rewards = advantages.std(dim=1).repeat_interleave(group_size).check_shape([batch_size]).check_range(0, float('inf'))
                std_rewards.log() 
    
                # Reshape back to original form
                advantages = advantages.view(-1)
                advantages.check_shape([batch_size]).log("advantages before A_hat")
                A_hat = ((advantages - mean_rewards) / (std_rewards + 1e-4)).unsqueeze(1).check_shape([batch_size, 1]) 
                A_hat = torch.clamp(A_hat, -5, 5)
                A_hat.log() 
                
                # Clear rewards intermediates.
                advantages = None
                # 5. grouped_ppo Loss Calc
                print_step("5. Grouped ppo Loss Calc")            
                # PPO objective calculations.
                unclipped_objective = probability_ratio
                unclipped_objective.check_shape(RESPONSE_IDS).log()
                epsilon_high = torch.full_like(unclipped_objective, 1 + hparam.epsilon).check_shape(RESPONSE_IDS)
                _grouped_ppo_loss = - torch.minimum(unclipped_objective, epsilon_high)
                _grouped_ppo_loss.check_shape(RESPONSE_IDS).log("before A_hat multiply")
                _grouped_ppo_loss = _grouped_ppo_loss * A_hat
                grouped_ppo_loss = _grouped_ppo_loss.mean(dim=-1).check_shape([batch_size])
                grouped_ppo_loss.log() # sample epsilon=0.2
    
                # Remove now-unused intermediate tensors.
                A_hat = None
                unclipped_objective = None
                clipped_ratio = None
                clipped_objective = None
            
            
            # kl_lambda is a scaling factor for the KL term
            _grpo_loss = grouped_ppo_loss + hparam.kl_lambda * kl_div
            _grpo_loss.check_shape([batch_size]).log() 
            grpo_loss = _grpo_loss.mean()
            grpo_loss.log() # []

            # Save scalar values for logging before clearing.
            ppo_loss_val = grouped_ppo_loss.mean().item()
            grpo_loss_val = grpo_loss.mean().item()

            # Remove now-unused intermediate tensors.
            respone_ids = None
            grouped_ppo_loss = None
            kl_div = None
            probability_ratio = None

            # Remove now-unused intermediate tensors.
            per_token_loss = None
            
            # 6. Backpropagation and parameter update (only if not in validation mode).
            print_step("6. Backpropagation and parameter update") 
            is_param_updated = False

        if not is_validation:
            scaler.scale(grpo_loss).backward()
            scaler.step(optimizer)
            write_weight_state(model, writer, step, log_group+'_weights')
            if False:
                change_grad(model, 0, 2, multiple=0.01)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()
            is_param_updated = True
            # Remove unused variables from the current iteration.
            grpo_loss = None
            #compare_memory_ids(start_tensor_ids)
            current_lr = optimizer.param_groups[0]['lr']
            writer.add_scalar(f"{log_group}/lr", current_lr, step)
        else:
            optimizer.zero_grad()
            grpo_loss = None

        running_loss += grpo_loss_val

        # 7. Logging with dynamic log group.
        print_step("7. Logging")
        mean_reward = sum(rewards) / len(rewards)
        metrics = {
            'mean_reward': mean_reward,
            'loss': grpo_loss_val
        }
        #writer.add_hparams(hparam_dict=hparams, metric_dict=metrics, global_step=global_step)
        writer.add_scalar(f"{log_group}/grpo_loss", grpo_loss_val, global_step)
        writer.add_scalar(f"{log_group}/ppo_loss", ppo_loss_val, global_step)
        writer.add_scalar(f"{log_group}/kl_div", kl_div_val, global_step)
        writer.add_scalar(f"{log_group}/mean_reward", mean_reward, global_step)
        mean_reward_list.append(mean_reward)
        mean_loss_list.append(grpo_loss_val)
        if is_param_updated:
            writer.add_scalar(f"{log_group}/model_update_grpo_loss", grpo_loss_val, global_step)

        print_step("8. End Loop")
        global_step += 1

    epoch_mean_reward =  sum(mean_reward_list) / len(mean_reward_list)
    mean_loss =  sum(mean_loss_list) / len(mean_loss_list)
    print_step("_.1. run() exit")
    return running_loss, epoch_mean_reward, global_step
    
if __name__ == "__main__":
    print("Start reasoning logic....")
    main(hparam, state)

Start reasoning logic....
Dummy file written: runs/Adafactor_starcoder2-3b_reasoning_Adafactor_FULL_LAYER__3.131e-05_8_0.26207_1_0.53715/2025-04-06_10-13-17.txt
==USING REFERENCE MODEL==
No duplicate parameters found in the optimizer.
Total categories: 9 ['simple arithmetic', 'simple if', 'simple loop', 'loop and if', 'simple state', 'recursive function', 'pointer manipulation', 'string manipulation', 'sort algorithm']


Map:   0%|          | 0/14 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Data counts: 14


Map:   0%|          | 0/126 [00:00<?, ? examples/s]

Map:   0%|          | 0/9 [00:00<?, ? examples/s]

Epoch 1/200 - Validation
Validation Start ....


  std_rewards = advantages.std(dim=1).repeat_interleave(group_size).check_shape([batch_size]).check_range(0, float('inf'))
100%|███████████████████████████████████████████████████████████████████████████████████████| 9/9 [01:58<00:00, 13.13s/it]


Validation End ....
Training Start ....
Epoch 1/200 - category include =simple arithmetic


  7%|██████▏                                                                               | 1/14 [00:34<07:26, 34.38s/it]


KeyboardInterrupt: 