In [None]:
!pip install -q kagglehub

!pip install -q tensorflow
!pip install -q tensorboardX
!pip install -q grain
!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

!pip install -q huggingface_hub
!pip install -q datasets

In [None]:
import functools
import os
import re

from flax import nnx
import grain
import humanize
import jax
import optax
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm

from tunix.examples.data import translation_dataset as data_lib
from tunix.generate import sampler as sampler_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib

from datasets import load_dataset
from tunix.sft.dpo.dpo_trainer import DpoTrainingConfig
from tunix.sft.dpo.dpo_trainer import DpoTrainer
from tunix.sft.dpo.dpo_trainer import TrainingInput
from huggingface_hub import snapshot_download
from tunix.sft.dpo.dpo_trainer import _generate_ids_and_masks
from tunix.models.gemma3 import model as gemma3_model_lib
from datasets import concatenate_datasets
import numpy as np

In [None]:
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

INTERMEDIATE_CKPT_DIR = "/content/intermediate_ckpt/"
# ====== LoRA ======
RANK = 8
ALPHA = 16.0

# ====== Sharding ======
MESH = [(1, 1), ("fsdp", "tp")]

MAX_PROMPT_LENGTH = 192
TOTAL_GENERATION_STEPS = 192
TEMPERATURE = 0.7
TOP_P = 1.0
TOP_K = 50
BETA = 0.1

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1

# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
BATCH_SIZE = 1
NUM_BATCHES = 512
NUM_TEST_BATCHES = 100
EVAL_EVERY_N_STEPS = 100

NUM_EPOCHS = 1  # can potentially train for more epochs
TRAIN_FRACTION = 1.0
MAX_STEPS = int(NUM_BATCHES * TRAIN_FRACTION * NUM_EPOCHS)

WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1

# ====== Inference ======
GENERATION_CONFIGS = {
    # greedy search
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    # some randomness
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    # liberal
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

# Load reference model and LoRA model

In [None]:
!huggingface-cli login

In [None]:
model_id = "google/gemma-3-1b-it"
ignore_patterns = [
    "*.pth",  # Ignore PyTorch .pth weight files
]
print(f"Downloading {model_id} from Hugging Face...")
local_model_path = snapshot_download(
    repo_id=model_id, ignore_patterns=ignore_patterns
)
print(f"Model successfully downloaded to: {local_model_path}")

In [None]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  print("\n--- TPU HBM Usage ---")
  for i, d in enumerate(jax.local_devices()):
    stats = d.memory_stats()
    used = stats.get("bytes_in_use", 0)
    limit = stats.get("bytes_limit", 0)

    hbm_used = stats.get("device:0:HBM0:bytes_in_use", used)
    hbm_limit = stats.get("device:0:HBM0:bytes_limit", limit)

    # Fallback if specific HBM stats not available
    if hbm_limit == 0:
      hbm_used = used
      hbm_limit = limit

    percentage = (hbm_used / hbm_limit * 100) if hbm_limit > 0 else 0

    print(
        f"Device {i} ({d.device_kind}): Using {fmt_size(hbm_used)} /"
        f" {fmt_size(hbm_limit)} ({percentage:.2f}%)"
    )

  print("--- End HBM Usage ---")

In [None]:
print("\n--- HBM Usage BEFORE Model Load ---")
show_hbm_usage()

In [None]:
MODEL_CP_PATH = local_model_path

model_config = (
    gemma3_model_lib.Gemma3Config.gemma3_1b()
)  # pick correponding config based on model version
MESH = [(1, 1), ("fsdp", "tp")]
mesh = jax.make_mesh(*MESH)
with mesh:
  gemma3 = params_safetensors_lib.create_model_from_safe_tensors(
      MODEL_CP_PATH, model_config, mesh
  )
  nnx.display(gemma3)

In [None]:
print("\n--- HBM Usage AFTER Model Load ---")
show_hbm_usage()

In [None]:


gemma_tokenizer = data_lib.Gemma3Tokenizer()
#from transformers import AutoTokenizer

#gemma_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
sampler = sampler_lib.Sampler(
    transformer=gemma3,
    tokenizer=gemma_tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [None]:
def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
      #weight_qtype="nf4",
      #tile_size=4,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

In [None]:
# Policy model
lora_gemma = get_lora_model(gemma3, mesh=mesh)
nnx.display(lora_gemma)

Load evaluation data and evaluate the reference model

In [None]:
eval_dataset = load_dataset("gsm8k", "main", split="test").select(range(NUM_TEST_BATCHES))

In [None]:
reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"


SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \
provide your reasoning. Place it between {reasoning_start} and \
{reasoning_end}. Then, provide the final answer (i.e., just one numerical \
value) between {solution_start} and {solution_end}."""

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model"""

def generate(
    question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None
):
  """Given prompt, generates text."""

  if isinstance(question, str):
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=question,
        ),
    ]
  else:
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=q,
        )
        for q in question
    ]

  out_data = sampler(
      input_strings=input_batch,
      max_generation_steps=TOTAL_GENERATION_STEPS, # Was 768
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=False,
      seed=seed if seed is not None else None,
  )

  output = out_data.text
  if isinstance(question, str):
    return output[0]
  return output

In [None]:
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

match_format.search(
    f"{reasoning_start}Let me"
    f" think!{reasoning_end}{solution_start}2{solution_end}",
)

match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)
match_numbers.findall(f"{solution_start}  0.34  {solution_end}")

In [None]:
def evaluate(
    dataset,
    sampler,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_passes=1,
    corr_lst=False,
    make_lst=False,
):
  """Computes accuracy and percentage of outputs matching the format."""

  response_lst = []
  corr = 0
  partially_corr = 0
  corr_format = 0
  total = 0

  for batch in tqdm(dataset):
    answers = batch["answer"]
    questions = batch["question"]

    multiple_call_responses = [[] for _ in range(len(questions))]
    for p in range(num_passes):
      responses = generate(
          questions, sampler, temperature, top_k, top_p, seed=p
      )
      for idx, response in enumerate(responses):
        multiple_call_responses[idx].append(response)

    for question, multiple_call_response, answer in zip(
        questions, multiple_call_responses, answers
    ):
      # check answer
      corr_ctr_per_question = 0
      partially_corr_per_question = 0
      corr_format_per_question = 0
      for response in multiple_call_response:
        extracted_response = (
            guess.group(1)
            if (guess := match_numbers.search(response)) is not None
            else "-1000000"
        )
        try:
          if float(extracted_response.strip()) == float(answer.strip()):
            corr_ctr_per_question += 1

          ratio = float(extracted_response.strip()) / float(answer.strip())
          if ratio >= 0.9 and ratio <= 1.1:
            partially_corr_per_question += 1
        except:
          print("SKIPPED")

        # check format
        if match_format.search(response) is not None:
          corr_format_per_question += 1

        if (
            corr_ctr_per_question > 0
            and partially_corr_per_question > 0
            and corr_format_per_question > 0
        ):
          break

      if corr_ctr_per_question > 0:
        corr += 1
        if corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      else:
        if not corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      if partially_corr_per_question > 0:
        partially_corr += 1
      if corr_format_per_question > 0:
        corr_format += 1

      total += 1
      if total % 10 == 0:
        print(
            f"===> {corr=}, {total=}, {corr / total * 100=}, "
            f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
        )

  to_return = (
      corr,
      total,
      corr / total * 100,
      partially_corr / total * 100,
      corr_format / total * 100,
  )
  if make_lst:
    return to_return, response_lst
  return to_return

In [None]:
def extract_hash_answer(text: str) -> str | None:
  if "####" not in text:
    return None
  return text.split("####")[1].strip()

def get_dataset(data_dir, split="train") -> grain.MapDataset:
  # Download data
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  data = tfds.data_source(
      "gsm8k",
      split=split,
      data_dir=data_dir,
      builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
      download=True,
  )

  dataset = (
      grain.MapDataset.source(data)
      .shuffle(seed=42)
      .map(
          lambda x: {
              # passed to model forward pass
              "prompts": TEMPLATE.format(
                  system_prompt=SYSTEM_PROMPT,
                  question=x["question"].decode("utf-8"),
              ),
              # passed to reward functions
              "question": x["question"].decode("utf-8"),
              # passed to reward functions
              "answer": extract_hash_answer(x["answer"].decode("utf-8")),
          }
      )
  )
  return dataset

test_dataset = get_dataset(TEST_DATA_DIR, "test").batch(BATCH_SIZE)[
    :NUM_TEST_BATCHES
]

len(test_dataset)

In [None]:
(corr, total, accuracy, partial_accuracy, format_accuracy), responses = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["standard"],
    make_lst=True
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

# Load DPO dataset

In [None]:
dpo_dataset = load_dataset("argilla/distilabel-intel-orca-dpo-pairs", split="train")
gsm8k_train_dpo_dataset = dpo_dataset.filter(lambda x: x['in_gsm8k_train'])

# Get the number of samples in the filtered dataset
num_gsm8k_train_samples = len(gsm8k_train_dpo_dataset)
print(f"Number of samples with in_gsm8k_train=True: {num_gsm8k_train_samples}")

# Calculate how many more samples are needed
total_samples_needed = NUM_BATCHES
samples_to_add = total_samples_needed - num_gsm8k_train_samples
print(f"Number of additional random samples needed: {samples_to_add}")

if samples_to_add > 0:
    # Randomly select additional samples from the original dataset
    # Ensure we don't sample more than the total available in the original dataset
    random_samples = dpo_dataset.shuffle(seed=42).select(range(min(samples_to_add, len(dpo_dataset))))
    print(f"Number of random samples selected: {len(random_samples)}")

    # Combine the filtered dataset and the random samples
    combined_dpo_dataset = concatenate_datasets([gsm8k_train_dpo_dataset, random_samples])
else:
    combined_dpo_dataset = gsm8k_train_dpo_dataset

print(f"Total samples in the combined dataset: {len(combined_dpo_dataset)}")

# Display the combined dataset info
print("\nCombined Dataset Info:")
print(combined_dpo_dataset)

In [None]:
def process_dpo_dataset(dataset, tokenizer, max_prompt_length, total_generation_steps, batch_size):
    processed_batches = []
    for i in tqdm(range(0, len(dataset), batch_size)):
        batch_examples = dataset.select(range(i, min(i + batch_size, len(dataset))))
        processed_examples = []
        for example in batch_examples:
            # Apply left padding to prompts
            prompt_ids, prompt_mask = _generate_ids_and_masks(
                [example["input"]], tokenizer, max_prompt_length, left_pad=True
            )

            # Apply right padding to chosen and rejected responses
            chosen_ids, chosen_mask = _generate_ids_and_masks(
                [example["chosen"]], tokenizer, total_generation_steps, left_pad=False
            )
            rejected_ids, rejected_mask = _generate_ids_and_masks(
                [example["rejected"]], tokenizer, total_generation_steps, left_pad=False
            )
            processed_examples.append({
                "prompt_ids": prompt_ids.astype(np.int64),
                "prompt_mask": prompt_mask.astype(np.float64), # Cast to float64
                "chosen_ids": chosen_ids.astype(np.int64),
                "chosen_mask": chosen_mask.astype(np.float64), # Cast to float64
                "rejected_ids": rejected_ids.astype(np.int64),
                "rejected_mask": rejected_mask.astype(np.float64), # Cast to float64
            })

        data_dict = {key: np.array([example[key] for example in processed_examples]) for key in processed_examples[0].keys()}


        training_input = TrainingInput(
            prompt_ids=data_dict["prompt_ids"],
            prompt_mask=data_dict["prompt_mask"],
            chosen_ids=data_dict["chosen_ids"],
            chosen_mask=data_dict["chosen_mask"],
            rejected_ids=data_dict["rejected_ids"],
            rejected_mask=data_dict["rejected_mask"],
        )
        processed_batches.append(training_input)

    return processed_batches

processed_dpo_dataset = process_dpo_dataset(combined_dpo_dataset, gemma_tokenizer, MAX_PROMPT_LENGTH, TOTAL_GENERATION_STEPS, BATCH_SIZE)

In [None]:
class MySource():

  def __init__(self, data):
    self._data = data

  def __getitem__(self, idx):
    return self._data[idx]

  def __len__(self):
    return len(self._data)
def _dummy_dataset(
    source: MySource,
    prompt_ids: np.ndarray,
    prompt_mask: np.ndarray,
    chosen_ids: np.ndarray,
    chosen_mask: np.ndarray,
    rejected_ids: np.ndarray,
    rejected_mask: np.ndarray,
):
  return grain.MapDataset.source(source).map(
      lambda x: TrainingInput(
          prompt_ids=prompt_ids[x],
          prompt_mask=prompt_mask[x],
          chosen_ids=chosen_ids[x],
          chosen_mask=chosen_mask[x],
          rejected_ids=rejected_ids[x],
          rejected_mask=rejected_mask[x],
      )
  )
train_ds_dpo = _dummy_dataset(
        range(len(processed_dpo_dataset)),
        [processed_dpo_dataset[x].prompt_ids for x in range(len(processed_dpo_dataset))],
        [processed_dpo_dataset[x].prompt_mask for x in range(len(processed_dpo_dataset))],
        [processed_dpo_dataset[x].chosen_ids for x in range(len(processed_dpo_dataset))],
        [processed_dpo_dataset[x].chosen_mask for x in range(len(processed_dpo_dataset))],
        [processed_dpo_dataset[x].rejected_ids for x in range(len(processed_dpo_dataset))],
        [processed_dpo_dataset[x].rejected_mask for x in range(len(processed_dpo_dataset))],
    )

# Define optimizer and DPO Trainer

In [None]:
# Optimizer, learning rate scheduler, gradient clipping
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM is not None:
  optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
      optimizer,
  )

In [None]:
# Configure DPO Training (using previously defined config variables)
dpo_config = DpoTrainingConfig(
    beta=BETA,
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
)

In [None]:
dpo_config

In [None]:
dpo_trainer = DpoTrainer(
    model=lora_gemma,
    ref_model=gemma3,
    optimizer=optimizer,
    training_config=dpo_config,
)

# Train and evaluate LoRA model

In [None]:
show_hbm_usage()

In [None]:
# Start training
print("Starting DPO training...")

dpo_trainer.train(train_ds=processed_dpo_dataset)
print("DPO training finished.")

In [None]:
(corr, total, accuracy, partial_accuracy, format_accuracy), responses = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["standard"],
    make_lst=True
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

In [None]:
from google.colab import runtime
runtime.unassign()