## About the Notebook
* Training 2B model on verifyable coding Dataset
* Updated reward functions and code sandbox

In [1]:
!pip install -q kagglehub
!pip install -q ipywidgets
!pip install -q tensorflow
!pip install -q tensorflow_datasets
!pip install -q tensorboardX
!pip install -q transformers
!pip install -q grain
!pip install -q "google-tunix[prod]==0.1.3"
!pip uninstall -q -y flax
!pip install -q flax==0.12.0
!pip install -q datasets wandb==0.22.0

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip instal

In [2]:
import wandb, os
from kaggle_secrets import UserSecretsClient
try:
    os.environ['WANDB_API_KEY'] = UserSecretsClient().get_secret("WANDB_API_KEY")
except:
    print("Warning: WANDB_API_KEY not found. WandB logging may fail.")





In [3]:
from pprint import pprint
from typing import List, Dict, Any
import csv
import shutil
import functools, gc, os, re, asyncio, subprocess, tempfile, resource, signal, itertools, sys, time, datetime, random, contextlib, io, traceback
from concurrent.futures import ThreadPoolExecutor, as_completed

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
from pathlib import Path
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
# from tunix.models.gemma3 import model as gemma_lib
# from tunix.models.gemma3 import params as params_lib
from tunix.models.gemma import model as gemma_lib
from tunix.models.gemma import params as params_lib
from tunix.models.gemma3 import params
from tunix.models.gemma3 import model
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from datasets import load_dataset
from huggingface_hub import login
import datasets.utils.logging



In [4]:
try:
    del grpo_trainer
    del rl_cluster
    del lora_policy
    del ref_model
    del optimizer
    del sampler
except NameError:
    pass 

gc.collect()

jax.clear_caches()

In [5]:
from flax import nnx

# Patch nnx.Variable.set_metadata to accept positional arguments
# This fixes the compatibility issue between qwix and flax 0.12.0
if not hasattr(nnx.Variable.set_metadata, "_patched"):
    _orig_set_metadata = nnx.Variable.set_metadata
    
    def _patched_set_metadata(self, *args, **kwargs):
        # If called like set_metadata("sharding_names", val), convert to kwargs
        if len(args) == 2:
            kwargs[args[0]] = args[1]
            args = () # Clear positional args
        return _orig_set_metadata(self, *args, **kwargs)
    
    _patched_set_metadata._patched = True
    nnx.Variable.set_metadata = _patched_set_metadata
    print("‚úì Applied Flax 0.12.0 compatibility patch")

‚úì Applied Flax 0.12.0 compatibility patch


## Hyper parameters

In [6]:
DATASET_NAME = "openCoder-LLM/opc-sft-stage2"
DATASET_SPLIT = "educational_instruct"
TOTAL_SAMPLES_TO_LOAD = 200  # Load slightly more to ensure good splits
TRAIN_SPLIT_RATIO = 0.8

# LoRA
RANK, ALPHA = 64, 64.0

# Sharding
MESH = [(1, 4), ("fsdp", "tp")]

# GRPO
MAX_PROMPT_LENGTH = 512
TOTAL_GENERATION_STEPS = 1024
TEMPERATURE, TOP_P, TOP_K = 0.7, 0.95, 50
NUM_GENERATIONS = 4
NUM_ITERATIONS, BETA, EPSILON = 1, 0.08, 0.2

# Training
TRAIN_MICRO_BATCH_SIZE = 2
# Num batches calculated later based on train split size
NUM_EPOCHS = 1
EVAL_EVERY_N_STEPS = 10

# Optimizer
LEARNING_RATE, B1, B2, WEIGHT_DECAY = 3e-6, 0.9, 0.99, 0.1
MAX_GRAD_NORM = 0.1

# Checkpointing
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
SAVE_INTERVAL_STEPS, MAX_TO_KEEP = 10, 2

# Execution
EXECUTION_TIMEOUT = 3 # Seconds

## Logs setup

In [7]:
VERBOSE = True 
# Use absolute path to be 100% sure where it is
LOG_FILE_PATH = "/kaggle/working/training_logs.txt"

def log_rich_interaction(
    source_stage: str,
    instruction: str,
    full_response: str,
    extracted_code: str,
    test_case: Any,
    exec_result: Dict,
    rewards: Dict[str, float] = None,
    full_details: bool = False
):

    timestamp = datetime.datetime.now().strftime("%H:%M:%S")
    
    # --- FORMATTING ---
    if full_details:
        disp_inst = instruction
        disp_resp = full_response
    else:
        disp_inst = instruction[:150] + "..." if len(instruction) > 150 else instruction
        disp_resp = full_response[:200].replace('\n', ' ') + "..." 

    if exec_result.get('test_log'):
        test_display = "\n         ".join(exec_result['test_log'])
    else:
        test_display = str(test_case)

    reward_str = "N/A"
    if rewards:
        reward_str = " | ".join([f"{k}: {v:+.1f}" for k, v in rewards.items()])
        total_score = sum(rewards.values())
        reward_str += f" | üèÜ TOTAL: {total_score:+.1f}"

    # --- CONSTRUCT LOG ENTRY ---
    log_entry = (
        f"\n{'='*60}\n"
        f"üöÄ [{source_stage.upper()}] @ {timestamp}\n"
        f"{'='*60}\n"
        f"‚ùì QUESTION:\n"
        f"{disp_inst}\n\n"
        f"üß† MODEL OUTPUT:\n"
        f"{disp_resp}\n\n" 
        f"üß™ VERIFICATION:\n"
        f"   ‚îú‚îÄ‚îÄ Code Extracted: {'(None)' if not extracted_code else 'Yes'}\n"
        f"   ‚îú‚îÄ‚îÄ Result:         {exec_result.get('error') if exec_result.get('error') else '‚úÖ Passed'}\n"
        f"   ‚îî‚îÄ‚îÄ Test Details:\n"
        f"         {test_display}\n\n"
        f"üèÖ GRADING: {reward_str}\n"
        f"{'='*60}\n"
    )

    # --- FORCE WRITE TO DISK ---
    try:
        with open(LOG_FILE_PATH, "a", encoding="utf-8") as f:
            f.write(log_entry)
            f.flush()
            os.fsync(f.fileno())
    except Exception as e:
        print(f"‚ö†Ô∏è LOGGING ERROR: Could not write to file: {e}")

    # Print to console
    if VERBOSE:
        print(log_entry)

print(f"‚úì Visualization Engine Ready. Logs forcing to: {LOG_FILE_PATH}")

‚úì Visualization Engine Ready. Logs forcing to: /kaggle/working/training_logs.txt


In [8]:
LOG_FILE_PATH = "training_logs.txt"

# Initialize log file
with open(LOG_FILE_PATH, "w", encoding="utf-8") as f:
    f.write(f"=== TRAINING LOG STARTED AT {datetime.datetime.now()} ===\n\n")

print(f"‚úì Logging initialized at: {LOG_FILE_PATH}")

‚úì Logging initialized at: training_logs.txt


## Prompt

In [9]:
reasoning_start, reasoning_end = "<reasoning>", "</reasoning>"
answer_start, answer_end = "<answer>", "</answer>"

SYSTEM_PROMPT = f"""You are given a coding problem. Think about the problem and \
provide your reasoning. Place it between {reasoning_start} and {reasoning_end}. \
Then, provide the complete Python code solution between {answer_start} and {answer_end}."""

TEMPLATE = """<start_of_turn>user
{system_prompt}

{instruction}<end_of_turn>
<start_of_turn>model"""

## Dataset

In [10]:
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["HF_HUB_DISABLE_XET"] = "1"

try:
    hf_token = UserSecretsClient().get_secret("HF_TOKEN")
    login(token=hf_token)
    print("Logged in to Hugging Face successfully.")
except:
    print("Warning: HF_TOKEN not found.")



In [11]:
datasets.utils.logging.disable_progress_bar()

print(f"üåä Streaming data...")
try:
    # Load dataset streaming
    dataset = load_dataset(DATASET_NAME, DATASET_SPLIT, streaming=True)
    iterable_ds = dataset['train'].take(TOTAL_SAMPLES_TO_LOAD)

    full_data_list = []
    for item in tqdm(iterable_ds, total=TOTAL_SAMPLES_TO_LOAD):
        full_data_list.append({
            "instruction": f"{item['instruction']} using the Function name {item['entry_point']}",
            "code": item["code"],
            "test_case": item["testcase"],
        })
    
    # FACTORY: Perform Train/Val Split
    random.seed(42)
    random.shuffle(full_data_list)
    
    split_idx = int(len(full_data_list) * TRAIN_SPLIT_RATIO)
    train_data = full_data_list[:split_idx]
    val_data = full_data_list[split_idx:]
    
    NUM_BATCHES = len(train_data) // TRAIN_MICRO_BATCH_SIZE
    MAX_STEPS = NUM_BATCHES * NUM_ITERATIONS * NUM_EPOCHS
    WARMUP_STEPS = int(0.1 * MAX_STEPS)

    print(f"‚úì Data Loaded. Total: {len(full_data_list)}")
    print(f"  - Train: {len(train_data)} samples ({NUM_BATCHES} batches)")
    print(f"  - Val:   {len(val_data)} samples")

except Exception as e:
    print(f"‚ùå Error loading dataset: {e}")
    # Fallback for testing if dataset fails
    train_data = []
    val_data = []

üåä Streaming data...


README.md: 0.00B [00:00, ?B/s]

  0%|          | 0/200 [00:00<?, ?it/s]

‚úì Data Loaded. Total: 200
  - Train: 160 samples (80 batches)
  - Val:   40 samples


In [12]:
def get_grain_dataset_iterator(data_source, batch_size):
    """
    Creates a new iterator from the data source. 
    Essential for multiple epochs or separate train/eval passes.
    """
    def _preprocess(x):
        # Flatten inputs to strings to ensure compatibility
        tc = x["test_case"]
        if isinstance(tc, list): tc = "\n".join(tc)
        elif not isinstance(tc, str): tc = str(tc)

        code = x["code"]
        if isinstance(code, list): code = "\n".join(code)
            
        inst = x["instruction"]
        if isinstance(inst, list): inst = "\n".join(inst)

        return {
            "prompts": TEMPLATE.format(
                system_prompt=SYSTEM_PROMPT,
                instruction=inst,
            ),
            "instruction": inst,
            "code": code,
            "test_case": tc,
        }

    ds = (
        grain.MapDataset.source(data_source)
        .shuffle(seed=42)
        .map(_preprocess)
    )
    
    return ds.to_iter_dataset().batch(batch_size, drop_remainder=True)

# Verify
train_ds_iter = get_grain_dataset_iterator(train_data, TRAIN_MICRO_BATCH_SIZE)
print(f"‚úì Dataset iterator factory ready.")

‚úì Dataset iterator factory ready.


## Code executor

In [13]:
@contextlib.contextmanager
def time_limit(seconds):
    def signal_handler(signum, frame):
        raise TimeoutError("Timed out!")
    signal.signal(signal.SIGALRM, signal_handler)
    signal.alarm(seconds)
    try:
        yield
    finally:
        signal.alarm(0)

def execute_code_safe(code: str, test_case: Any, timeout: int = EXECUTION_TIMEOUT) -> Dict[str, Any]:
    """
    Executes code and runs test assertions one by one.
    Returns DETAILED logs for every single test case.
    """
    output_buffer = io.StringIO()
    safe_globals = {
        "__builtins__": __builtins__, 
        "print": lambda *args, **kwargs: print(*args, file=output_buffer, **kwargs)
    }
    
    # Robust Input Handling
    if isinstance(test_case, list): raw_lines = test_case
    elif isinstance(test_case, str): raw_lines = test_case.split('\n')
    else: raw_lines = []

    test_lines = [line.strip() for line in raw_lines if line.strip() and not line.strip().startswith('#')]
    
    total_tests = len(test_lines)
    passed_tests = 0
    test_details = [] # NEW: Track specific results
    
    try:
        with time_limit(timeout):
            # Execute User Code
            exec(code, safe_globals)
            
            # Execute Tests Individually
            for test_line in test_lines:
                try:
                    exec(test_line, safe_globals)
                    passed_tests += 1
                    test_details.append(f"‚úÖ PASS: {test_line}")
                except AssertionError:
                    test_details.append(f"‚ùå FAIL: {test_line}") 
                except Exception as e:
                    test_details.append(f"‚ö†Ô∏è ERR:  {test_line} ({type(e).__name__})")

            is_success = (passed_tests == total_tests) and (total_tests > 0)
            
            return {
                "success": is_success,
                "passed_tests": passed_tests,
                "total_tests": total_tests,
                "test_log": test_details, # Sending back the full log
                "error": None if is_success else "Tests Failed"
            }

    except TimeoutError:
        return {
            "success": False, "passed_tests": 0, "total_tests": max(1, total_tests), 
            "test_log": ["‚è≥ TIMEOUT during execution"], "error": f"Timeout ({timeout}s)"
        }
    except SyntaxError as e:
        return {
            "success": False, "passed_tests": 0, "total_tests": max(1, total_tests),
            "test_log": [f"üö´ SYNTAX ERROR: {e}"], "error": f"Syntax Error: {e}"
        }
    except Exception as e:
        return { 
            "success": False, "passed_tests": 0, "total_tests": max(1, total_tests),
            "test_log": [f"üí• SYSTEM ERROR: {type(e).__name__}"], "error": f"Runtime Error" 
        }

def execute_batch_serial(codes: List[str], test_cases: List[Any]) -> List[Dict[str, Any]]:
    results = []
    for code, test in zip(codes, test_cases):
        results.append(execute_code_safe(code, test))
    return results

print("‚úì Execution engine upgraded: Detailed Test Logging.")

‚úì Execution engine upgraded: Detailed Test Logging.


## Reward Functions

In [14]:
def clean_markdown_code(code_string):
    if not code_string: return ""
    pattern = r"```(?:\w+)?\s*\n(.*?)\s*```"
    match = re.search(pattern, code_string, re.DOTALL)
    if match: return match.group(1).strip()
    return code_string.strip()

match_format = re.compile(
    rf"^\s{{0,}}{reasoning_start}.+?{reasoning_end}.*?{answer_start}(.+?){answer_end}\s{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

# 1. Format (Prerequisite)
def match_format_exactly(prompts, completions, **kwargs):
    return [0.5 if match_format.search(r) else 0.0 for r in completions]

# 2. Format Partial: Tiny hints
def match_format_approximately(prompts, completions, **kwargs):
    scores = []
    for c in completions:
        score = 0.0
        if reasoning_start in c: score += 0.1
        if answer_start in c: score += 0.1
        scores.append(score)
    return scores

# 3. Compilation: High Penalty for Syntax Errors
def check_code_compilation(prompts, completions, **kwargs):
    scores = []
    for c in completions:
        match = match_format.search(c)
        if not match:
            scores.append(0.0)
            continue
        clean_code = clean_markdown_code(match.group(1))
        try:
            compile(clean_code, '<string>', 'exec')
            scores.append(0.5) # Small reward for valid syntax
        except:
            scores.append(-1.0) # Big penalty for invalid code (Hallucination prevention)
    return scores

# 4. Reasoning & Code Quality 
def check_quality_metrics(prompts, completions, **kwargs):
    scores = []
    for c in completions:
        score = 0.0
        
        # A. Reasoning Quality
        r_match = re.search(rf"{reasoning_start}(.+?){reasoning_end}", c, re.DOTALL)
        if r_match:
            reasoning = r_match.group(1).strip()
            if len(reasoning) > 100: score += 0.2  # Reward detailed thought
            #if "step" in reasoning.lower() or "first" in reasoning.lower(): score += 0.1 # Structure
            
        # B. Code Quality
        a_match = re.search(rf"{answer_start}(.+?){answer_end}", c, re.DOTALL)
        if a_match:
            code = clean_markdown_code(a_match.group(1))
            if "#" in code: score += 0.1  # Has comments
            if '"""' in code: score += 0.1 # Has docstrings
        scores.append(score)
    return scores

# 5. Test Execution: Granular Partial Scoring
def check_test_execution(prompts, completions, test_case, **kwargs):
    if not isinstance(test_case, list): test_case = list(test_case)
    n_completions = len(completions)
    n_tests = len(test_case)
    if n_tests == 0: return [0.0] * n_completions

    codes = []
    for c in completions:
        match = match_format.search(c)
        if match:
            codes.append(clean_markdown_code(match.group(1)))
        else:
            codes.append(None)
    
    valid_indices = [i for i, c in enumerate(codes) if c]
    if not valid_indices: return [0.0] * n_completions

    valid_codes = [codes[i] for i in valid_indices]
    valid_tests = []
    for i in valid_indices:
        batch_idx = i // NUM_GENERATIONS
        if batch_idx >= n_tests: batch_idx = i % n_tests
        valid_tests.append(test_case[batch_idx])

    results = execute_batch_serial(valid_codes, valid_tests)
    
    # --- LOGGING EVERY GENERATION ---
    for i, (code_idx, res, code_text) in enumerate(zip(valid_indices, results, valid_codes)):
        try:
             # We assume prompts align with batch indices
             batch_idx = code_idx // NUM_GENERATIONS
             
             # Log every single generation to file (and console if VERBOSE=True)
             log_rich_interaction(
                source_stage=f"TRAIN_GEN_{code_idx}_(Batch_{batch_idx})",
                instruction="<Instruction hidden in Training Loop>", # We don't have raw instruction easily here
                full_response=completions[code_idx], 
                extracted_code=code_text,
                test_case=valid_tests[i],
                exec_result=res,
                rewards={"Passed": res['passed_tests']},
                full_details=True
            )
        except: pass
    # -------------------------------

    scores = [0.0] * n_completions
    for i, r in zip(valid_indices, results):
        score = 0.0
        passed = r['passed_tests']
        score += (passed * 1.0) 
        if r['success']: score += 2.0 
        if not r['success'] and "Assertion" not in str(r.get('error', '')): score -= 0.5
        scores[i] = score
    return scores

print("‚úì Reward functions ready")

‚úì Robust, Balanced, and Partial-Credit Reward functions ready


## Model Prep

In [15]:
model_path = "google/gemma-2/flax/gemma2-2b-it"
print(f"Downloading {model_path}...")
kaggle_ckpt_path = kagglehub.model_download(model_path)

!rm -rf {INTERMEDIATE_CKPT_DIR} {CKPT_DIR}

print("Converting checkpoint format...")
params = params_lib.load_and_format_params(os.path.join(kaggle_ckpt_path, "gemma2-2b-it"))
gemma = gemma_lib.Transformer.from_params(params, version="2-2b-it")
checkpointer = ocp.StandardCheckpointer()
_, state = nnx.split(gemma)
checkpointer.save(os.path.join(INTERMEDIATE_CKPT_DIR, "state"), state)
checkpointer.wait_until_finished()

del params, gemma, state
gc.collect()
print("‚úì Checkpoint converted")

Downloading google/gemma-2/flax/gemma2-2b-it...
Converting checkpoint format...


E0000 00:00:1764886502.986734      12 common_lib.cc:648] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:238
ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x79bc504b60c0> is already entered
ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x79bc504b60c0> is already entered
ERROR:asyncio:Exception in c

‚úì Checkpoint converted


In [16]:
def get_gemma_ref_model(ckpt_path):
    mesh = jax.make_mesh(*MESH)
    model_config = gemma_lib.ModelConfig.gemma2_2b()
    
    abs_gemma = nnx.eval_shape(
        lambda: gemma_lib.Transformer(model_config, rngs=nnx.Rngs(params=0))
    )
    abs_state = nnx.state(abs_gemma)
    abs_state = jax.tree.map(
        lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
        abs_state,
        nnx.get_named_sharding(abs_state, mesh),
    )
    
    checkpointer = ocp.StandardCheckpointer()
    restored_params = checkpointer.restore(ckpt_path, target=abs_state)
    graph_def, _ = nnx.split(abs_gemma)
    gemma = nnx.merge(graph_def, restored_params)
    return gemma, mesh, model_config

def get_lora_model(base_model, mesh):
    lora_provider = qwix.LoraProvider(
        module_path=(".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum"),
        rank=RANK, alpha=ALPHA,
    )
    model_input = base_model.get_model_input()
    lora_model = qwix.apply_lora_to_model(base_model, lora_provider, **model_input)
    
    with mesh:
        state = nnx.state(lora_model)
        pspecs = nnx.get_partition_spec(state)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded_state)
    return lora_model

In [17]:
ref_model, mesh, model_config = get_gemma_ref_model(
    ckpt_path=os.path.join(INTERMEDIATE_CKPT_DIR, "state")
)

lora_policy = get_lora_model(ref_model, mesh=mesh)

tokenizer = tokenizer_lib.Tokenizer(
    tokenizer_path=os.path.join(kaggle_ckpt_path, "tokenizer.model")
)
print("‚úì Models and Tokenizer ready.")

ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x79bc504b60c0> is already entered
ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-1554' coro=<_async_in_context.<locals>.run_in_context() done, defined at /usr/local/lib/python3.12/site-packages/ipykernel/utils.py:57> wait_for=<Task pending name='Task-1555' coro=<Kernel.shell_main() running at /usr/local/lib/python3.12/site-packages/ipykernel/kernelbase.py:590> cb=[Task.__wakeup()]> cb=[ZMQStream._run_callback.<locals>._log_error() at /usr/local/lib/python3.12/site-packages/zmq/eventloop/zmqstream.py:563]>
  async def wait_for_bytes(self, requested_bytes: int):
ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-1555' c

‚úì Models and Tokenizer ready.


In [19]:
def generate(question, sampler, temperature=0.7, top_k=50, top_p=0.95):
    input_batch = [
        TEMPLATE.format(system_prompt=SYSTEM_PROMPT, instruction=q)
        for q in ([question] if isinstance(question, str) else question)
    ]
    out_data = sampler(
        input_strings=input_batch,
        max_generation_steps=TOTAL_GENERATION_STEPS,
        temperature=temperature, top_k=top_k, top_p=top_p, echo=False,
    )
    return out_data.text[0] if isinstance(question, str) else out_data.text

## Pre training - Perfromance

In [20]:
def evaluate_model(data_list, sampler, num_samples=10):
    num_samples = min(num_samples, len(data_list))
    if num_samples == 0: return 0,0,0

    eval_subset = data_list[:num_samples]
    print(f"\nüîé STARTING EVALUATION ON {num_samples} SAMPLES...\n")
    
    total, correct_format, compiles, passes_tests = 0,0,0,0
    batch_size = TRAIN_MICRO_BATCH_SIZE
    
    for i in range(0, num_samples, batch_size):
        batch = eval_subset[i : i + batch_size]
        instructions = [b["instruction"] for b in batch]
        test_cases = [b["test_case"] for b in batch]
        
        # Eval = Low Temperature
        responses = generate(instructions, sampler, temperature=0.1)
        
        for j, (inst, resp, tc) in enumerate(zip(instructions, responses, test_cases)):
            total += 1
            code = None
            match = match_format.search(resp)
            if match:
                correct_format += 1
                code = clean_markdown_code(match.group(1))
            
            is_compiled = False
            if code:
                try:
                    compile(code, '<string>', 'exec')
                    is_compiled = True
                    compiles += 1
                except: pass
            
            result = {"success": False, "passed_tests": 0, "total_tests": 0, "error": "No Code"}
            if code:
                result = execute_code_safe(code, tc, timeout=3)
                if result['success']: passes_tests += 1

            grades = {
                "Format": 0.5 if match else 0.0,
                "Syntax": 0.5 if is_compiled else (-1.0 if code else 0.0),
                "Logic": (result['passed_tests'] * 1.0) + (2.0 if result['success'] else 0.0)
            }
            if not result['success'] and "Assertion" not in str(result.get('error','')):
                grades["Logic"] -= 0.5

            # CALL LOGGER WITH full_details=True
            log_rich_interaction(
                source_stage=f"EVAL_SAMPLE_{total}",
                instruction=inst,
                full_response=resp,
                extracted_code=code,
                test_case=tc,
                exec_result=result,
                rewards=grades,
                full_details=True # <--- SHOW EVERYTHING
            )

    print(f"\n{'='*40}\nüèÅ VALIDATION SUMMARY (n={total}):")
    print(f"  Format:  {correct_format/total*100:.1f}%")
    print(f"  Compile: {compiles/total*100:.1f}%")
    print(f"  Passed:  {passes_tests/total*100:.1f}%")
    print(f"{'='*40}\n")
    return correct_format/total, compiles/total, passes_tests/total

In [21]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [22]:
print("Evaluating Pre-trained model on VALIDATION set...")
with mesh:
    # Use validation data specifically
    evaluate_model(val_data, sampler, num_samples=5)

Evaluating Pre-trained model on VALIDATION set...

üîé STARTING EVALUATION ON 5 SAMPLES...


üöÄ [EVAL_SAMPLE_1] @ 22:19:28
‚ùì QUESTION:
Write a function to find the longest common prefix string amongst an array of strings. using the Function name longest_common_prefix

üß† MODEL OUTPUT:
<reasoning>
To find the longest common prefix string, we can use a sliding window approach. We iterate through the strings in the array, comparing characters at the current window position. If all strings have the same character at a position, we move the window forward. If any string has a different character, we stop and return the current window as the longest common prefix.

Here's a breakdown of the algorithm:

1. **Initialization:**
   - `prefix`: An empty string to store the longest common prefix.
   - `i`: An index to track the current position in the first string.

2. **Iteration:**
   - We iterate through the strings in the array.
   - For each string, we compare the character at the curr

## Settig up GRPO Trainer

In [23]:
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0, peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS, decay_steps=MAX_STEPS, end_value=0.0,
    ),
    b1=B1, b2=B2, weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM:
    optimizer = optax.chain(optax.clip_by_global_norm(MAX_GRAD_NORM), optimizer)

In [24]:
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS,
    max_to_keep=MAX_TO_KEEP
)

metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tmp/tensorboard/grpo",
    flush_every_n_steps=20
)

cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        metrics_logging_options=metrics_logging_options,
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
    ),
)

grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_format_exactly,      # Format (0.5)
        match_format_approximately,# Formatting Hints (0.2)
        check_code_compilation,    # Syntax (+0.5 / -1.0)
        check_quality_metrics,     # Reasoning/Style (0.5 max)
        check_test_execution,      # Logic (+1.0 per test + 2.0 bonus)
    ],
    grpo_config=grpo_config,
)

print("‚úì Training configs ready")

[34m[1mwandb[0m: (1) Private W&B dashboard, no account required
[34m[1mwandb[0m: (2) Use an existing W&B account


[34m[1mwandb[0m: Enter your choice:  1


[34m[1mwandb[0m: You chose 'Private W&B dashboard, no account required'
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33manony-moose-740892564596240858[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


‚úì Training configs ready


In [25]:
print(f"\n{'='*60}")
print(f"Starting GRPO Training on {len(train_data)} samples")
print(f"  Max steps: {MAX_STEPS}")
print(f"  Train batches per epoch: {NUM_BATCHES}")
print(f"{'='*60}\n")


Starting GRPO Training on 160 samples
  Max steps: 80
  Train batches per epoch: 80



## Model Training

In [26]:
train_dataset_iter = get_grain_dataset_iterator(train_data, TRAIN_MICRO_BATCH_SIZE)

with mesh:
    grpo_trainer.train(train_dataset_iter)

Actor Training:   0%|          | 0/80 [00:00<?, ?step/s]

ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x79bc504b60c0> is already entered
ERROR:asyncio:Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x79bc504b60c0> is already entered
ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-4026' coro=<_async_in_context.<locals>.run_in_context() done, defined at /usr/local/lib/python3.12/site-packages/ipykernel/utils.py:57> wait_for=<Task pending name='Task-4027' coro=<Kernel.shell_main() running at /usr/local/lib/python3


üöÄ [TRAIN_GEN_1_(BATCH_0)] @ 22:23:40
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
The problem asks us to find the longest substring within a given string where no characters repeat.  This is a classic problem often solved using a sliding window approach. We can use a dictionary to keep track of the last seen character index for each character in the string. If a character is encountered again, we can shrink the window by moving the start index until we find a new character not in the dictionary. 

Here's a breakdown of the approach:

1. **Initialization:** 
   - `start`: Represents the starting index of the current window.
   - `end`: Represents the ending index of the current window.
   - `char_map`: A dictionary to store the last seen index of each character.

2. **Iteration:**
   - We iterate through the string, expanding the window using `end` pointer.
   - For each character:
      - If the character is not in the `char_map`, add it to th

ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-2729' coro=<_async_in_context.<locals>.run_in_context() running at /usr/local/lib/python3.12/site-packages/ipykernel/utils.py:60> wait_for=<Task pending name='Task-4025' coro=<Kernel.shell_main() running at /usr/local/lib/python3.12/site-packages/ipykernel/kernelbase.py:590> cb=[Task.__wakeup()]> cb=[ZMQStream._run_callback.<locals>._log_error() at /usr/local/lib/python3.12/site-packages/zmq/eventloop/zmqstream.py:563]>
  def __init__(self):
ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-4025' coro=<Kernel.shell_main() running at /usr/local/lib/python3.12/site-packages/ipykernel/kernelbase.py:590> cb=[Task.__wakeup()]>



üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:24:49
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
To find the largest continuous sum in a list, we need to iterate through the list, keeping track of the running sum and the maximum sum encountered so far. We can do this by using a variable `current_sum` to hold the sum of the current sub-array and `max_sum` to hold the maximum sum found so far.  We compare the current sum to the maximum sum and update it if the current sum is larger. 

Additionally, we need to handle the case where the list is empty or contains only one element. If the list is empty, the largest continuous sum is 0. If the list contains only one element, the largest continuous sum is that element itself.  </reasoning>
 
<answer>
```python
def largest_cont_sum(nums):
    if len(nums) == 0:
        return 0
    if len(nums) == 1:
        return nums[0]
    current_sum = nums[0]
    max_sum = nums[0]
    for i in range(1, len(nums)):
        curre




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:24:58
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
The QuickSelect algorithm is a divide-and-conquer algorithm that efficiently finds the kth smallest element in an unsorted array. It works by repeatedly selecting the pivot element and partitioning the array around the pivot. The algorithm then recursively selects the kth smallest element from the subarray. 

Here's how the algorithm works:

1. **Choose a pivot:** A pivot element is selected from the array. 
2. **Partition:** The array is partitioned into two subarrays: elements smaller than the pivot and elements larger than the pivot.
3. **Recursive selection:** The algorithm recursively calls itself to find the kth smallest element in the smaller subarray and the larger subarray. 
4. **Combine:** The smallest element from the smaller subarray is compared with the pivot to determine the kth smallest element.

To implement QuickSelect in Python, we can use the f




üöÄ [TRAIN_GEN_1_(BATCH_0)] @ 22:25:07
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
The factorial of a non-negative integer n, denoted by n!, is the product of all positive integers less than or equal to n. For example, 5! = 5 * 4 * 3 * 2 * 1 = 120. 

A recursive approach to calculate factorial is to define a function that breaks down the problem into smaller subproblems.  It will call itself with a decreasing argument until it reaches a base case (where the argument is 0).  The base case is when the argument is 0, where the factorial is 1. The results are then combined back up the call stack to calculate the final factorial. 

</reasoning>

<answer>
```python
def factorial(n):
  if n == 0:
    return 1
  else:
    return n * factorial(n-1)

# Example usage
number = 5
result = factorial(number)
print(f"The factorial of {number} is {result}") 
```
</answer> 
<end_of_turn>

üß™ VERIFICATION:
   ‚îú‚îÄ‚îÄ Code Extracted: Yes
   ‚îú‚îÄ‚îÄ Result:  




üöÄ [TRAIN_GEN_1_(BATCH_0)] @ 22:25:15
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
To remove duplicate elements from a list while preserving order, we can utilize a set. Sets inherently store only unique elements, and we can iterate through the list and add each element to the set.  This ensures we only store unique elements, and we can use the set's order to reconstruct the list.

Here's a breakdown of the process:

1. **Initialize an empty set:** We'll use a set to store unique elements.
2. **Iterate through the input list:** We'll use a loop to go through each element in the list.
3. **Check if the element exists in the set:** For each element, we check if it's already present in the set. 
4. **Add to set or skip:** If the element is not in the set, we add it to the set. Otherwise, we skip it.
5. **Create a new list:** We'll create a new list using the set's elements to reconstruct the unique elements in order.

</reasoning>
<answer>
```pyth




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:25:29
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
Here's how we can determine if a string is a palindrome:

1. **Reverse the String:**  The core logic is to compare the original string with its reversed version. 
2. **Character-by-Character Comparison:** Iterate through the string, comparing each character with its corresponding character from the end.
3. **Handling Spaces and Punctuation:**  Ensure to consider spaces and punctuation marks as separate characters, as they shouldn't affect the palindrome check.

</reasoning>

<answer>
```python
def is_palindrome(text):
    """
    Checks if a string is a palindrome.

    Args:
        text (str): The string to be checked.

    Returns:
        bool: True if the string is a palindrome, False otherwise.
    """
    processed_text = ''.join(c for c in text.lower() if c.isalnum())
    return processed_text == processed_text[::-1]

# Example usage
string1 = "racecar"
s




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:25:35
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
To find the prime factors of a number, we need to repeatedly divide the number by prime numbers until it becomes 1. We can use a loop to iterate through all possible prime numbers and check if the given number is divisible by them. If it is, we add the prime number to the list of prime factors. We can use the modulo operator (%) to check if a number is divisible by another number. We can also use a list to store the prime factors. </reasoning>
<answer>
```python
def prime_factors(n):
    prime_factors = []
    i = 2
    while i * i <= n:
        if n % i:
            i += 1
        else:
            n //= i
            prime_factors.append(i)
    if n > 1:
        prime_factors.append(n)
    return prime_factors

# Example usage
number = 12
factors = prime_factors(number)
print(f"The prime factors of {number} are: {factors}")
```
</answer> 
<end_of_turn>

üß™ VE




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:25:46
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
To count the occurrences of a target character in a string, we can use a loop to iterate through each character in the string. For each character, we check if it matches the target character. If it does, we increment a counter. We should also handle case insensitivity by converting the string to lowercase before comparison.  We can then return the final count. 
</reasoning>

<answer>
def count_char(string, target):
  """
  Counts the number of times a target character appears in a string.

  Args:
      string: The string to search.
      target: The character to count.

  Returns:
      The number of times the target character appears in the string.
  """
  count = 0
  string = string.lower()
  for char in string:
    if char == target:
      count += 1
  return count

</answer> 
<end_of_turn>

üß™ VERIFICATION:
   ‚îú‚îÄ‚îÄ Code Extracted: Yes
   ‚îú‚îÄ‚îÄ Res




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:25:58
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
Here's how we can solve this problem:

1. **Sliding Window Approach:** The most common and efficient approach is using a sliding window. 
   - We maintain a window of characters that we're currently considering.
   - We keep track of the last seen index of each character in the window.
   - As we traverse the string, we expand the window until we encounter a repeating character or reach the end of the string.
   - The length of the longest substring without repeating characters is the size of the window.

2. **Two Hash Maps:** We can use a hash map (dictionary) to keep track of the last seen index of each character. 
   - We iterate through the string, adding each character to the hash map and checking if it's present in the window.
   - If a character is encountered that has already been seen, we shrink the window from the right until we find a character that's 




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:26:08
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
To find the sum of all even numbers in a list, we can iterate through the list and check each number if it's even. We can use the modulo operator (%) to determine if a number is even. If the remainder when dividing by 2 is 0, the number is even.  We can then add each even number to a running total. </reasoning>

<answer>
```python
def sum_even_numbers(numbers):
  """
  This function calculates the sum of all even numbers in a given list.

  Args:
    numbers: A list of integers.

  Returns:
    The sum of all even numbers in the list.
  """
  total = 0
  for number in numbers:
    if number % 2 == 0:
      total += number
  return total

# Example usage
numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
even_sum = sum_even_numbers(numbers)
print(f"The sum of even numbers is: {even_sum}")
```
</answer> 
<end_of_turn>

üß™ VERIFICATION:
   ‚îú‚îÄ‚îÄ Code Extracted: Yes
   




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:26:17
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
To calculate the factorial of a number using recursion, we can break down the problem into smaller, similar subproblems. 

1. **Base Case:**  The factorial of 0 is defined as 1. 
2. **Recursive Step:** For any non-zero number, the factorial is the product of that number and the factorial of the number minus 1. This is where recursion comes in - we call the `factorial` function itself with a smaller number. 

This approach ensures that the function eventually reaches the base case and then uses the results of the smaller subproblems to build up the final factorial value. 
</reasoning>

<answer>
```python
def factorial(n):
    if n == 0:
        return 1
    else:
        return n * factorial(n-1)

# Get input from the user
num = int(input("Enter a non-negative integer: "))

# Validate input
if num < 0:
    print("Factorial is not defined for negative numbers.")
el




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:26:26
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
A perfect number is a positive integer that is equal to the sum of its proper divisors (all divisors excluding the number itself). We can determine if a number is perfect by iterating through all divisors and calculating their sum. 

To optimize the solution, we can utilize a set to store the divisors, and only calculate them once.
</reasoning>

<answer>
```python
def isPerfectNumber(number):
    """
    Checks if a given number is a perfect number.

    Args:
      number: The number to check.

    Returns:
      True if the number is a perfect number, False otherwise.
    """

    if number <= 1:
        return False

    divisors = set()
    for i in range(1, int(number**0.5) + 1):
        if number % i == 0:
            divisors.add(i)
            if i != number // i:
                divisors.add(number // i)
    
    return sum(divisors) == number

# Example




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:26:36
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
To reverse the vowels in a string, we can iterate through the string, identify vowels, and swap them with the vowels at the end of the string.  We can do this by using a loop and a set to keep track of the vowels in the string.  


</reasoning>
 
 <answer>
```python
def reverse_vowels(s):
    """
    This function reverses the vowels in a string.

    Args:
      s: The input string.

    Returns:
      The string with all vowels reversed.
    """
    vowels = set("aeiouAEIOU")
    left = 0
    right = len(s) - 1
    result = []
    while left <= right:
      if s[left] in vowels:
        result.append(s[left])
        left += 1
      if s[right] in vowels:
        result.append(s[right])
        right -= 1
    return ''.join(result)

# Example usage:
string = "hello world"
reversed_string = reverse_vowels(string)
print(f"Original String: {string}")
print(f"Rever




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:26:47
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
The problem involves finding the longest substring within a given string that contains only unique characters. We can solve this using a sliding window approach. The idea is to iterate through the string, keeping track of the characters seen in the current window and updating the window's boundaries as we move along. If a repeating character is encountered, we slide the window to exclude it. The maximum length of the substring is then returned.

To efficiently handle the sliding window, we can use a dictionary to store the last seen index of each character. This helps us quickly determine if a character is repeating within the current window. 

</reasoning>

<answer>
```python
def longest_substring(s):
  """
  Finds the longest substring without repeating characters in a given string.

  Args:
    s: The input string.

  Returns:
    The length of the longest sub




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:26:58
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
To determine if a string is a valid parentheses string, we can use a recursive approach. We can break down the problem into smaller subproblems. We will use a stack to keep track of opening parentheses and closing parentheses. If the string is empty, it is valid. Otherwise, we will iterate through each character and check if it is an opening parenthesis, closing parenthesis, or a concatenation. If the string is empty, we can immediately return True. If the string is a valid parentheses string, we can return True. If the string is not a valid parentheses string, we can return False. </reasoning>

<answer>
```python
def is_valid_parentheses(s):
    stack = []
    parentheses_map = {")": "(", "}": "{", "]": "["}
    for char in s:
        if char in parentheses_map.values():
            stack.append(char)
        elif char in parentheses_map.keys():
            if s




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:27:07
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
The problem requires us to count how many times a specific binary number appears within a list of integers represented as strings. To achieve this, we need to first convert both the binary number and the integers in the list to their integer representations before comparing them. This ensures that the comparison is performed on the actual numerical values of the numbers, not just their string representations. 

Here's a breakdown of the steps:

1. **Convert binary to integer:** The binary string is converted to its integer representation using the `int(binary_number, 2)` function. This function interprets the binary string as a base-2 number.

2. **Convert integers to integers:**  The integers in the list are also converted to integers using the `int(integer_string)` function. 

3. **Compare and count:** Once both the binary number and the integers are in integer




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:27:18
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
The problem asks for a function that takes a string as input and returns the length of the longest substring without repeating characters.  To solve this, we can utilize a sliding window approach. We'll maintain a set to track the characters we've already seen, and for each character, we'll check if it's in our set. If it is, we'll shrink the window, moving it to the right to avoid repeating characters.

Here's a breakdown of the logic:

1. **Initialization:** We'll initialize an empty set to track seen characters.
2. **Sliding Window:** We'll use two pointers, left and right, to define the start and end of the window. 
3. **Check for Repetition:** For each character at the right end, we'll check if it's in our set. 
4. **Shrink Window:** If a character is repeated, we'll shrink the window to the right by updating the right pointer.
5. **Longest Substring:** We'l




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:27:27
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
The problem asks for a function that takes a string as input and returns the length of the longest substring without repeating characters.  To solve this, we can utilize a sliding window approach. We'll maintain a set to track the characters we've already seen in the current window.  We'll iterate through the string, expanding the window until a repeating character is encountered. We'll then shrink the window by one character, keeping track of the current window's length. The maximum length of the substring without repeating characters will be the longest window size.

Here's a breakdown of the logic:

1. **Sliding Window:** The sliding window represents the current substring under consideration. 
2. **Character Tracking:** We use a set to keep track of the characters present in the current window.
3. **Iteration:** We iterate through the string, expanding the wi




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:27:38
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
The problem asks us to find the smallest positive integer that is not present in a given list of integers. This is a classic problem that can be solved using the concept of **set operations** and a little bit of logic. 

Here's how we can approach it:

1. **Create a Set:** Convert the input list into a set. Sets are efficient for membership checking.
2. **Find Missing Integers:** Iterate through the range of positive integers. If an integer is not present in the set, it is the missing integer. 
3. **Return the Missing Integer:** Return the smallest positive integer that's not in the set.

</reasoning>
<answer>
```python
def find_missing(nums):
  """
  Finds the smallest positive integer that does not occur in a given list of integers.

  Args:
    nums: A list of integers.

  Returns:
    The smallest positive integer that is not in the list.
  """
  missing = 1





üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:27:45
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
The problem involves finding the longest substring within a given string that contains only unique characters. This can be accomplished using a sliding window approach. We need to iterate through the string, keeping track of the last seen index of each character.  We can maintain a dictionary to store the last seen index of each character. When a duplicate character is encountered, we slide the window to the right to avoid including it in the current substring.  The length of the longest substring without repeating characters is then the maximum length of the substring encountered.

</reasoning>
<answer>
```python
def longest_substring(s):
    """
    Finds the length of the longest substring without repeating characters.

    Args:
      s: The input string.

    Returns:
      The length of the longest substring without repeating characters.
    """
    n = len




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:27:55
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
This problem involves finding the longest palindromic substring within a given string.  Here's a breakdown of how we can approach it:

1. **Understanding Palindromes:** A palindrome is a word, phrase, number, or other sequence of characters that reads the same backward as forward. 

2. **Dynamic Programming Approach:**  A dynamic programming approach is efficient for finding palindromes because we can reuse calculations.  We'll create a table to store the lengths of palindromes of substrings.

3. **Algorithm:**
   -  Initialize a table `dp` of size `(n+1) x (n+1)` where n is the length of the input string. 
   -  For each substring of the string, check if it's a palindrome.
   -  If it is, update the length of the palindrome in the table.

4. **Finding the Longest Palindrome:**  We need to find the maximum length palindrome in the table.

</reasoning>

<answer>
`




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:28:07
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
The nth Catalan number is a sequence of numbers in combinatorics.  It's defined as the number of ways to choose a set of n objects from a set of n distinct objects, where the order of selection doesn't matter.  We can solve this using dynamic programming by storing the values for previously calculated Catalan numbers in a list. We iterate through the range of n and calculate the Catalan number for each value. </reasoning>
 
<answer>
```python
def catalan(n):
    """
    Calculates the nth Catalan number using dynamic programming.

    Args:
      n: The index of the Catalan number to calculate.

    Returns:
      The nth Catalan number.
    """

    if n <= 1:
        return 1

    catalan_numbers = [0] * (n + 1)
    catalan_numbers[0] = 1
    catalan_numbers[1] = 1

    for i in range(2, n + 1):
        for j in range(i):
            catalan_numbers[i] += catal




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:28:17
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
To determine if a word is a palindrome, we need to compare the word with its reverse. We can achieve this by:

1. **Reverse the word:**  Use string slicing or a loop to reverse the word.
2. **Compare the original and reversed words:** Check if the original and reversed words are the same. If they are, the word is a palindrome.

</reasoning>

<answer>
```python
def is_palindrome(word):
  """
  Checks if a word is a palindrome.

  Args:
    word: The word to check.

  Returns:
    True if the word is a palindrome, False otherwise.
  """
  word = word.lower().replace(" ", "")
  return word == word[::-1]

# Example usage
word1 = "madam"
word2 = "racecar"
word3 = "apple"

print(f"'{word1}' is a palindrome: {is_palindrome(word1)}")
print(f"'{word2}' is a palindrome: {is_palindrome(word2)}")
print(f"'{word3}' is a palindrome: {is_palindrome(word3)}")
```
</answer> 
<end




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:28:27
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
To find the mode of a list of integers, we can use the following approach:
1. Count the occurrences of each integer in the list using a dictionary.
2. Identify the integer with the highest count as the mode.
3. If there are multiple modes, return the smallest one.

To handle the case of multiple modes, we can use a dictionary to store the count of each integer and then sort the list of integers and their counts. The smallest integer will be the mode.
</reasoning>
<answer>
```python
def find_mode(nums):
  """
  Finds the mode of a list of integers.

  Args:
      nums: A list of integers.

  Returns:
      The mode of the list, which is the integer that appears most frequently.
  """
  counts = {}
  for num in nums:
    if num in counts:
      counts[num] += 1
    else:
      counts[num] = 1
  
  max_count = 0
  mode = None
  for num, count in counts.items():
    




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:28:36
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
To find the maximum product of two integers in a list, we can use the following steps:

1. **Initialize two variables:** `max1` and `max2` to store the maximum and second maximum product.
2. **Iterate through the list:** Check each integer in the list.
3. **Compare each integer with the current maximum:** If the current integer is greater than `max1`, update `max1` with the current integer. If the current integer is greater than `max2`, update `max2` with the current integer.
4. **After iterating through the entire list, the maximum product will be stored in `max1` and the second maximum in `max2`**.

We can then return the larger of the two. 


</reasoning>

<answer>
```python
def max_product(nums):
  """
  Finds the maximum product of two integers in a given list of integers.

  Args:
    nums: A list of integers.

  Returns:
    The maximum product of two inte




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:28:47
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
To reverse a list of strings in Python, we can iterate through the list and use slicing to reverse the order of the strings. We can also use list comprehension for a more concise solution.  

The function should use a loop to iterate through the list of strings, then reverse each string using slicing and create a new list with the reversed strings.
</reasoning>

<answer>
```python
def reverse_strings(strings):
  """
  This function takes a list of strings and returns a new list with each string reversed.

  Args:
    strings: A list of strings.

  Returns:
    A new list of strings with each string reversed.
  """
  reversed_strings = []
  for string in strings:
    reversed_strings.append(string[::-1])
  return reversed_strings

# Example usage
my_list = ["hello", "world", "python"]
reversed_list = reverse_strings(my_list)
print(f"Original list: {my_list}")
prin




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:28:55
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
To find the smallest positive number divisible by all numbers from 1 to n, we need to find the least common multiple (LCM). The LCM is the smallest number that is a multiple of all given numbers. We can use prime factorization to calculate the LCM.  By finding the prime factorization of each number from 1 to n, we can determine the LCM. 

For example, if we have 12, the prime factorization is 2 x 2 x 3.  We can then multiply these prime factors to get the LCM. 

</reasoning>

<answer>
```python
def smallest_divisible(n):
  """
  Finds the smallest positive number that is evenly divisible by all numbers from 1 to n.

  Args:
      n: The upper limit.

  Returns:
      The smallest positive number divisible by all numbers from 1 to n.
  """
  if n <= 1:
    return 1
  
  prime_factors = []
  for i in range(2, n+1):
    is_prime = True
    for j in range(2, int(i**0




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:29:07
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
This problem can be solved using a sliding window approach. We can maintain a dictionary to keep track of the last seen index of each character. We iterate through the string, and whenever we encounter a character we have already seen, we shrink the window from the left by removing the character at the start of the window that's no longer within the window. 

The key to solving this problem is to keep track of the maximum length of the substring we find so far. We do this by comparing the current window's length with the maximum length we've seen. </reasoning>

<answer>
```python
def length_of_longest_substring(s):
    """
    Finds the length of the longest substring without repeating characters in a given string.

    Args:
        s: The input string.

    Returns:
        The length of the longest substring without repeating characters.
    """
    n = len(s)




üöÄ [TRAIN_GEN_0_(BATCH_0)] @ 22:29:16
‚ùì QUESTION:
<Instruction hidden in Training Loop>

üß† MODEL OUTPUT:
<reasoning>
To solve this problem, we can use the concept of finding the median of the array and then iterate through the array and calculate the difference between each element and the median. The minimum number of operations required is the sum of the absolute values of these differences. 

We can use a greedy approach.  We iterate through the array and keep track of the median. The median is used to determine the optimal operations. 

We will iterate through the array and if the element is less than the median, we will increment it by 1. If the element is greater than the median, we will decrement it by 1.
</reasoning>

<answer>
```python
def min_operations(arr):
  """
  Calculates the minimum number of operations required to make all elements of an array equal.

  Args:
    arr: A list of integers.

  Returns:
    The minimum number of operations required.
  """
  n = le

0,1
actor/train/kl,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÉ‚ñÇ‚ñÑ‚ñÉ‚ñÑ‚ñÑ‚ñÖ‚ñà‚ñÑ‚ñá‚ñÜ‚ñÖ‚ñÑ‚ñÜ‚ñÖ‚ñÜ‚ñÜ‚ñÖ‚ñÜ‚ñÖ‚ñÜ‚ñÖ‚ñÑ‚ñÑ‚ñÖ‚ñá‚ñÑ‚ñá‚ñÉ‚ñÉ‚ñÖ‚ñÑ‚ñà‚ñÜ
actor/train/loss,‚ñÑ‚ñÖ‚ñÇ‚ñÖ‚ñà‚ñÅ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÇ‚ñÉ‚ñÉ‚ñÖ‚ñÉ‚ñÑ‚ñÉ‚ñÑ‚ñÅ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÖ‚ñá‚ñÉ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÖ
actor/train/perplexity,‚ñÇ‚ñÑ‚ñÑ‚ñà‚ñÑ‚ñÖ‚ñÉ‚ñÇ‚ñÉ‚ñÉ‚ñÇ‚ñÅ‚ñÑ‚ñÇ‚ñÉ‚ñÑ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÅ‚ñÑ‚ñÉ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÉ
actor/train/step_time_sec,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
actor/train/steps_per_sec,‚ñÅ‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñÉ‚ñà‚ñà‚ñà‚ñà‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
actor/train/tflops_per_step,‚ñÅ
jax/core/compile/backend_compile_duration,‚ñÅ
jax/core/compile/jaxpr_to_mlir_module_duration,‚ñÅ
jax/core/compile/jaxpr_trace_duration,‚ñÅ
jax/orbax/write/sharded_array_gb,‚ñÅ

0,1
actor/train/kl,0.00521
actor/train/loss,0.12099
actor/train/perplexity,1.12861
actor/train/step_time_sec,0.17752
actor/train/steps_per_sec,5.63307
actor/train/tflops_per_step,36.91269
jax/core/compile/backend_compile_duration,1764887021.9409
jax/core/compile/jaxpr_to_mlir_module_duration,1764887020.49653
jax/core/compile/jaxpr_trace_duration,1764887016.70765
jax/orbax/write/sharded_array_gb,0.0011


In [27]:
! ls -ltr /tmp/content/ckpts/actor

  pid, fd = os.forkpty()


total 8
drwxr-xr-x 3 root root 4096 Dec  4 22:28 70
drwxr-xr-x 3 root root 4096 Dec  4 22:29 80


## Post Training - Performance

In [28]:
wandb.init(project='tunix-eval-code')

trained_ckpt_path = os.path.join(CKPT_DIR, "actor", str(MAX_STEPS), "model_params")
print(f"Restoring from {trained_ckpt_path}...")

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

nnx.update(
    lora_policy,
    jax.tree.map(lambda a, b: b, nnx.state(lora_policy, nnx.LoRAParam), trained_lora_params),
)

# Recreate sampler
sampler = sampler_lib.Sampler(
    transformer=lora_policy, tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers, num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

print("Evaluating trained model on VALIDATION set...")
# Evaluating on Validation Data (Unseen)
evaluate_model(val_data, sampler, num_samples=25)

Restoring from /tmp/content/ckpts/actor/80/model_params...




Evaluating trained model on VALIDATION set...

üîé STARTING EVALUATION ON 25 SAMPLES...


üöÄ [EVAL_SAMPLE_1] @ 22:31:39
‚ùì QUESTION:
Write a function to find the longest common prefix string amongst an array of strings. using the Function name longest_common_prefix

üß† MODEL OUTPUT:
<reasoning>
To find the longest common prefix string, we can use a sliding window approach. We iterate through the strings in the array, comparing characters at each position. If a mismatch is found, we know the prefix ends at that position. We keep track of the prefix length and return it. 

Here's a breakdown of the logic:

1. **Initialization:**
   - `prefix`: Initialize an empty string to store the common prefix.
   - `i`: Initialize a pointer to the beginning of the first string.

2. **Iteration:**
   - Iterate through the strings in the array.
   - For each string, compare characters at the current position `i` with the corresponding characters in the other strings.
   - If a mismatch is found, 

(1.0, 1.0, 0.72)