In [1]:
# Additional modules in addition of docker environment
# ------------------------------------------------

if False:
    !pip install datasets torch_optimizer lion_pytorch clang_repl_kernel --break-system-packages
    !pip install --upgrade clang-repl-kernel  --break-system-packages

In [2]:
# Reasoning execise environment prepare
# ------------------------------------------------

from ClangReplInterface import ClangReplInterface
if False:
    clang_repl = ClangReplInterface()


In [3]:
#!rm -rf "runs/starcoder2_reasoning"
#!rm -rf "saved_models/reasoning"

In [4]:
# import and setup
# ------------------------------------------------

import os
from datetime import datetime
import copy
import re
import json
import torch
import shutil
import inspect
import gc
from enum import Enum
import bitsandbytes as bnb
import torch.nn as nn
import torch.optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
)
from transformers.optimization import Adafactor
from datasets import Dataset
import signal
from tqdm import tqdm
import random
import numpy as np

# For hyperparameter optimization
import optuna
import pickle
from Config import SimpleConfig
from ClangReplInterface import ClangReplInterface, ObjectPool
import warnings

# to clean up log
warnings.filterwarnings(
    "ignore",
    category=FutureWarning
)

# for memory
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# pytorch random seed fix idiom
def set_random_seed(seed=42):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU setups
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Random seed set as {seed}")
    
set_random_seed()




Random seed set as 42


In [5]:
# Overall configuration
# ------------------------------------------------


config = SimpleConfig()
max_length = 512
MAX_REWARD=4.4 

class TrainLayers(Enum):
    FULL_LAYER = 1
    TWO_FRONT_LAYER = 2
    ODD_LAYER = 3
    EVEN_LAYER = 4
    SWITCH_PAIR_LAYER = 5

train_layer = TrainLayers.FULL_LAYER

ref_checkpoint_path = "./saved_models/starcoder2-3b_exact_sample/checkpoint.pt"
test_target_object_file = "manual_data_set/ReasoningTestTarget.json"
learning_name = f"starcoder2-3b_reasoning_{train_layer.name}_"
now = datetime.now().strftime("%d_%H-%M")
learning_name_time = learning_name + now
last_checkpoint_path = f"./saved_models/{learning_name}/checkpoint.pt"
checkpoint_dir_pre = f"./saved_models/{learning_name}/epoch_"

model_id = "bigcode/starcoder2-3b"
# Load tokenizer from saved directory if exists; otherwise, load from pretrained.
tokenizer_save_dir = "./saved_models/tokenizer"


log_content = False
log_step = False
log_memory = False
log_logits = False
log_reward = True
log_tensor = True

checking_range = False
checking_shape = False

group_size = 4
batch_size = 1 # 7

category_count_start = 1 # 9
use_reference_model = True
skip_validation_step=True
is_finding_opt = True
if not is_finding_opt:
    num_epochs = 200
    lr = 1.6289610787142993e-07 #1.1111588431283189e-06
    kl_lambda = 1.0 #0.15842765249477542
    epsilon = 0.2 #0.11144786260484413
    num_grpo = 3
    save_epochs = 10
    warming_up_step= 1
    gradient_accumulation_step = 1
    temperature = 0.9797868214412535  #1.0
# Best trial:
#  Value: 10.131847697589546
#  Params: 
#    lr: 1.6289610787142993e-06
#    kl_lambda: 0.0490765147903471
#    epsilon: 0.09582283082922764
#    num_grpo: 1
#    temperature: 0.9797868214412535    
# for note valide setting
# num_iterations=1, num_steps=500, batch_size=4, num_generations=4, max_completion_length=128, kl=0.1,
# learning_rate=5e-6, mu=3, epsilon=0.2,
#
# lr: 7.205691481165551e-05 kl_lambda: 0.2654706177039008 epsilon: 0.019437902361559744 num_grpo: 1
# lr: 1.1111588431283189e-06 kl_lambda: 0.15842765249477542 epsilon: 0.11144786260484413 num_grpo: 3

def object_hiper_param(trial):
    # Shortened training for demonstration:
    num_epochs = 4  # 2   # or 2–3, to save time during hyperparameter search

    # Sample hyperparameters
    lr = trial.suggest_float("lr", 1e-6, 1e-3, log=True) # 5e-6
    kl_lambda = trial.suggest_float("kl_lambda", 0.04, 10.0) # 0.04 from
    epsilon = trial.suggest_float("epsilon", 0.01, 0.3) # 0.1
    num_grpo = trial.suggest_int("num_grpo", 1, 2, step=1) # 1
    warming_up_step= 1 #trial.suggest_int("warming_up_step", 1, 1, step=1)
    gradient_accumulation_step = 1 #trial.suggest_int("gradient_accumulation_step", 1, 5, step=1)
    temperature = trial.suggest_float("temperature", 0.5, 1.4) # 1.0
    category_count_start = 1 #trial.suggest_int("category_count_start", 1, 2, step=1)

    return num_epochs, lr, kl_lambda, epsilon, num_grpo, warming_up_step, gradient_accumulation_step, temperature, category_count_start

def shrink_dataset( reasoning_dataset, val_reasoning_dataset):
    reasoning_dataset = reasoning_dataset
    val_reasoning_dataset = val_reasoning_dataset[:5]
    return reasoning_dataset, val_reasoning_dataset


In [6]:
# Define Tokenization
# ------------------------------------------------

if os.path.exists(tokenizer_save_dir):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_save_dir)
    tokenizer.padding_side = 'left'
    if log_step: print("Loaded tokenizer from saved checkpoint.")
else:
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.padding_side = 'left'
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

In [7]:
# Reward utility
# ------------------------------------------------

def reward_atag(front, end, response):
    tag_len = len(front + end)
    start = response.find(front + end)
    end = response.find(front + '/' + end)
    reward = 0
    if start != -1: reward += 0.1
    if end != -1: reward += 0.1
    
    if start + tag_len < end:
        if len(response[start + tag_len:end].strip()) > 1:
            reward += 0.1
    return reward


def remove_comments(code: str):
    pattern = re.compile(r'//.*?$|/\*.*?\*/', re.DOTALL | re.MULTILINE)
    return re.sub(pattern, '', code)


def find_all_tag_indexes(text, tag):
    """Return a list of starting indexes where the tag occurs in the text."""
    indexes = []
    start = 0
    while True:
        idx = text.find(tag, start)
        if idx == -1:
            break
        indexes.append(idx)
        start = idx + len(tag)
    return indexes


def get_tag_start_end(idx, starts, ends, tag, full_text):
    start = starts[idx]+len(tag)
    end = ends[idx]
    return full_text[start: end].strip()
    

def reward_correct(clang_repl, full_text, writer, log_group, step):
    # handle only first answer
    reward = 0.0
    test_target_open = find_all_tag_indexes(full_text, "<Test Target>")
    test_target_close = find_all_tag_indexes(full_text, "</Test Target>")
    clang_repl_open = find_all_tag_indexes(full_text, "<Clang-repl Test>")
    clang_repl_close = find_all_tag_indexes(full_text, "</Clang-repl Test>")
    if len(test_target_open) == 0 or len(test_target_close) == 0 or len(clang_repl_open) == 0 or len(clang_repl_close) == 0:
        return reward, '<Test Target> or <Clang-repl Test> not found'
    if len(test_target_open) != len(test_target_close) or len(clang_repl_open) != len(clang_repl_close):
        return reward , '<Test Target> or <Clang-repl Test> pair not match'
    if not all(x < y for x, y in zip(test_target_open, test_target_close)):
        return reward, '<Test Target> not closed properly'
    if not all(x < y for x, y in zip(clang_repl_open, clang_repl_close)):
        return reward, '<Clang-repl Test> not closed properly' 
    target_text = get_tag_start_end(-1, test_target_open, test_target_close, "<Test Target>", full_text)
    target_text = remove_comments(target_text)
    target_text = ">>> "+target_text.replace('\n', '')

    for idx in range(len(clang_repl_open)):
        clang_repl_test = get_tag_start_end(idx, clang_repl_open, clang_repl_close, "<Clang-repl Test>", full_text)
        test_case_with_target = target_text+'\n'+clang_repl_test
        #print(test_case_with_target)
        result, response = clang_repl.run_verify(test_case_with_target)
        reward = 0.0
        if result == 'ok':
            reward = 2.0
        elif result == 'fail':
            reward = 1.0
        elif result == 'error':
            reward = 0.0
        else:
            assert False
        if log_reward and reward < 1.9:
            writer.add_text(f"{log_group}/reward_correct_context", f"Verify: {test_case_with_target}\nResponse: {response}", step)
        return reward, response
    else:
        return reward, ''

def reward(clang_repl, response_full, writer, log_group, step):
    # https://blog.gopenai.com/coding-grpo-from-scratch-a-guide-to-distributed-implementation-with-qwen2-5-1-5b-instruct-59b34227edacabs
    need_more_test_idx = response_full.find('<Need More Test')
    response = response_full if need_more_test_idx == -1 else response_full[:need_more_test_idx]
    score = 0.0
    score += reward_atag("<", "Test Object>", response) # 0.3
    score += reward_atag("<", "Input Data>", response) # 0.3
    score += reward_atag("<", "Expected Output>", response) # 0.3
    score += reward_atag("<", "Clang-repl Test>", response) # 0.3
    score += reward_atag("[", "REASON]", response) * 2 # 0.3*2
    score += reward_atag("[", "ANSWER]", response) * 2 # 0.3*2
    score = score / 10 # 0.24 *10 = 2.4
    if score >= 0.24:
        correct_reward, response = reward_correct(clang_repl, response, writer, log_group, step)
    else:
        correct_reward = 0.0
        response = "not formatted"
    reward = score + correct_reward
    writer.add_scalar(f"{log_group}/format_correct_sample", score, step)
    writer.add_scalar(f"{log_group}/reward_correct_sample", correct_reward, step)
    return reward, response

class RewardWorkPool(ObjectPool):     
    def __init__(self, group_size):
        super().__init__(ClangReplInterface, reward, group_size)

    def reward(self, codes, writer=None, log_group=None, global_step=None):
        global group_size
        args = [ [code, writer, log_group, global_step*group_size + idx] for idx, code in enumerate(codes)]
        self.start_tasks(args)
        
    def take_result(self):
        return self.get_results()




In [8]:
# Dataset and dataloader
# ------------------------------------------------
def load_qa_dataset(json_file):
    with open(json_file, "r", encoding="utf-8") as f:
        data = json.load(f)

    train_examples = []
    for item in data:
        q = item.get("Q", "")
        a = item.get("A", "")
        content = f"### Instruction\n\n{q}\n### Response\n\n{a}\n"
        train_examples.append({"content": content + "<|endoftext|>"})
    return train_examples


def load_sample_dataset(pk_file):
    with open(config.dataset_file, "rb") as f:
        global_samples = pickle.load(f)
        sample_dataset = []
        for sample in global_samples:
            sample_dataset.append({"content": sample + "<|endoftext|>"})
        return sample_dataset

def get_test_target_content(full_text):
    test_target_open = find_all_tag_indexes(full_text, "<Test Target>")
    test_target_close = find_all_tag_indexes(full_text, "</Test Target>")
    target_text = get_tag_start_end(-1, test_target_open, test_target_close, "<Test Target>", full_text)
    return target_text

def load_reasoning_dataset(test_target_object_file, seltected_categories):
    with open(test_target_object_file, 'r', encoding='utf-8') as file:
        data = json.load(file)
        train = []
        val = []
        categories = set()
        data_dic = {}
        for item in data:
            categories.add(item['category'])
        for cat in categories:
            data_dic[cat] = []
        for item in data:
            data_dic[item['category']].append(item['content'])

        for cat in seltected_categories:
            for idx, item in enumerate(data_dic[cat]):
                content = f"### Instruction\n\nn<Test Target>\n{get_test_target_content(item)}\n</Test Target>\nWrtie a Clang-repl Test\n### Response\n"
                if idx >=10:
                    val.append({"content":content})
                else:
                    train.append({"content":content})

        return train, val


def get_all_categories():
    global test_target_object_file
    with open(test_target_object_file, 'r', encoding='utf-8') as file:
        data = json.load(file)
        train = []
        val = []
        categories = set()
        category_list = []
        data_dic = {}
        for item in data:
            # reserve order
            if not item['category'] in categories:
                category_list.append(item['category'])
                categories.add(item['category'])
    return category_list

        
def get_dataloader(categories):
    global test_target_object_file, tokenizer
    reasoning_dataset, val_reasoning_dataset = load_reasoning_dataset(test_target_object_file, categories)
    reasoning_dataset, val_reasoning_dataset = shrink_dataset( reasoning_dataset, val_reasoning_dataset)
    
    # Create a Hugging Face Dataset from the list
    train_dataset = Dataset.from_list(reasoning_dataset)
    val_train_dataset = Dataset.from_list(val_reasoning_dataset)

    
    def tokenize_function(examples):
        return tokenizer(
            examples["content"],
            truncation=True,
            max_length=max_length,
            #padding="max_length"
        )
    
    if log_step: print("eos: ", tokenizer.eos_token, tokenizer.eos_token_id)
    
    tokenized_dataset = train_dataset.map(tokenize_function, batched=True)
    tokenized_dataset = tokenized_dataset.remove_columns(["content"])
    tokenized_dataset.set_format("torch")
    
    val_tokenized_dataset = val_train_dataset.map(tokenize_function, batched=True)
    val_tokenized_dataset = val_tokenized_dataset.remove_columns(["content"])
    val_tokenized_dataset.set_format("torch")
    
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    # DataLoader
    global batch_size
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=data_collator
    )
    val_dataloader = DataLoader(
        val_tokenized_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=data_collator
    )
    return dataloader, val_dataloader

def get_train_dataloader(categories):
    train, _ = get_dataloader(categories)
    return train


def get_val_dataloader(categories):
    _, val = get_dataloader(categories)
    return val



In [9]:
# Logging or Print
# ------------------------------------------------

previous_tensor_info = {}
def print_memory(tag):
    global previous_tensor_info, log_memory
    if not log_memory:
        return
    # Make sure you have a GPU device available.
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Print current allocated and reserved memory in MB:
    allocated = torch.cuda.memory_allocated(device) / (1024 ** 2)
    reserved = torch.cuda.memory_reserved(device) / (1024 ** 2)
    print(tag)
    print(f"Memory allocated: {allocated:.2f} MB")
    print(f"Memory reserved: {reserved:.2f} MB")

    
    if False:
        print(torch.cuda.memory_summary(device=None, abbreviated=False))
        
    if True:
        # List all CUDA tensors and collect their details in a dictionary.
        cuda_tensors = list_cuda_tensors()
        current_tensor_info = {id(tensor): (tensor.shape, tensor.device) for tensor in cuda_tensors}
        
        # Determine new and deleted tensor IDs.
        current_ids = set(current_tensor_info.keys())
        previous_ids = set(previous_tensor_info.keys())

        print("Total Tensors:", len(current_ids), ", Changes:", len(current_ids)-len(previous_ids))

        if False:
            # Determine new tensors since the last call.
            new_tensor_ids = current_tensor_ids - previous_tensor_ids
            deleted_tensor_ids = previous_ids - current_ids
            if new_tensor_ids:
                print("New CUDA tensors created since the last call:")
                for tid in new_tensor_ids:
                    shape, dev = current_tensor_info[tid]
                    print(f"Tensor id: {tid} | Shape: {shape} | Device: {dev}")
    
            if deleted_tensor_ids:
                print("Deleted CUDA tensors since the last call:")
                for tid in deleted_tensor_ids:
                    shape, dev = previous_tensor_info[tid]
                    print(f"Tensor id: {tid} | Shape: {shape} | Device: {dev}")
        previous_tensor_info = current_tensor_info.copy()
        
def cur_memory_ids():
    # List all CUDA tensors and collect their details in a dictionary.
    cuda_tensors = list_cuda_tensors()
    current_tensor_info = {id(tensor): (tensor.shape, tensor.device) for tensor in cuda_tensors}
    
    # Determine new and deleted tensor IDs.
    current_ids = set(current_tensor_info.keys())
    return current_ids

def compare_memory_ids(previous_tensor_ids):
    current_tensor_ids = cur_memory_ids()
    new_tensor_ids = current_tensor_ids - previous_tensor_ids
    deleted_tensor_ids = previous_tensor_ids - current_tensor_ids 
    print("New tensor:", new_tensor_ids)
    print("Delted tensor:", deleted_tensor_ids)
    
def list_cuda_tensors():
    cuda_tensors = []
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) and obj.is_cuda:
                cuda_tensors.append(obj)
        except Exception:
            pass  # Some objects might not have the attributes we need.
    return cuda_tensors
    
def print_step(tag, main_step=False):
    global log_memory
    if log_memory:
        print_memory(tag)
    else:
        if log_step or main_step:
            print(tag)

def check_optimizer_duplicates(optimizer):
    seen_ids = set()
    duplicates = []
    for group in optimizer.param_groups:
        for param in group['params']:
            pid = id(param)
            if pid in seen_ids:
                duplicates.append(param)
            else:
                seen_ids.add(pid)
    return duplicates

def print_logits_ids(tag, logits, ids):
    global log_logits
    if log_logits:
        logits_len = logits.shape[1]
        ids_len = ids.shape[1]
        if True:
            logits_ids = torch.argmax(logits, dim=-1)
            ids_text = [tokenizer.decode(ids[i], skip_special_tokens=True) for i in range(ids.size(0))]
            logits_text = [tokenizer.decode(logits_ids[i], skip_special_tokens=True) for i in range(logits_ids.size(0))]
            
            print('##### ', tag, '( logits_len:', logits_len,', ids_len:', ids_len, ' )')
            print('First five logits_ids:', logits_ids[0][:5].tolist(), ', First five ids:', ids[0][:5].tolist())
            print('###### logit text:',logits_text[0][:100])
            print('###### ids_text:',ids_text[0][:100])
    print_tensor(logits, name=tag+'(logits)')
    print_tensor(ids, name=tag+'(ids)')

def print_tensor(tensor, name=None):
    if log_tensor:
        if name is None:
            # Try to infer the variable name from the caller's local variables.
            frame = inspect.currentframe().f_back
            # Look for local variables that are the same object as tensor.
            names = [var_name for var_name, var_val in frame.f_locals.items() if var_val is tensor]
            name = names[0] if names else "tensor"
        if not torch.is_floating_point(tensor):
            mean_val = tensor.float().mean().item()
        else:
            mean_val = tensor.mean().item()
        print(name, tensor.shape, '(min=',tensor.min().item(), ', avg=', mean_val, ', max=',tensor.max().item(), ')')

    return tensor

def match_shape(actual, expected):
    if len(actual) != len(expected):
        return False
    return all(e == a or e is None or e == -1 for a, e in zip(actual, expected))

def check_shape(self, expected_shape):
    if checking_shape:
        if not match_shape(self.shape, expected_shape):
            raise ValueError(f"Shape mismatch! Got {self.shape}, expected {expected_shape}")
    return self

def check_range(self, min_val, max_val, not_values=None):
    if checking_range:
        # Range check
        in_range = (self >= min_val) & (self <= max_val)
    
        # Optional exclusion check
        if not_values is not None:
            for v in not_values:
                in_range &= self != v
    
        if not torch.all(in_range):
            raise ValueError(f"Tensor check_range failed: values not in range [{min_val}, {max_val}] or contain excluded {not_values}")
    
    return self 

torch.Tensor.log = print_tensor
torch.Tensor.check_shape = check_shape
torch.Tensor.check_range = check_range
            
def samping(model, tokenizer, device, epoch, writer, sample_prompt, expected):
    # Include attention_mask in the tokenization
    sample_prompt = f"### Instruction\n\n{sample_prompt}\n\n### Response"
    inputs = tokenizer(sample_prompt, return_tensors="pt", return_attention_mask=True)
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    # Pass the attention_mask and explicitly set pad_token_id to eos_token_id for reliable generation
    full_ids = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=20,
        pad_token_id=tokenizer.eos_token_id
    )
    sample_text = tokenizer.decode(full_ids[0], skip_special_tokens=True)

    sample_text = sample_text.strip()
    if (log_content):print(f"Sample Output (Epoch {epoch + 1}): {sample_text}")
    if (log_content):print("Expected:", expected)
    writer.add_text("Sample Output", f"Epoch {epoch + 1}: {sample_text}", epoch)

def write_time_file(folder):
    now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    filename = f"{now}.txt"
    os.makedirs(folder, exist_ok=True)
    
    filepath = os.path.join(folder, filename)
    with open(filepath, "w") as f:
        f.write("This is a dummy file.\n")
    
    print(f"Dummy file written: {filepath}")

In [10]:
# Optimization
# ------------------------------------------------
def write_weight_state(model, writer, step, log_group):
    for idx, (name, param) in enumerate(model.named_parameters()):
        if param.requires_grad:
            weight_mean = param.data.mean().item()
            weight_std = param.data.std().item()

            writer.add_scalar(f"{log_group}/{idx}_{name}/mean", weight_mean, step)
            writer.add_scalar(f"{log_group}/{idx}_{name}/std", weight_std, step)

def change_grad(model, layer_start, layer_end, multiple=0.01):
    for idx, val in enumerate(model.named_parameters()):
        name, param = val
        if param.grad is None:
            continue
        if layer_start <= idx < layer_end:
                param.grad.mul_(multiple)
            
class TrainLayerUpdater:
    def __init__(self, model, train_layer):
        self.model = model.model
        self.train_layer = train_layer

    def get_layer_params(self):
        if self.train_layer == TrainLayers.FULL_LAYER:
            params = list(self.model.parameters())
        elif self.train_layer == TrainLayers.TWO_FRONT_LAYER:
            return [p for layer in self.model.layers[:2] for p in layer.parameters()]
        elif self.train_layer == TrainLayers.ODD_LAYER:
            params = [p for idx, layer in enumerate(self.model.layers) if idx % 2 == 1 for p in layer.parameters()]
        elif self.train_layer == TrainLayers.EVEN_LAYER:
            params = [p for idx, layer in enumerate(self.model.layers) if idx % 2 == 0 for p in layer.parameters()]
        elif self.train_layer == TrainLayers.SWITCH_PAIR_LAYER:
            idx = self.config.current_layer_index
            params = []
            if idx < len(self.model.layers):
                params.extend(list(self.model.layers[idx].parameters()))
            if (idx + 1) < len(self.model.layers):
                params.extend(list(self.model.layers[idx + 1].parameters()))
            # Cycle the current_layer_index for the next update
            self.config.current_layer_index = (idx + 2) % len(self.model.layers) 
        else:
            raise ValueError("Invalid train_layer configuration")

        # add first two layer whatever
        #params.extend([p for layer in self.model.layers[:2] for p in layer.parameters()]) # params
        # remove first two layer
        params = [x for x in params if x not in self.model.layers[:2]] # 
        return params

    def update_optimizer_and_requires_grad(self, optimizer):
        # Get the new set of parameters for training.
        new_params = self.get_layer_params()
        new_params_set = set(new_params)
        
        # Update requires_grad flags for all model parameters.
        for param in self.model.parameters():
            param.requires_grad = param in new_params_set

        # Update the optimizer parameter group (assuming a single group).
        optimizer.param_groups[0]['params'] = list(new_params)

In [11]:
# Tensor utility
# ------------------------------------------------


def pad_to_match(tensor_a, tensor_b, padding_value=0):
    # Determine the current sequence lengths
    seq_len_a = tensor_a.size(1)
    seq_len_b = tensor_b.size(1)

    if seq_len_a > seq_len_b:
        max_seq_len = max(seq_len_a, seq_len_b)
    
        # Define padding function
        def pad_tensor(tensor, target_length):
            pad_length = target_length - tensor.size(1)
            if pad_length > 0:
                padding = (0, 0) * (tensor.dim() - 2) + (0, pad_length)
                tensor = F.pad(tensor, padding, value=padding_value)
            return tensor
    
        # Pad both tensors to the maximum sequence length
        tensor_a_padded = pad_tensor(tensor_a, max_seq_len)
        tensor_b_padded = pad_tensor(tensor_b, max_seq_len)
    else:
        tensor_b_padded = tensor_b[:, :seq_len_a]
        tensor_a_padded = tensor_a

    return tensor_a_padded, tensor_b_padded



def selective_log_softmax(logits, input_ids, tokenizer):
    # Ensure input_ids are on the same device as logits
    if input_ids.device != logits.device:
        input_ids = input_ids.to(logits.device)

    log_probs = nn.functional.log_softmax(logits, dim=-1)
    if input_ids.size(1) > log_probs.size(1):
        input_ids = input_ids[:, :log_probs.size(1)]

    # Gather log probabilities corresponding to input_ids
    selected_log_probs = log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

    if (log_content):
        input_text = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
        print("Input Texts:")
        for text in input_text:
            print(text)
        logits_ids = logits.argmax(dim=-1)
        logit_text = tokenizer.batch_decode(logits_ids, skip_special_tokens=True)
        print("\nLogit Texts:")
        for text in logit_text:
            print(text)

    return selected_log_probs

def add_front_transformer_block(self, copy_weights: bool = True):
    # Retrieve the current first transformer block.
    layer_index = 1
    original_first_block = self.model.layers[layer_index]

    # Create a new block.
    new_block = copy.deepcopy(original_first_block) if copy_weights else type(original_first_block)()

    self.model.layers.insert(layer_index, new_block)

    self.config.num_hidden_layers += 1

def cut_tensors_by_min(a: torch.Tensor, b: torch.Tensor, dim: int):
    assert a.dim() > dim and b.dim() > dim, "Specified dim exceeds tensor rank"

    min_length = min(a.size(dim), b.size(dim))
    a_cut = torch.narrow(a, dim, 0, min_length)
    b_cut = torch.narrow(b, dim, 0, min_length)
    return a_cut, b_cut

def cut_ids_on_eos_tensor(full_ids, eos_token_id):
    processed_ids = []
    for seq in full_ids:
        eos_positions = (seq == eos_token_id).nonzero(as_tuple=True)[0]
        if eos_positions.numel() > 0:
            first_eos_index = eos_positions[0].item()
            processed_ids.append(seq[:first_eos_index])
        else:
            processed_ids.append(seq)
    return processed_ids
    
def cut_ids_on_eos(generated_ids, eos_token_id):
    processed_ids = []
    for seq in generated_ids:
        if eos_token_id in seq:
            # Truncate the sequence at the first occurrence of the EOS token
            first_eos_index = seq.index(eos_token_id)
            processed_ids.append(seq[:first_eos_index])
        else:
            processed_ids.append(seq)
    return processed_ids


def shift_ids_with_logits(ids, shift_logits):
    shift_ids = torch.cat([ids[:, 1:], torch.argmax(shift_logits[:, -1, :], dim=-1).unsqueeze(1)], dim=1)
    return shift_ids

In [12]:
# Train scheduling
# ------------------------------------------------

def train_and_evaluate(model, ref_model, lr, optimizer, device, num_epochs, group_size,
                       num_grpo, epsilon, kl_lambda, scaler, save_epochs, start_epoch,
                       scheduler, gradient_accumulation_step, skip_validation_step=False, 
                       temperature=1.0,
                       train_layer_updater=None, category_count_start=1, is_finding_opt=False):
    global learning_name_time 
    lr_str = f"{lr:.5g}"
    kl_lambda_str = f"{kl_lambda:.5g}"
    epsilon_str = f"{epsilon:.5g}"
    temperature_str = f"{temperature:.5g}"
    log_dir = f"runs/{learning_name_time}_{lr_str}_{kl_lambda_str}_{epsilon_str}_{num_grpo}_{temperature_str}" 
    write_time_file(log_dir)
    writer = SummaryWriter(log_dir=log_dir)
    global batch_size
    reward_work = RewardWorkPool(group_size*batch_size)

    # --- Generate sample output text after each epoch ---
    model.eval()  # Set to eval mode for generation.
    with torch.no_grad():
        samping(model, tokenizer, device, 0, writer,
                "In Custom Clang-repl, What is the prompt in Custom Clang-repl?",
                "```\n>>> (prompt)\n```")
        samping(model, tokenizer, device, 0, writer,
                "In Custom Clang-repl, Do we allow multiline comments or backslash-extended lines in Custom Clang-repl Test?",
                "Custom Clang-repl takes only one line input.")
    model.train()  # Switch back to training mode.

    global_step = 0

    mean_reward = 0
    sum_mean_reward = 0
    categories = get_all_categories()
    category_size = len(categories)
    cur_category_count = category_count_start
    last_category_count = cur_category_count
    switch_pair_layer = 0

    print("Total categories:", category_size, categories)

    dataloader = get_train_dataloader(categories[:cur_category_count])
    print("Data counts:", len(dataloader))
    val_dataloader = get_val_dataloader(categories[:cur_category_count])

    for epoch in range(start_epoch, num_epochs):
        running_loss = 0.0
        print_step(f"Epoch {epoch+1}/{num_epochs} - Validation", main_step=True)

        if train_layer_updater.train_layer == TrainLayers.SWITCH_PAIR_LAYER:
            if switch_pair_layer >=4 and mean_reward > 2.0 and cur_category_count < category_size:
                switch_pair_layer = 0
                if not is_finding_opt: cur_category_count += 1
            else:
                train_layer_updater.update_optimizer_and_requires_grad(optimizer)
                switch_pair_layer += 1
        else:
            if mean_reward > 2.0 and cur_category_count < category_size:
                if not is_finding_opt: cur_category_count += 1

        print_memory(20)
        old_model = copy.deepcopy(model).half()
        old_model.eval()
        for param in old_model.parameters():
            param.requires_grad = False
        print_memory(21)

        if not skip_validation_step:
            print("Validation Start ....")
            # Run validation (no parameter updates).
            _, _, _ = run(model, old_model, ref_model, val_dataloader, optimizer, device, tokenizer,
                                 group_size, epsilon, kl_lambda, scaler, writer, global_step, reward_work=reward_work,
                                 log_group="validation", scheduler=None,
                                 gradient_accumulation_step=gradient_accumulation_step, is_validation=True,
                                 temperature=temperature)
            print("Validation End ....")

        print("Training Start ....")
        # Loop over gradient groups for training.
        for grpo_idx in range(num_grpo):
            print_step(f"Epoch {epoch+1}/{num_epochs} - Training Gradient Group {grpo_idx+1}/{num_grpo}, category include ={categories[cur_category_count-1]}", main_step=True)
            
            if last_category_count != cur_category_count:
                last_category_count = cur_category_count
                dataloader = get_train_dataloader(categories[:cur_category_count])
                print("Data counts:", len(dataloader))
                
            loss, mean_reward, global_step = run(model, old_model, ref_model, dataloader, optimizer, device, tokenizer,
                                      group_size, epsilon, kl_lambda, scaler, writer, global_step, reward_work=reward_work,
                                      log_group="Training", scheduler=scheduler,
                                      gradient_accumulation_step=gradient_accumulation_step, is_validation=False,
                                    temperature=temperature)
            
            running_loss += loss
            sum_mean_reward += mean_reward
        old_model = None
        print_step("7. End Epoch")
        avg_loss = running_loss / len(dataloader)
        avg_reward = sum_mean_reward / len(dataloader)
        if log_step: print(f"Epoch {epoch + 1} completed. Average Loss: {avg_loss:.4f}")
        writer.add_scalar("Epoch/Average_Loss", avg_loss, epoch + 1)
        print("Training End ....")

        # Save latest checkpoint.
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch + 1
        }
        os.makedirs(os.path.dirname(last_checkpoint_path), exist_ok=True)
        torch.save(checkpoint, last_checkpoint_path)

        # Optionally save checkpoint on specific epochs.
        if save_epochs is not None and epoch % save_epochs == 0:
            checkpoint_dir = checkpoint_dir_pre + str(epoch + 1)
            os.makedirs(checkpoint_dir, exist_ok=True)
            checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pt")
            torch.save(checkpoint, checkpoint_path)
            if log_step: print(f"Checkpoint saved at epoch {epoch + 1} to {checkpoint_path}")

            # Save tokenizer once if not already saved.
            tokenizer_save_dir = "./saved_models/tokenizer"
            if not os.path.exists(tokenizer_save_dir):
                os.makedirs(tokenizer_save_dir, exist_ok=True)
                tokenizer.save_pretrained(tokenizer_save_dir)
                if log_step: print("Tokenizer saved.")


    writer.close()
    return avg_reward

def train(
        num_epochs,
        lr,
        kl_lambda,
        epsilon,
        num_grpo,
        group_size,
        warming_up_step,
        gradient_accumulation_step,
        save_epochs=None,
        skip_validation_step=False,
        is_finding_opt=False,
        temperature=1.0,
        category_count_start=1):
    global use_reference_model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Check if a latest checkpoint exists to load model and optimizer states
    if os.path.exists(last_checkpoint_path) and not is_finding_opt and not use_reference_model:
        print("==USING CHECK POINT MODEL==")
        print_memory(1)
        checkpoint = torch.load(last_checkpoint_path, map_location=torch.device("cpu"))
        _model = AutoModelForCausalLM.from_pretrained(model_id)
        config = copy.deepcopy(_model.config)
        _model = None
        print_memory(2)
        config.num_hidden_layers += 2
        config.max_position_embeddings=512
        model = AutoModelForCausalLM.from_config(config)
        print_memory(3)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        print_memory(4)
        train_layer_updater = TrainLayerUpdater(model, train_layer) 
        #optimizer = Adafactor(train_layer_updater.get_layer_params(), lr=lr, relative_step=False, scale_parameter=False)
        optimizer = torch.optim.AdamW(train_layer_updater.get_layer_params(), lr=lr)
        train_layer_updater.update_optimizer_and_requires_grad(optimizer)
        #ArithmeticErroroptimizer = bnb.optim.AdamW8bit(model.parameters(), lr=lr, betas=(0.9, 0.999))
        #optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0)
        if log_step: print(f"Loaded checkpoint from last {last_checkpoint_path} at epoch {start_epoch}")
        print_memory(7)
    else:
        if os.path.exists(ref_checkpoint_path):
            print("==USING REFERENCE MODEL==")
            print_memory(1)
            checkpoint = torch.load(ref_checkpoint_path, map_location=torch.device("cpu"))
            _model = AutoModelForCausalLM.from_pretrained(model_id)
            config = copy.deepcopy(_model.config)
            _model = None
            print_memory(2)
            config.num_hidden_layers += 2
            config.max_position_embeddings=512
            model = AutoModelForCausalLM.from_config(config)
            print_memory(3)
            model.load_state_dict(checkpoint['model_state_dict'])
            model.to(device)
            print_memory(4)
            train_layer_updater = TrainLayerUpdater(model, train_layer) 
            #optimizer = Adafactor(train_layer_updater.get_layer_params(), lr=lr, relative_step=False, scale_parameter=False)
            optimizer = torch.optim.AdamW(train_layer_updater.get_layer_params(), lr=lr)
            train_layer_updater.update_optimizer_and_requires_grad(optimizer)
            #optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=lr, betas=(0.9, 0.999))
            #optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = 0 #checkpoint.get('epoch', 0)
            if log_step: print(f"Loaded checkpoint from reference {ref_checkpoint_path} at epoch {start_epoch}")
            print_memory(7)

        else:
            assert False, "prompt_last_checkpoint_path must exist"

    dups = check_optimizer_duplicates(optimizer)
    if dups:
        print("Warning: The optimizer contains duplicate parameters!")
        print(f"Duplicate parameter count: {len(dups)}")
    else:
        print("No duplicate parameters found in the optimizer.")

    # Clear cached memory that is no longer used
    torch.cuda.empty_cache()
    gc.collect()
    print_memory(9)

    # Reference model (for KL)
    ref_model = None
    ref_model = copy.deepcopy(model).half().eval()
    for param in ref_model.parameters():
        param.requires_grad = False
    print_memory(10)

    # AMP GradScaler
    scaler = torch.cuda.amp.GradScaler()

    def lr_schedule(step):
        # Linear warm-up to 1.0, then constant
        return min(1.0, step / warming_up_step)

    scheduler = LambdaLR(optimizer, lr_lambda=lr_schedule)

    # Train & get final metric
    final_avg_loss = train_and_evaluate(
        model=model,
        ref_model=ref_model,
        lr=lr,
        optimizer=optimizer,
        device=device,
        num_epochs=num_epochs,
        group_size=group_size,
        num_grpo=num_grpo,
        epsilon=epsilon,
        kl_lambda=kl_lambda,
        scaler=scaler,
        save_epochs=save_epochs,
        start_epoch=start_epoch,
        scheduler=scheduler,
        gradient_accumulation_step=gradient_accumulation_step,
        skip_validation_step=skip_validation_step,
        temperature=temperature,
        train_layer_updater=train_layer_updater,
        category_count_start=category_count_start,
        is_finding_opt=is_finding_opt
        
    )

    # Return the final average loss to Optuna
    return final_avg_loss


def objective(trial):
    global group_size
    num_epochs, lr, kl_lambda, epsilon, num_grpo, warming_up_step, gradient_accumulation_step, temperature, category_count_start = object_hiper_param(trial)

    print(
        f"[Optuna] Trial hyperparameters -> lr: {lr}, kl_lambda: {kl_lambda}, epsilon: {epsilon}, num_grpo: {num_grpo}, warming_up_step: {warming_up_step}, gradient_accumulation_step: {gradient_accumulation_step}, temperature: {temperature}, category_count_start: {category_count_start}")
    return train(
        num_epochs=num_epochs,
        lr=lr,
        kl_lambda=kl_lambda,
        epsilon=epsilon,
        num_grpo=num_grpo,
        group_size=group_size,
        warming_up_step=warming_up_step, 
        gradient_accumulation_step=gradient_accumulation_step,
        skip_validation_step=True,
        is_finding_opt=True,
        temperature=temperature,
        category_count_start=category_count_start
    )


def main(
        skip_validation_step,
        objective,
        is_finding_opt=False,
        category_count_start=1):
    if is_finding_opt:
        # Create study to minimize final loss
        study = optuna.create_study(direction="maximize")
        study.optimize(objective, n_trials=5)  # You can increase n_trials

        print("Study completed!")
        print("Best trial:")
        best_trial = study.best_trial
        print(f"  Value: {best_trial.value}")
        print("  Params: ")
        for key, value in best_trial.params.items():
            print(f"#    {key}: {value}")
        with open("hiper_param.json", "w") as f:
            json.dump(dict(best_trial.params.items()), f, indent=4)
    else:
        global num_epochs, lr, kl_lambda, epsilon, num_grpo, save_epochs, warming_up_step, gradient_accumulation_step, temperature
        train(
            num_epochs=num_epochs,
            lr=lr,
            kl_lambda=kl_lambda,
            epsilon=epsilon,
            num_grpo=num_grpo,
            group_size=group_size,
            save_epochs=save_epochs,
            warming_up_step=warming_up_step,
            gradient_accumulation_step=gradient_accumulation_step,
            skip_validation_step=skip_validation_step,
            temperature=temperature,
            category_count_start=category_count_start
        )

In [None]:
# Core logic
# ------------------------------------------------

def generate_ids(model, batch, tokenizer, temperature):
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    eos_token_id = tokenizer.eos_token_id

    # Determine prompt length for each example in the batch based on the first occurrence of EOS.
    prompt_lengths = []
    for i in range(input_ids.size(0)):
        seq = input_ids[i]
        # Find indices where the token equals the eos_token_id.
        eos_positions = (seq == eos_token_id).nonzero(as_tuple=True)[0]
        # If there's at least one occurrence, use its index + 1 (if you want to include the EOS in the prompt).
        # Otherwise, fallback to the full sequence length.
        if eos_positions.numel() > 0:
            first_eos = eos_positions[0].item() + 1
        else:
            first_eos = seq.size(0)
        prompt_lengths.append(first_eos)
    
    print_memory("Prompt lengths per batch element: " + str(prompt_lengths))

    output = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=max_length,  # assuming max_length is defined globally
        temperature=temperature,
        do_sample=True,
        eos_token_id=eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        output_scores=True,
        return_dict_in_generate=True
    )
    full_ids = output.sequences.detach()
    truncated_ids = cut_ids_on_eos_tensor(full_ids, tokenizer.eos_token_id)    
    respone_ids =  pad_sequence([truncated_ids[idx][p_len:] for idx, p_len in enumerate(prompt_lengths)],
                                               batch_first=True, padding_value=tokenizer.pad_token_id)
    truncated_ids = pad_sequence(truncated_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    full_ids = pad_sequence(full_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    output = None
    print_memory("full_ids.shape[-1]: " + str(full_ids.shape[-1]))
    return full_ids, truncated_ids, respone_ids, prompt_lengths


def compute_logits(model, full_ids, prompt_lengths, respone_ids, tokenizer, detach_out=False):
    # Pad the list of full_ids to a whole tensor with shape (batch, max_seq_length)
    full_ids = pad_sequence(full_ids, batch_first=True, padding_value=tokenizer.pad_token_id).to(dtype=torch.int32)
    
    # Create an attention mask where non-pad tokens are 1 and pad tokens are 0
    full_ids_mask = (full_ids != tokenizer.pad_token_id).to(dtype=torch.int32, device=full_ids.device)
    
    # Compute logits for the whole padded tensor.
    logits = model(input_ids=full_ids, attention_mask=full_ids_mask, early_stop=False).logits
    
    truncated_response_ids_list = []
    truncated_response_logits_list = []
    batch_size = full_ids.size(0)
    
    for i in range(batch_size):
        p_len = prompt_lengths[i]
        # Determine the true sequence length (ignoring padding) for this batch element.
        actual_length = full_ids_mask[i].sum().item()
        # Ensure prompt length does not exceed actual length.
        if p_len > actual_length:
            p_len = actual_length

        # Extract completion token IDs for this example.
        comp_ids = full_ids[i, p_len:actual_length].detach()
        # For logits, if you want to include the token just before the completion, slice from p_len-1.
        comp_logits = logits[i, p_len-1:actual_length-1, :]
        
        # Optionally, adjust lengths to be consistent (if needed by downstream code)
        #comp_ids, comp_logits = cut_tensors_by_min(comp_ids, comp_logits, 0)
        expected_len = respone_ids.shape[1]
        truncated_response_ids_list.append((comp_ids.detach() if detach_out else comp_ids)[:expected_len])
        truncated_response_logits_list.append((comp_logits.detach() if detach_out else comp_logits)[:expected_len, :])

    truncated_response_logits = pad_sequence(truncated_response_logits_list, batch_first=True, padding_value=tokenizer.pad_token_id)
    truncated_response_ids = pad_sequence(truncated_response_ids_list, batch_first=True, padding_value=tokenizer.pad_token_id)
    return logits, truncated_response_logits, truncated_response_ids



# ------------------------------------------------
# Define Training Function
# ------------------------------------------------
def run(model, old_model, ref_model, dataloader, optimizer, device, tokenizer,
        group_size, epsilon, kl_lambda, scaler, writer, global_step, log_group,
        scheduler, gradient_accumulation_step, reward_work, is_validation=False, temperature=1.0):
    running_loss = 0.0
    mean_reward = 0.0
    sum_reward = 0.0
    run_start_global_step=global_step
    print_memory("_.1. run() enter")
    # For accumulation mode, ensure gradients are zeroed at the start.
    if not is_validation:
        optimizer.zero_grad()

    for step, batch in enumerate(tqdm(dataloader, total=len(dataloader)), start=1):
        print_step(f"Processing batch {step}/{len(dataloader)}: Start Loop")
        # Move batch to device and expand the tensors.
        batch = {k: v.to(device=device, dtype=torch.int32) for k, v in batch.items()}
        batch_size = len(batch)
        input_ids = batch['input_ids'].repeat_interleave(group_size, dim=0)
        attention_mask = batch['attention_mask'].repeat_interleave(group_size, dim=0)
        batch_size = input_ids.size(0)
        batch['input_ids'] = input_ids
        batch['attention_mask'] = attention_mask

        start_tensor_ids = cur_memory_ids()

        # LOGIT sample (min= -34.09375 , avg= 0.364013671875 , max= 42.25 )
        # LOG_LOGIT sample: (min= -18.188087463378906 , avg= -0.3128775656223297 , max= 0.0 )
        # LOG_PROBE sample: (min= -22.589847564697266 , avg= -0.4811277985572815 , max= 0.0 )
        # kl_div sample: (min= 0.07530781626701355 , avg= 0.0904245376586914 , max= 0.10227474570274353 )

        # std_rewards: 0.5049999952316284 
        # advantages: (min= 1.2400000095367432 , avg= 1.4900000095367432 , max= 2.240000009536743 )
        # A_hat: (min= -1.4997029304504395 , avg= 1.4901161193847656e-08 , max= 0.4999009966850281 )
        # unclipped_objective: (min= -483673.78125 , avg= -212.85939025878906 , max= 221674.734375 )
        # clipped_objective: (min= 0.7576461434364319 , avg= 0.7710149884223938 , max= 1.2423537969589233 )
        # ppo_loss: -0.5694103240966797 , kl_div: 0.0904245376586914

        # std_rewards: 0.0 Rewards: [1.24, 1.24, 1.24, 1.24]
        # advantages: (min= 1.2400000095367432 , avg= 1.2400000095367432 , max= 1.2400000095367432 )
        # A_hat: (min= 0.0 , avg= 0.0 , max= 0.0 )
        # unclipped_objective: (min= 0.0 , avg= 0.0 , max= 0.0 )
        # clipped_objective: (min= 0.7576461434364319 , avg= 0.7576462030410767 , max= 0.7576461434364319 )
        # ppo_loss: 0.0 , kl_div: 0.0454302616417408

        
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            # 1. Model forward pass for generation.
            print_step("1. Model train")
            with torch.no_grad():
                full_ids, truncated_ids, respone_ids, prompt_lengths = generate_ids(model, batch, tokenizer, temperature)
                full_ids.log() # FULL_IDS
                truncated_ids.log() # TRUNCATED_IDS
                respone_ids.log() # RESPONSE_IDS
                full_text_lists = tokenizer.batch_decode(truncated_ids, skip_special_tokens=True)
                reward_work.reward(full_text_lists, writer, log_group, global_step)
                # Release unused tensors from generation.
                full_text_lists = None
            _full_shift_logits, response_truncated_logits, _ = compute_logits(model, full_ids, prompt_lengths, respone_ids, tokenizer) 
            _full_shift_logits.log() # FULL_LOGITS
            response_truncated_logits.log() # RESPONSE_LOGITS

            full_shift_ids = shift_ids_with_logits(full_ids, _full_shift_logits)
            full_shift_logits = pad_sequence(_full_shift_logits, batch_first=True, padding_value=tokenizer.pad_token_id)
            print_logits_ids("model full", full_shift_logits, full_shift_ids)  # good format confirmed, full_shift_logits: FULL_LOGITS, full_shift_ids: FULL_IDS
            
            FULL_IDS = full_ids.shape  # [batch, full_ids_len], sample: batch=4, full_ids=512
            TRUNCATED_IDS = truncated_ids.shape # [batch, truncated_ids_len],  sample: truncated_ids_len=512 or less. FULL_IDS with cut out end parts after eos
            RESPONSE_IDS = respone_ids.shape  # [batch, respone_ids_len], sample: respone_ids_len = 466 = truncated_ids_len-prompt_length
            FULL_LOGITS = full_shift_logits.shape # [batch, full_ids_len, embedding_len], sample: embedding_len=49152
            RESPONSE_LOGITS = response_truncated_logits.shape #  [batch, respone_ids_len, embedding_len]
            # GROUPED_BATCH  = advantages.shape # [grouped_batch, group_size], example grouped_batch=1, group_size=4 see advantages creation
            

            # 2. Run legacy models (old and reference models).
            print_step("2. Legacy Models Run")
            with torch.no_grad():
                _, old_response_truncated_logits, _ = compute_logits(old_model, truncated_ids, prompt_lengths, respone_ids, tokenizer, detach_out=True)
                ref_full_shift_logits, _, _ = compute_logits(ref_model, full_ids, prompt_lengths, respone_ids, tokenizer, detach_out=True)
                print_logits_ids("ref model full", ref_full_shift_logits, full_shift_ids) # good format confirmed, ref_full_shift_logits: FULL_LOGITS, full_shift_ids: FULL_IDS
            truncated_ids = None
            prompt_lengths = None
            full_ids = None
            
            with torch.no_grad():
                print_logits_ids("model response", response_truncated_logits, respone_ids) # good format confirmed, response_truncated_logits: RESPONSE_LOGITS, respone_ids: RESPONSE_IDS
                model_log_logits = selective_log_softmax(response_truncated_logits, respone_ids, tokenizer).check_shape(RESPONSE_IDS)
                model_log_logits.log()
                print_logits_ids("old model response", old_response_truncated_logits, respone_ids) # good format confirmed, old_response_truncated_logits: RESPONSE_LOGITS, respone_ids: RESPONSE_IDS
                old_model_log_logits = selective_log_softmax(old_response_truncated_logits, respone_ids, tokenizer).check_shape(RESPONSE_IDS)
                old_model_log_logits.log() 
                probability_ratio = torch.exp(model_log_logits - old_model_log_logits).check_shape(RESPONSE_IDS)
                probability_ratio.log() 
                
                # Remove legacy model intermediates (no longer needed)
                full_truncated_full_logits = None
                response_truncated_logits = None
                old_response_truncated_logits = None
                ref_completion_ids = None
                model_log_logits = None
                old_model_log_logits = None


            # 3. kl_div Loss Calc
            print_step("3. kl_div Loss Calc")    
            # Calculate token-level log probabilities.
            model_log_probs = selective_log_softmax(full_shift_logits, full_shift_ids, tokenizer)
            model_log_probs.log() # RESPONSE_IDS
            ref_log_probs = selective_log_softmax(ref_full_shift_logits, full_shift_ids, tokenizer)
            ref_log_probs.log() # RESPONSE_IDS
            
            # Compute token-level KL divergence.
            token_kl_div = F.kl_div(model_log_probs, ref_log_probs, reduction='none', log_target=True).check_shape(FULL_IDS) # it is not an ids but parts of logits content. the shape is just like ids)
            token_kl_div.log()
            kl_div = token_kl_div.mean(dim=-1).check_shape([batch_size])
            kl_div.log() # average over tokens. range (0, infite) but for output of similar model. It is very small. sample: kl_div=0.09
            
            # Create a mask for non-padding tokens.
            completion_mask = (respone_ids != tokenizer.pad_token_id).to(dtype=torch.float32, device='cuda').check_shape(RESPONSE_IDS) 
            completion_mask.log() 

            # Save scalar values for logging before clearing.
            kl_div_val=kl_div.mean().item()
            
            # Remove now-unused intermediate tensors.
            ref_log_probs = None
            ref_full_shift_logits = None
            full_shift_logits = None
            full_shift_ids = None
            model_log_probs = None
            token_kl = None 
            
            # 4. Calculate rewards.
            print_step("4. Reward calc")
            
            reward_work_result = reward_work.take_result()
            if not reward_work_result:  # reward list is empty
                response_texts = tokenizer.batch_decode(respone_ids, skip_special_tokens=True)
                writer.add_text(f"{log_group}/reward_empty_response", str(response_texts), global_step=step)
                rewards = [0.0]*group_size
                if False:
                    # Remove now-unused intermediate tensors.
                    kl_div = None
                    respone_ids = None
                    probability_ratio = None
                    completion_mask = None
                    continue  # skip to next batch
            else:
                rewards, responses = reward_work_result
            sum_reward += sum(rewards)
            # rewards list[batch]
            
            if all(reward > 2.0 for reward in rewards):
                # perfect no loss in grouped_ppo
                grouped_ppo_loss = -(torch.ones(len(rewards), dtype=torch.float32, device=device) + epsilon).check_shape([batch_size])
                grouped_ppo_loss.log()
            else:
                # Convert rewards to tensor
                grouped_batch_size = len(rewards) // group_size
                advantages = torch.tensor(rewards, dtype=torch.float32, device=device).view(grouped_batch_size, group_size).check_range(0, 2.24)
                advantages.log() 
                
                # Calculate mean and std per batch (along dim=1) and repeat to match original size
                mean_rewards = advantages.mean(dim=1).repeat_interleave(group_size).check_shape([batch_size]).check_range(0, 2.24)
                mean_rewards.log()
                std_rewards = advantages.std(dim=1).repeat_interleave(group_size).check_shape([batch_size]).check_range(0, float('inf'))
                std_rewards.log() 
                print(">>>>>>>>>>>>>>>>>>>> mean_rewards:", mean_rewards[0].item(), "std_rewards:", std_rewards[0].item(), "Rewards:", rewards)
    
                # Reshape back to original form
                advantages = advantages.view(-1)
                advantages.check_shape([batch_size]).log("advantages before A_hat")
                A_hat = ((advantages - mean_rewards) / (std_rewards + 1e-4)).unsqueeze(1).check_shape([batch_size, 1]) # range (-infinite, infinite)
                A_hat.log() 
                
                # Clear rewards intermediates.
                advantages = None
                # 5. grouped_ppo Loss Calc
                print_step("5. Grouped ppo Loss Calc")            
                # PPO objective calculations.
                unclipped_objective = probability_ratio
                unclipped_objective.check_shape(RESPONSE_IDS).log()
                epsilon_high = torch.full_like(unclipped_objective, 1 + epsilon).check_shape(RESPONSE_IDS)
                epsilon_low  = torch.full_like(unclipped_objective, 1 - epsilon).check_shape(RESPONSE_IDS)
                
    
                # if advantage > 0: take the minimum (to avoid too large an update)
                # if advantage < 0: take the maximum
                _grouped_ppo_loss = - torch.where(A_hat >= 0 -1e-4, 
                            torch.minimum(unclipped_objective, epsilon_high), # do not use torch.max()
                            torch.maximum(unclipped_objective, epsilon_low))  # do not use torch.min()
                _grouped_ppo_loss.log("before A_hat multiply")
                _grouped_ppo_loss = _grouped_ppo_loss * A_hat
                _grouped_ppo_loss.check_shape(RESPONSE_IDS).log() 
                grouped_ppo_loss = _grouped_ppo_loss.mean(dim=-1).check_shape([batch_size])
                grouped_ppo_loss.log() # sample epsilon=0.2
    
                # Remove now-unused intermediate tensors.
                A_hat = None
                unclipped_objective = None
                clipped_ratio = None
                clipped_objective = None
            
            
            # Assume kl_lambda is a scaling factor for the KL term
            _combined_loss = grouped_ppo_loss + kl_lambda * kl_div
            _combined_loss.check_shape([batch_size]).log() 
            combined_loss = _combined_loss.mean()
            combined_loss.log() # []

            # Save scalar values for logging before clearing.
            ppo_loss_val = grouped_ppo_loss.mean().item()
            combined_loss_val = combined_loss.mean().item()

            # Remove now-unused intermediate tensors.
            respone_ids = None
            grouped_ppo_loss = None
            kl_div = None
            probability_ratio = None

            # Save scalar values for logging before clearing.
            print("Final Loss:", combined_loss_val, ", ppo_loss:", ppo_loss_val, ", kl_div:", kl_div_val)

            # Remove now-unused intermediate tensors.
            completion_mask = None
            per_token_loss = None
            
            # 6. Backpropagation and parameter update (only if not in validation mode).
            print_step("6. Backpropagation and parameter update") 
            is_param_updated = False

        if not is_validation:
            scaler.scale(combined_loss).backward()
            scaler.step(optimizer)
            write_weight_state(model, writer, step, log_group+'_weights')
            if False:
                change_grad(model, 0, 2, multiple=0.01)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()
            is_param_updated = True
            # Remove unused variables from the current iteration.
            combined_loss = None
            #compare_memory_ids(start_tensor_ids)
            current_lr = optimizer.param_groups[0]['lr']
            writer.add_scalar(f"{log_group}/lr", current_lr, step)
        else:
            combined_loss = None

        running_loss += combined_loss_val

        # 7. Logging with dynamic log group.
        print_step("7. Logging")
        writer.add_scalar(f"{log_group}/combined_loss", combined_loss_val, global_step)
        writer.add_scalar(f"{log_group}/ppo_loss", ppo_loss_val, global_step)
        writer.add_scalar(f"{log_group}/kl_div", kl_div_val, global_step)
        writer.add_scalar(f"{log_group}/mean_reward", sum(rewards) / len(rewards), global_step)
        if is_param_updated:
            writer.add_scalar(f"{log_group}/model_update_combined_loss", combined_loss_val, global_step)

        print_step("8. End Loop")
        global_step += 1

    mean_reward = sum_reward/((global_step-run_start_global_step)*batch_size)
    writer.add_scalar(f"{log_group}_epoch/mean_reward", mean_reward, global_step)

    print_step("_.1. run() exit")
    return running_loss, mean_reward, global_step
    
if __name__ == "__main__":
    print("Start reasoning logic....")
    main(
        skip_validation_step,
        objective,
        is_finding_opt, 
        category_count_start=category_count_start)

[I 2025-03-30 09:59:10,037] A new study created in memory with name: no-name-4dae1035-7529-4793-ab91-f30acf00d3b4


Start reasoning logic....
[Optuna] Trial hyperparameters -> lr: 0.0002581746143279661, kl_lambda: 0.11056613361305812, epsilon: 0.09572867840913253, num_grpo: 1, warming_up_step: 1, gradient_accumulation_step: 1, temperature: 0.5881472509842769, category_count_start: 1
==USING REFERENCE MODEL==
No duplicate parameters found in the optimizer.
Dummy file written: runs/starcoder2-3b_reasoning_FULL_LAYER_30_09-59_0.00025817_0.11057_0.095729_1_0.58815/2025-03-30_10-00-02.txt
Total categories: 9 ['simple arithmetic', 'simple if', 'simple loop', 'loop and if', 'simple state', 'recursive function', 'pointer manipulation', 'string manipulation', 'sort algorithm']


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Data counts: 10


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Epoch 1/4 - Validation
Training Start ....
Epoch 1/4 - Training Gradient Group 1/1, category include =simple arithmetic


  0%|                                                                                                                                                                                                                                 | 0/10 [00:00<?, ?it/s]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3565.04248046875 , max= 48567 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3565.04248046875 , max= 48567 )
respone_ids torch.Size([4, 466]) (min= 45 , avg= 3667.2548828125 , max= 48567 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -32.46875 , avg= 0.1549072265625 , max= 42.5 )
response_truncated_logits torch.Size([4, 466, 49152]) (min= -32.46875 , avg= 0.227783203125 , max= 42.5 )
model full(logits) torch.Size([4, 512, 49152]) (min= -32.46875 , avg= 0.1549072265625 , max= 42.5 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3564.5419921875 , max= 48567 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -25.125 , avg= 0.315185546875 , max= 45.34375 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3564.5419921875 , max= 48567 )
model response(logits) torch.Size([4, 466, 49152]) (min= -32.46875 , avg= 0.227783203125 , max= 42.5 )
model response(ids) torch.Size([4, 466]) (min= 45 , avg= 3667.2548828125 

 10%|█████████████████████▋                                                                                                                                                                                                   | 1/10 [00:19<02:57, 19.72s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3403.22412109375 , max= 41623 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3403.22412109375 , max= 41623 )
respone_ids torch.Size([4, 449]) (min= 45 , avg= 3503.42041015625 , max= 41623 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -35.25 , avg= 0.673828125 , max= 45.0 )
response_truncated_logits torch.Size([4, 449, 49152]) (min= -35.25 , avg= 0.76220703125 , max= 45.0 )
model full(logits) torch.Size([4, 512, 49152]) (min= -35.25 , avg= 0.673828125 , max= 45.0 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3400.77978515625 , max= 41623 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -27.765625 , avg= 0.51025390625 , max= 46.34375 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3400.77978515625 , max= 41623 )
model response(logits) torch.Size([4, 449, 49152]) (min= -35.25 , avg= 0.76220703125 , max= 45.0 )
model response(ids) torch.Size([4, 449]) (min= 45 , avg= 3503.42041015625 , max= 41623 )
m

 20%|███████████████████████████████████████████▍                                                                                                                                                                             | 2/10 [00:36<02:25, 18.20s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3189.853515625 , max= 44260 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3189.853515625 , max= 44260 )
respone_ids torch.Size([4, 466]) (min= 45 , avg= 3267.33056640625 , max= 44260 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -32.875 , avg= 0.09521484375 , max= 42.53125 )
response_truncated_logits torch.Size([4, 466, 49152]) (min= -32.875 , avg= 0.1793212890625 , max= 42.53125 )
model full(logits) torch.Size([4, 512, 49152]) (min= -32.875 , avg= 0.09521484375 , max= 42.53125 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3187.681640625 , max= 44260 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -25.015625 , avg= 0.313232421875 , max= 45.875 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3187.681640625 , max= 44260 )
model response(logits) torch.Size([4, 466, 49152]) (min= -32.875 , avg= 0.1793212890625 , max= 42.53125 )
model response(ids) torch.Size([4, 466]) (min= 45 , avg= 3267.330566406

 30%|█████████████████████████████████████████████████████████████████                                                                                                                                                        | 3/10 [00:54<02:05, 17.91s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3725.986328125 , max= 48567 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3725.986328125 , max= 48567 )
respone_ids torch.Size([4, 468]) (min= 45 , avg= 3804.085693359375 , max= 48567 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -33.6875 , avg= 0.431884765625 , max= 42.9375 )
response_truncated_logits torch.Size([4, 468, 49152]) (min= -33.6875 , avg= 0.493408203125 , max= 42.9375 )
model full(logits) torch.Size([4, 512, 49152]) (min= -33.6875 , avg= 0.431884765625 , max= 42.9375 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3723.96240234375 , max= 48567 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -27.109375 , avg= 0.385498046875 , max= 47.5 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3723.96240234375 , max= 48567 )
model response(logits) torch.Size([4, 468, 49152]) (min= -33.6875 , avg= 0.493408203125 , max= 42.9375 )
model response(ids) torch.Size([4, 468]) (min= 45 , avg= 3804.085693

 40%|██████████████████████████████████████████████████████████████████████████████████████▊                                                                                                                                  | 4/10 [01:12<01:47, 17.92s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 2647.8427734375 , max= 36923 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 2647.8427734375 , max= 36923 )
respone_ids torch.Size([4, 466]) (min= 45 , avg= 2717.0654296875 , max= 36923 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -39.03125 , avg= 0.3935546875 , max= 42.28125 )
response_truncated_logits torch.Size([4, 466, 49152]) (min= -39.03125 , avg= 0.47021484375 , max= 42.28125 )
model full(logits) torch.Size([4, 512, 49152]) (min= -39.03125 , avg= 0.3935546875 , max= 42.28125 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 2645.486328125 , max= 36923 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -26.359375 , avg= 0.409912109375 , max= 45.6875 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 2645.486328125 , max= 36923 )
model response(logits) torch.Size([4, 466, 49152]) (min= -39.03125 , avg= 0.47021484375 , max= 42.28125 )
model response(ids) torch.Size([4, 466]) (min= 45 , avg= 2717.06542

 50%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                            | 5/10 [01:30<01:29, 17.84s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3237.087890625 , max= 38475 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3237.087890625 , max= 38475 )
respone_ids torch.Size([4, 449]) (min= 45 , avg= 3351.951171875 , max= 38475 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -32.9375 , avg= 0.501953125 , max= 42.3125 )
response_truncated_logits torch.Size([4, 449, 49152]) (min= -32.9375 , avg= 0.5927734375 , max= 42.3125 )
model full(logits) torch.Size([4, 512, 49152]) (min= -32.9375 , avg= 0.501953125 , max= 42.3125 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3236.294921875 , max= 38475 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -27.46875 , avg= 0.337890625 , max= 44.65625 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3236.294921875 , max= 38475 )
model response(logits) torch.Size([4, 449, 49152]) (min= -32.9375 , avg= 0.5927734375 , max= 42.3125 )
model response(ids) torch.Size([4, 449]) (min= 45 , avg= 3351.951171875 , max= 38475 

 60%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                      | 6/10 [01:47<01:10, 17.61s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 2984.0947265625 , max= 36923 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 2984.0947265625 , max= 36923 )
respone_ids torch.Size([4, 461]) (min= 45 , avg= 3101.977294921875 , max= 36923 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -37.5625 , avg= 0.5458984375 , max= 43.59375 )
response_truncated_logits torch.Size([4, 461, 49152]) (min= -37.5625 , avg= 0.63427734375 , max= 43.59375 )
model full(logits) torch.Size([4, 512, 49152]) (min= -37.5625 , avg= 0.5458984375 , max= 43.59375 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 2982.1337890625 , max= 36923 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -25.578125 , avg= 0.31591796875 , max= 47.0625 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 2982.1337890625 , max= 36923 )
model response(logits) torch.Size([4, 461, 49152]) (min= -37.5625 , avg= 0.63427734375 , max= 43.59375 )
model response(ids) torch.Size([4, 461]) (min= 45 , avg= 3101.977294

 70%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                 | 7/10 [02:04<00:52, 17.49s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3394.58984375 , max= 36923 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3394.58984375 , max= 36923 )
respone_ids torch.Size([4, 469]) (min= 45 , avg= 3488.315673828125 , max= 36923 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -32.53125 , avg= 0.57080078125 , max= 43.9375 )
response_truncated_logits torch.Size([4, 469, 49152]) (min= -32.53125 , avg= 0.64501953125 , max= 43.9375 )
model full(logits) torch.Size([4, 512, 49152]) (min= -32.53125 , avg= 0.57080078125 , max= 43.9375 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3392.90576171875 , max= 36923 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -25.578125 , avg= 0.44921875 , max= 48.21875 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3392.90576171875 , max= 36923 )
model response(logits) torch.Size([4, 469, 49152]) (min= -32.53125 , avg= 0.64501953125 , max= 43.9375 )
model response(ids) torch.Size([4, 469]) (min= 45 , avg= 3488.31567382

 80%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                           | 8/10 [02:22<00:35, 17.54s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3095.18359375 , max= 36923 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3095.18359375 , max= 36923 )
respone_ids torch.Size([4, 468]) (min= 45 , avg= 3173.94677734375 , max= 36923 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -37.21875 , avg= 0.384765625 , max= 42.78125 )
response_truncated_logits torch.Size([4, 468, 49152]) (min= -37.21875 , avg= 0.46533203125 , max= 42.78125 )
model full(logits) torch.Size([4, 512, 49152]) (min= -37.21875 , avg= 0.384765625 , max= 42.78125 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3092.4482421875 , max= 36923 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -26.9375 , avg= 0.384521484375 , max= 47.03125 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3092.4482421875 , max= 36923 )
model response(logits) torch.Size([4, 468, 49152]) (min= -37.21875 , avg= 0.46533203125 , max= 42.78125 )
model response(ids) torch.Size([4, 468]) (min= 45 , avg= 3173.946777343

 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                     | 9/10 [02:39<00:17, 17.60s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3513.5830078125 , max= 44170 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3513.5830078125 , max= 44170 )
respone_ids torch.Size([4, 467]) (min= 45 , avg= 3615.887451171875 , max= 44170 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -35.34375 , avg= 0.74462890625 , max= 42.59375 )
response_truncated_logits torch.Size([4, 467, 49152]) (min= -35.34375 , avg= 0.8349609375 , max= 42.59375 )
model full(logits) torch.Size([4, 512, 49152]) (min= -35.34375 , avg= 0.74462890625 , max= 42.59375 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3511.99462890625 , max= 44170 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -26.484375 , avg= 0.77197265625 , max= 47.125 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3511.99462890625 , max= 44170 )
model response(logits) torch.Size([4, 467, 49152]) (min= -35.34375 , avg= 0.8349609375 , max= 42.59375 )
model response(ids) torch.Size([4, 467]) (min= 45 , avg= 3615.8

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:57<00:00, 17.74s/it]


Training End ....
Epoch 2/4 - Validation
Training Start ....
Epoch 2/4 - Training Gradient Group 1/1, category include =simple arithmetic


  0%|                                                                                                                                                                                                                                 | 0/10 [00:00<?, ?it/s]

full_ids torch.Size([4, 512]) (min= 39 , avg= 3686.4677734375 , max= 41623 )
truncated_ids torch.Size([4, 512]) (min= 39 , avg= 3686.4677734375 , max= 41623 )
respone_ids torch.Size([4, 449]) (min= 39 , avg= 3826.406494140625 , max= 41623 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -31.578125 , avg= 0.916015625 , max= 44.65625 )
response_truncated_logits torch.Size([4, 449, 49152]) (min= -31.578125 , avg= 0.99853515625 , max= 44.65625 )
model full(logits) torch.Size([4, 512, 49152]) (min= -31.578125 , avg= 0.916015625 , max= 44.65625 )
model full(ids) torch.Size([4, 512]) (min= 39 , avg= 3683.81591796875 , max= 41623 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -27.515625 , avg= 0.4560546875 , max= 45.21875 )
ref model full(ids) torch.Size([4, 512]) (min= 39 , avg= 3683.81591796875 , max= 41623 )
model response(logits) torch.Size([4, 449, 49152]) (min= -31.578125 , avg= 0.99853515625 , max= 44.65625 )
model response(ids) torch.Size([4, 449]) (min= 39 , avg= 382

 10%|█████████████████████▌                                                                                                                                                                                                  | 1/10 [03:16<29:24, 196.11s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3309.22265625 , max= 44170 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3309.22265625 , max= 44170 )
respone_ids torch.Size([4, 467]) (min= 45 , avg= 3391.8349609375 , max= 44170 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -36.875 , avg= 0.5703125 , max= 41.3125 )
response_truncated_logits torch.Size([4, 467, 49152]) (min= -36.875 , avg= 0.630859375 , max= 41.3125 )
model full(logits) torch.Size([4, 512, 49152]) (min= -36.875 , avg= 0.5703125 , max= 41.3125 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3308.701171875 , max= 44170 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -24.9375 , avg= 0.38232421875 , max= 46.5 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3308.701171875 , max= 44170 )
model response(logits) torch.Size([4, 467, 49152]) (min= -36.875 , avg= 0.630859375 , max= 41.3125 )
model response(ids) torch.Size([4, 467]) (min= 45 , avg= 3391.8349609375 , max= 44170 )
model_log_l

 20%|███████████████████████████████████████████▍                                                                                                                                                                             | 2/10 [03:33<12:09, 91.23s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 2907.6640625 , max= 36923 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 2907.6640625 , max= 36923 )
respone_ids torch.Size([4, 461]) (min= 45 , avg= 3017.091064453125 , max= 36923 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -38.4375 , avg= 0.5048828125 , max= 43.25 )
response_truncated_logits torch.Size([4, 461, 49152]) (min= -38.4375 , avg= 0.5634765625 , max= 43.25 )
model full(logits) torch.Size([4, 512, 49152]) (min= -38.4375 , avg= 0.5048828125 , max= 43.25 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 2905.92529296875 , max= 36923 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -25.34375 , avg= 0.30810546875 , max= 47.28125 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 2905.92529296875 , max= 36923 )
model response(logits) torch.Size([4, 461, 49152]) (min= -38.4375 , avg= 0.5634765625 , max= 43.25 )
model response(ids) torch.Size([4, 461]) (min= 45 , avg= 3017.091064453125 , max= 3692

 30%|█████████████████████████████████████████████████████████████████                                                                                                                                                        | 3/10 [03:51<06:43, 57.66s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3108.8828125 , max= 36923 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3108.8828125 , max= 36923 )
respone_ids torch.Size([4, 449]) (min= 45 , avg= 3205.75732421875 , max= 36923 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -36.875 , avg= 0.91015625 , max= 41.71875 )
response_truncated_logits torch.Size([4, 449, 49152]) (min= -36.875 , avg= 1.0087890625 , max= 41.71875 )
model full(logits) torch.Size([4, 512, 49152]) (min= -36.875 , avg= 0.91015625 , max= 41.71875 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3106.91552734375 , max= 36923 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -28.046875 , avg= 0.25927734375 , max= 45.0625 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3106.91552734375 , max= 36923 )
model response(logits) torch.Size([4, 449, 49152]) (min= -36.875 , avg= 1.0087890625 , max= 41.71875 )
model response(ids) torch.Size([4, 449]) (min= 45 , avg= 3205.75732421875 , max= 36

 40%|██████████████████████████████████████████████████████████████████████████████████████▊                                                                                                                                  | 4/10 [04:09<04:10, 41.76s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3069.67626953125 , max= 36923 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3069.67626953125 , max= 36923 )
respone_ids torch.Size([4, 468]) (min= 45 , avg= 3146.041259765625 , max= 36923 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -35.875 , avg= 0.728515625 , max= 45.25 )
response_truncated_logits torch.Size([4, 468, 49152]) (min= -35.875 , avg= 0.80029296875 , max= 45.25 )
model full(logits) torch.Size([4, 512, 49152]) (min= -35.875 , avg= 0.728515625 , max= 45.25 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3074.02099609375 , max= 36923 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -26.9375 , avg= 0.35693359375 , max= 47.65625 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3074.02099609375 , max= 36923 )
model response(logits) torch.Size([4, 468, 49152]) (min= -35.875 , avg= 0.80029296875 , max= 45.25 )
model response(ids) torch.Size([4, 468]) (min= 45 , avg= 3146.041259765625 , max= 3

 50%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                            | 5/10 [04:26<02:45, 33.18s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3124.46484375 , max= 46716 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3124.46484375 , max= 46716 )
respone_ids torch.Size([4, 469]) (min= 45 , avg= 3193.42431640625 , max= 46716 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -34.53125 , avg= 0.57421875 , max= 41.625 )
response_truncated_logits torch.Size([4, 469, 49152]) (min= -34.53125 , avg= 0.64208984375 , max= 41.625 )
model full(logits) torch.Size([4, 512, 49152]) (min= -34.53125 , avg= 0.57421875 , max= 41.625 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3123.87646484375 , max= 46716 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -25.65625 , avg= 0.402099609375 , max= 46.75 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3123.87646484375 , max= 46716 )
model response(logits) torch.Size([4, 469, 49152]) (min= -34.53125 , avg= 0.64208984375 , max= 41.625 )
model response(ids) torch.Size([4, 469]) (min= 45 , avg= 3193.42431640625 , max= 

 60%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                      | 6/10 [04:44<01:52, 28.00s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 2590.80126953125 , max= 36923 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 2590.80126953125 , max= 36923 )
respone_ids torch.Size([4, 466]) (min= 45 , avg= 2654.393310546875 , max= 36923 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -37.5625 , avg= 0.5732421875 , max= 43.28125 )
response_truncated_logits torch.Size([4, 466, 49152]) (min= -37.5625 , avg= 0.64404296875 , max= 43.28125 )
model full(logits) torch.Size([4, 512, 49152]) (min= -37.5625 , avg= 0.5732421875 , max= 43.28125 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 2588.38623046875 , max= 36923 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -25.15625 , avg= 0.413818359375 , max= 46.84375 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 2588.38623046875 , max= 36923 )
model response(logits) torch.Size([4, 466, 49152]) (min= -37.5625 , avg= 0.64404296875 , max= 43.28125 )
model response(ids) torch.Size([4, 466]) (min= 45 , avg= 2654.3

 70%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                 | 7/10 [05:02<01:14, 24.70s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3740.3076171875 , max= 47274 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3740.3076171875 , max= 47274 )
respone_ids torch.Size([4, 468]) (min= 45 , avg= 3819.75341796875 , max= 47274 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -33.875 , avg= 0.560546875 , max= 44.0 )
response_truncated_logits torch.Size([4, 468, 49152]) (min= -33.875 , avg= 0.66064453125 , max= 44.0 )
model full(logits) torch.Size([4, 512, 49152]) (min= -33.875 , avg= 0.560546875 , max= 44.0 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3739.15283203125 , max= 47274 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -27.109375 , avg= 0.388427734375 , max= 46.25 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3739.15283203125 , max= 47274 )
model response(logits) torch.Size([4, 468, 49152]) (min= -33.875 , avg= 0.66064453125 , max= 44.0 )
model response(ids) torch.Size([4, 468]) (min= 45 , avg= 3819.75341796875 , max= 47274 )
m

 80%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                           | 8/10 [05:20<00:45, 22.52s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3198.75927734375 , max= 48567 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3198.75927734375 , max= 48567 )
respone_ids torch.Size([4, 466]) (min= 45 , avg= 3264.81494140625 , max= 48567 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -35.28125 , avg= 0.412841796875 , max= 43.875 )
response_truncated_logits torch.Size([4, 466, 49152]) (min= -35.28125 , avg= 0.4619140625 , max= 43.875 )
model full(logits) torch.Size([4, 512, 49152]) (min= -35.28125 , avg= 0.412841796875 , max= 43.875 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3209.6171875 , max= 48567 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -26.265625 , avg= 0.312255859375 , max= 45.84375 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3209.6171875 , max= 48567 )
model response(logits) torch.Size([4, 466, 49152]) (min= -35.28125 , avg= 0.4619140625 , max= 43.875 )
model response(ids) torch.Size([4, 466]) (min= 45 , avg= 3264.81494140625

 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                     | 9/10 [05:38<00:21, 21.03s/it]

full_ids torch.Size([4, 512]) (min= 45 , avg= 3300.8203125 , max= 44260 )
truncated_ids torch.Size([4, 512]) (min= 45 , avg= 3300.8203125 , max= 44260 )
respone_ids torch.Size([4, 466]) (min= 45 , avg= 3389.2509765625 , max= 44260 )
_full_shift_logits torch.Size([4, 512, 49152]) (min= -34.34375 , avg= 0.802734375 , max= 44.40625 )
response_truncated_logits torch.Size([4, 466, 49152]) (min= -34.34375 , avg= 0.92333984375 , max= 44.40625 )
model full(logits) torch.Size([4, 512, 49152]) (min= -34.34375 , avg= 0.802734375 , max= 44.40625 )
model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3306.560546875 , max= 44260 )
ref model full(logits) torch.Size([4, 512, 49152]) (min= -24.578125 , avg= 0.435546875 , max= 46.5 )
ref model full(ids) torch.Size([4, 512]) (min= 45 , avg= 3306.560546875 , max= 44260 )
model response(logits) torch.Size([4, 466, 49152]) (min= -34.34375 , avg= 0.92333984375 , max= 44.40625 )
model response(ids) torch.Size([4, 466]) (min= 45 , avg= 3389.2509765625 , max= 4

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [05:56<00:00, 35.63s/it]


Training End ....
Epoch 3/4 - Validation
Training Start ....
Epoch 3/4 - Training Gradient Group 1/1, category include =simple arithmetic


  0%|                                                                                                                                                                                                                                 | 0/10 [00:00<?, ?it/s]