# Tunix hackathon notebook submission template

This template notebook is part of the [Google Tunix hackathon](https://www.kaggle.com/competitions/google-tunix-hackathon/overview) and is used to simplify the submission process for competition participants and streamline the evaluation process for the judges.

## Your overall training and evaluation strategy

Please discuss your overall thinking process on how to approach this hackathon, how you allocate your compute, how you set up the reward for RL, how you evaluate the model, what kind of techniques are used, how your training pipeline looks like, ablation studies you have done, etc. 

Think of this section as a mini-technical report.

## How your finetuning dataset is created

Constructing a good data mixture is a key challenge for this hackathon. Please discuss this topic in detail in this section (but there is no need to put relevant code here since the judges will not try to reproduce your dataset).

Also make sure the data set is publicly accessible for evaluation.

## Your Tunix finetuning code

Use instruction-tuned Gemma2 2B or Gemma3 1B (other models are not allowed).

In [None]:
# Your prompt
PROMPT_TEMPLATE = "your awesome prompt with a placeholder {question}"

# Training parameters; feel free to change
TEMPERATURE=0.7
TOP_K=50
TOP_P=0.9

# You may also change this parameter for inference if you want
MAX_GENERATION_STEPS=768


# DO NOT CHANGE BELOW

# Use these standard output tags so that your model's output follow this format in plain text (no JSON/XML):
# <reasoning>model_reasoning_trace</reasoning>
# <answer>model_final_answer</answer>

REASONING_START = "<reasoning>"
REASONING_END = "</reasoning>"
SOLUTION_START = "<answer>"
SOLUTION_END = "</answer>"

# Use these parameters for greedy decoding; used in competition evaluation
INF_TEMPERATURE=0
INF_TOP_K=1
INF_TOP_P=None
SEED=42

Here is your finetuning code:

In [None]:
# Your awesome finetuning code. Example:
# 
# with mesh:
#   grpo_trainer.train(train_dataset)
#
# Make sure at least one checkpoint is saved during a 9hr run.
# The very last checkpoint will be used for evaluation.

## [Optional 15pts] unrestricted mode

If you would like to participate in the unrestricted mode, please write down the Kaggle model ID for evaluation. Make sure that:
1. you have published the relevant files (in Flax format) to Kaggle (please check Kaggle doc on how to upload model files. One way is to use `kagglehub.model_upload()`)
2. you have set the model visibility to 'Public'
3. the model files are loadable by Tunix modelling code (see the section below)
4. the model uses the same base model as in your single-session run (e.g., if you use Gemma2 2B for single-session, then this mode needs to come from the same Gemma2 2B, NOT Gemma3 1B).

In [None]:
# Example: 'windmaple/gpt2' in https://www.kaggle.com/models/windmaple/gpt2

unrestricted_kaggle_model = "user_name/model_name"  

## Other things you want the judges to know

Additional topics worth discussing. For example, 
- what you learned in this hackathon
- challenges you faced during this hackathon (e.g., compute, Tunix product issues and feature requests)
- how you think this hackathon could be better
- if you made a PR to Tunix for this hackathon, make sure to include it here

# Competition evaluation

This section is for references only and do not need to be included in your submission.

Google judges will first run your code above to reproduce your model and then run the code below for evaluation; the eval runtime is beyond your 9-hour budget so that you don't need to worry about it.

## Load the final checkpoint for single-session mode evaluation
There is no need to manually merge the LoRA adapter; Tunix will take care of it at inference time.

In [None]:
# Define the checkpoint folder
CKPT_DIR = 'your ckpt folder'

import re

# Find the latest checkpoint by listing directories in CKPT_DIR/actor
actor_ckpt_dir = os.path.join(CKPT_DIR, "actor")

latest_step = -1
if os.path.exists(actor_ckpt_dir):
  for item in os.listdir(actor_ckpt_dir):
    if os.path.isdir(os.path.join(actor_ckpt_dir, item)) and re.match(r'^\d+$', item):
      step = int(item)
      if step > latest_step:
        latest_step = step

if latest_step == -1:
  raise FileNotFoundError(f"No checkpoints found in {actor_ckpt_dir}")

print(f"Latest checkpoint step: {latest_step}")

wandb.init(project='tunix-eval')  # logging bug workaround

trained_ckpt_path = os.path.join(
    CKPT_DIR, "actor", str(latest_step), "model_params"
)

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

nnx.update(
    lora_policy,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(lora_policy, nnx.LoRAParam),
        trained_lora_params,
    ),
)

## Create the sampler for finetuned model

In [None]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + MAX_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

## Evaluate
Illustrative code for AI-based evaluation.

In [None]:
class TunixHackathonJudge:
    questions = ['question1', 'question2', ...]
    judge = "ai"
    
    def __init__(self, temperature, top_k, top_p, max_generation_steps, seed):
        ...

    def evaluate(self, sampler, prompt):
        ...

Result = TunixHackathonJudge(INF_TEMPERATURE, INF_TOP_K, INF_TOP_P, MAX_GENERATION_STEPS, SEED).evaluate(sampler, PROMPT_TEMPLATE)

# Unrestricted mode

If the participant includes a `unrestricted_kaggle_model`, Google will load the uploaded checkpoint from Kaggle like below and use the `TunixHackathonJudge` above for another eval.

In [None]:
trained_ckpt_path = kagglehub.model_download(unrestricted_kaggle_model+"/jax/size") # may need to append "actor"+str(latest_step)+"model_params" as well

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

nnx.update(
    lora_policy,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(lora_policy, nnx.LoRAParam),
        trained_lora_params,
    ),
)

# Evaluation code is pretty much the same as single-session mode above