<a href="https://colab.research.google.com/github/tsherida2012/wget-files/blob/main/pdo_for_creative.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install dotenv
%pip install kagglehub
%pip install safetensors
%pip install tensorflow
%pip install tensorflow_datasets
%pip install tensorboardX
%pip install transformers
%pip install grain
%pip install datasets
%pip install huggingface_hub
%pip install wandb
%pip uninstall flax -y
%pip install git+https://github.com/google/flax
%pip install git+https://github.com/google/tunix
%pip install git+https://github.com/google/qwix

Found existing installation: flax 0.12.1
Uninstalling flax-0.12.1:
  Successfully uninstalled flax-0.12.1
Collecting git+https://github.com/google/flax
  Cloning https://github.com/google/flax to /tmp/pip-req-build-olcovuf1
  Running command git clone --filter=blob:none --quiet https://github.com/google/flax /tmp/pip-req-build-olcovuf1
  Resolved https://github.com/google/flax to commit 697f4e5cda4b110decf862a6f8ab71a0345d0412
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: flax
  Building wheel for flax (pyproject.toml) ... [?25l[?25hdone
  Created wheel for flax: filename=flax-0.12.1-py3-none-any.whl size=488158 sha256=80233a9842d18a1490b6cc2d493d048f58fd3b8892dd0374bf58f31a7bfa757a
  Stored in directory: /tmp/pip-ephem-wheel-cache-i_hhrmhr/wheels/1c/1c/50/70e06d9ee1df89d65ab742227bddb43d8e4eea822db6757377
Successfully bu

In [None]:

import os
import sys

try:
  from google.colab import userdata
  USE_COLAB = True

  %pip uninstall -y wandb -y  # wandb is glitchy with tunix in colab

  os.environ["WANDB_API_KEY"] = userdata.get('WANDB_API_KEY')
  os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
except:
  USE_COLAB = False

  from dotenv import load_dotenv
  load_dotenv()
  print("Using env vars to login")

  import nest_asyncio
  nest_asyncio.apply()
  print("nest_asyncio applied")

  import wandb
  if "WANDB_API_KEY" in os.environ and os.environ["WANDB_API_KEY"]:
    wandb.login(key=os.environ["WANDB_API_KEY"])
  else:
    print("WANDB_API_KEY not found. Skipping wandb login.")

# Check if HF_TOKEN is set before logging in
if "HF_TOKEN" in os.environ and os.environ["HF_TOKEN"]:
  hf_token = os.environ["HF_TOKEN"]
  !hf auth login --token "$hf_token"
else:
  print("HF_TOKEN not found. Skipping Hugging Face login.")

Found existing installation: wandb 0.23.0
Uninstalling wandb-0.23.0:
  Successfully uninstalled wandb-0.23.0
The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `hf`CLI if you want to set the git credential as well.
Token is valid (permission: fineGrained).
The token `hotblood` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [None]:

# Imports

import os
import sys
import json
import shutil

from datasets import concatenate_datasets
from datasets import load_dataset
from flax import nnx
import grain
from huggingface_hub import snapshot_download
import jax
import jax.numpy as jnp
import numpy as np
import optax
from orbax import checkpoint as ocp
import qwix
import safetensors.numpy as safe_np
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import model as gemma3_model_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
from tunix.sft import metrics_logger
from tunix.sft.dpo.dpo_trainer import DPOTrainer
from tunix.sft.dpo.dpo_trainer import DPOTrainingConfig
from tunix.sft.utils import show_hbm_usage



In [None]:
# Hyperparamters/Config

model_id = "google/gemma-3-1b-it"  # also supports "google/gemma-3-270m-it"
GEMMA_TOKENIZER_PATH = "gs://gemma-data/tokenizers/tokenizer_gemma3.model"

# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

INTERMEDIATE_CKPT_DIR = "/content/intermediate_ckpt/"
# ====== LoRA ======
RANK = 32
ALPHA = 16.0

# ====== Sharding ======
# Adjust mesh based on your TPU memory and model size.
NUM_TPUS = len(jax.devices())
if NUM_TPUS == 8:
  MESH_COUNTS = (1, 4)
elif NUM_TPUS == 1:
  MESH_COUNTS = (1, 1)
else:
  raise ValueError(f"Unsupported number of TPUs: {NUM_TPUS}")

MESH = [
    MESH_COUNTS,
    ("fsdp", "tp"),
]

MAX_PROMPT_LENGTH = 192
MAX_RESPONSE_LENGTH = 192
TEMPERATURE = 0.7
TOP_P = 1.0
TOP_K = 50
BETA = 0.1

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-5
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1

# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
BATCH_SIZE = 2
NUM_BATCHES = 512
NUM_TEST_BATCHES = 100
NUM_TEST_BATCHES = 2
EVAL_EVERY_N_STEPS = 1024

NUM_EPOCHS = 2  # can potentially train for more epochs
MAX_STEPS = int(NUM_BATCHES * TRAIN_FRACTION * NUM_EPOCHS)

WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1

# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

# ====== Inference ======
GENERATION_CONFIGS = {
    # greedy search
    "greedy": {"temperature": None, "top_k": 1, "top_p": None},
    # some randomness
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    # liberal
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

In [None]:
ignore_patterns = [
    "*.pth",  # Ignore PyTorch .pth weight files
]
print(f"Downloading {model_id} from Hugging Face...")
local_model_path = snapshot_download(
    repo_id=model_id, ignore_patterns=ignore_patterns
)
print(f"Model successfully downloaded to: {local_model_path}")

EOS_TOKENS = []
generation_config_path = os.path.join(local_model_path, "generation_config.json")
if os.path.exists(generation_config_path):
  with open(generation_config_path, "r") as f:
    generation_configs = json.load(f)
  EOS_TOKENS = generation_configs.get("eos_token_id", [])
  print(f"Using EOS token IDs: {EOS_TOKENS}")


Downloading google/gemma-3-1b-it from Hugging Face...


Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

README.md:   0%|          | 0.00/24.3k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/899 [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.00G [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

Model successfully downloaded to: /root/.cache/huggingface/hub/models--google--gemma-3-1b-it/snapshots/dcc83ea841ab6100d6b47a070329e1ba4cf78752
Using EOS token IDs: [1, 106]


In [None]:
print("\n--- HBM Usage BEFORE Model Load ---")
show_hbm_usage()


--- HBM Usage BEFORE Model Load ---


In [None]:
MODEL_CP_PATH = local_model_path

if "gemma-3-270m" in model_id:
  model_config = gemma3_model_lib.ModelConfig.gemma3_270m()
elif "gemma-3-1b" in model_id:
  model_config = gemma3_model_lib.ModelConfig.gemma3_1b()
else:
  raise ValueError(f"Unsupported model: {model_id}")

mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]))
with mesh:
  gemma3 = params_safetensors_lib.create_model_from_safe_tensors(
      MODEL_CP_PATH, (model_config), mesh
  )
  nnx.display(gemma3)

In [None]:

gemma_tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)
if gemma_tokenizer.eos_id() not in EOS_TOKENS:
  EOS_TOKENS.append(gemma_tokenizer.eos_id())
  print(f"Using EOS token IDs: {EOS_TOKENS}")

sampler = sampler_lib.Sampler(
    transformer=gemma3,
    tokenizer=gemma_tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [None]:
# === Build Gemma3-1B base model + tokenizer ===

print("Setting up device mesh...")
mesh = jax.make_mesh(
    *MESH,
    axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0])
)
print("Mesh:", mesh)

# Pick the correct model config for Gemma3 1B
model_config = gemma3_model_lib.ModelConfig.gemma3_1b()

print("\nLoading Gemma3-1B params from safetensors...")
safe_path = os.path.abspath(local_model_path)
print("Using safetensors path:", safe_path)

with mesh:
  base_model: nnx.Module = params_safetensors_lib.create_model_from_safe_tensors(
      safe_path,
      model_config,
      mesh,
      dtype=jnp.bfloat16,
  )

print("Base Gemma3 model created.")
show_hbm_usage()

print("\nLoading tokenizer...")
tokenizer = tokenizer_lib.Tokenizer(
    tokenizer_type="sentencepiece",
    tokenizer_path=GEMMA_TOKENIZER_PATH,
)
print("Tokenizer ready.")


Setting up device mesh...
Mesh: Mesh('fsdp': 1, 'tp': 1, axis_types=(Auto, Auto))

Loading Gemma3-1B params from safetensors...
Using safetensors path: /root/.cache/huggingface/hub/models--google--gemma-3-1b-it/snapshots/dcc83ea841ab6100d6b47a070329e1ba4cf78752
Base Gemma3 model created.

Loading tokenizer...
Tokenizer ready.


In [None]:
print("\n--- HBM Usage BEFORE Model Load ---")
show_hbm_usage()

MODEL_CP_PATH = local_model_path

if "gemma-3-270m" in model_id:
  model_config = gemma3_model_lib.ModelConfig.gemma3_270m()
elif "gemma-3-1b" in model_id:
  model_config = gemma3_model_lib.ModelConfig.gemma3_1b()
else:
  raise ValueError(f"Unsupported model: {model_id}")

mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]))

with mesh:
  gemma3 = params_safetensors_lib.create_model_from_safe_tensors(
      MODEL_CP_PATH, (model_config), mesh
  )
  nnx.display(gemma3)

print("\n--- HBM Usage AFTER Model Load ---")
show_hbm_usage()



--- HBM Usage BEFORE Model Load ---



--- HBM Usage AFTER Model Load ---


In [None]:
gemma_tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)
if gemma_tokenizer.eos_id() not in EOS_TOKENS:
  EOS_TOKENS.append(gemma_tokenizer.eos_id())
  print(f"Using EOS token IDs: {EOS_TOKENS}")

sampler = sampler_lib.Sampler(
    transformer=gemma3,
    tokenizer=gemma_tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)


In [None]:
test_prompt = "user\nExplain DPO in one short sentence.\nmodel\n"

out = sampler(
    input_strings=[test_prompt],
    max_generation_steps=64,
    **GENERATION_CONFIGS["standard"],
    eos_tokens=EOS_TOKENS,
)

print(out.text[0])




A data protection officer (DPO) is a legal professional who advises and oversees a company's data protection compliance efforts.

Do you want me to explain another aspect of GDPR or something else?



In [None]:
def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

# Policy model
lora_gemma = get_lora_model(gemma3, mesh=mesh)
nnx.display(lora_gemma)


In [None]:
# ============================================================
# CREATIVE WRITING DPO DATASET + TRAINER SETUP
# ============================================================

print("=== Loading Creative Writing DPO Dataset ===")

def get_creative_dataset() -> grain.MapDataset:
    from datasets import load_dataset

    # 1. Load Anthropic HH-RLHF dataset (has prompt / chosen / rejected)
    dpo_dataset = load_dataset("Anthropic/hh-rlhf", split="train")
    print("Raw dataset size:", len(dpo_dataset))
    print("Dataset columns:", dpo_dataset.column_names)

    # Define the TEMPLATE variable, as it's used later
    # Example template, adjust as needed based on desired prompt format
    TEMPLATE = "user\n{question}\nmodel\n"

    # 2. Filter to creative-writing style prompts
    creative_keywords = [
        "story", "creative", "narrative", "poem", "plot",
        "fantasy", "describe", "character", "fiction", "imagine"
    ]

    def is_creative(x):
        # Extract the prompt part from the 'chosen' dialogue string
        # Example format: "Human: <prompt>\n\nAssistant: <response>"
        full_dialogue = x["chosen"]
        prompt_part = full_dialogue.split("Assistant:")[0].replace("Human:", "").strip()
        text = prompt_part.lower()
        return any(kw in text for kw in creative_keywords)

    creative_dataset = dpo_dataset.filter(is_creative)
    print("Creative subset size:", len(creative_dataset))

    # 3. Map into DPOTrainer format
    dataset = grain.MapDataset.source(creative_dataset).map(
        lambda x: {
            "prompts": x["chosen"].split("Assistant:")[0].replace("Human:", "").strip(),
            "chosen_responses": x["chosen"].split("Assistant:", 1)[1].strip(),
            "rejected_responses": x["rejected"].split("Assistant:", 1)[1].strip(),
        }
    )

    return dataset


# ============================================================
# Build TRAIN + VAL datasets
# ============================================================

dataset = get_creative_dataset().batch(BATCH_SIZE)[:NUM_BATCHES]

if TRAIN_FRACTION == 1.0:
    train_dataset = dataset.repeat(NUM_EPOCHS)
    val_dataset = None
else:
    train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
    train_dataset = train_dataset.repeat(NUM_EPOCHS)

    val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION):].repeat(NUM_EPOCHS)

print("Train dataset batches:", len(train_dataset))
print("Val dataset:", "None" if val_dataset is None else len(val_dataset))


# ============================================================
# Optimizer (AdamW + warmup + cosine decay)
# ============================================================

optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)

if MAX_GRAD_NORM is not None:
    optimizer = optax.chain(
        optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
        optimizer,
    )


# ============================================================
# DPO Training Config
# ============================================================

checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS,
    max_to_keep=MAX_TO_KEEP
)

metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tmp/tensorboard/dpo",
    flush_every_n_steps=20
)

dpo_config = DPOTrainingConfig(
    beta=BETA,
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
    max_prompt_length=MAX_PROMPT_LENGTH,
    max_response_length=MAX_RESPONSE_LENGTH,
    metrics_logging_options=metrics_logging_options,
    checkpoint_root_directory=CKPT_DIR,
    checkpointing_options=checkpointing_options,
)

print("DPO config ready.")


# ============================================================
# DPO Trainer
# ============================================================

dpo_trainer = DPOTrainer(
    model=lora_gemma,         # LoRA policy model
    ref_model=gemma3,         # reference model
    optimizer=optimizer,
    training_config=dpo_config,
    tokenizer=gemma_tokenizer,
)

print("Creative Writing DPO trainer built successfully.")


=== Loading Creative Writing DPO Dataset ===
Raw dataset size: 160800
Dataset columns: ['chosen', 'rejected']


Filter:   0%|          | 0/160800 [00:00<?, ? examples/s]

Creative subset size: 3597
Train dataset batches: 1024
Val dataset: None
DPO config ready.
Creative Writing DPO trainer built successfully.


In [None]:
with mesh:
    dpo_trainer.train(train_dataset, val_dataset)


Training:   0%|          | 0/1024 [00:00<?, ?step/s]

In [None]:
test_prompt = """user
Write a short fantasy story about a wandering mage who discovers a forbidden spell.
model
"""

out = sampler(
    input_strings=[test_prompt],
    max_generation_steps=160,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    echo=False,
    eos_tokens=EOS_TOKENS,
)

print("\n=== Model Output ===\n")
print(out.text[0])



=== Model Output ===



The wind tasted of iron and regret as Lyra traced a finger along the crumbling stone of the Forgotten Archive. Dust motes, thick as tiny stars, danced in the single shaft of sunlight piercing the gloom. She was a wanderer, a mage of the Silverwood, her spells honed to a precise, almost painful efficiency. She’d spent decades traversing the land, seeking forgotten knowledge, but lately, a quiet, unsettling loneliness clung to her like the damp air.

She wasn’t seeking power, not really. Just… something to fill the echoing spaces within her. The Archive, a place rumored to hold echoes of lost magic, was her solace. Tonight, she was looking for a specific section – a fragment of a spell, a ‘Silencing Bloom,’ said to be


In [None]:
output_dir = f"./{model_id}-lora"
if USE_COLAB:
    output_dir = f"/tmp/content/{model_id}-lora"

if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
os.makedirs(output_dir)

print(f"Saving model to {output_dir}")
print("\nStep 1: Extracting LoRA weights from lora_model...")

def path_to_str(qwix_path):
    return '.'.join([str(field) for field in qwix_path])

lora_layers = {}
for layer in lora_gemma.layers:
    down_proj_path = path_to_str(layer.mlp.down_proj.qwix_path)
    up_proj_path = path_to_str(layer.mlp.up_proj.qwix_path)
    lora_layers[down_proj_path] = (layer.mlp.down_proj.kernel_lora_a, layer.mlp.down_proj.kernel_lora_b)
    lora_layers[up_proj_path] = (layer.mlp.up_proj.kernel_lora_a, layer.mlp.up_proj.kernel_lora_b)

print(f"Found {len(lora_layers)} LoRA layers")
print(f"LoRA layer names: {list(lora_layers.keys())[:3]}...")

base_state = safe_np.load_file(local_model_path + "/model.safetensors")
print(f"Starting with {len(base_state)} base model parameters")

# Step 2: Apply LoRA deltas
for lora_name, (lora_a, lora_b) in lora_layers.items():
    state_key = f'model.{lora_name}.weight'
    assert state_key in base_state, f"LoRA layer {lora_name} not found in base model state dict"
    lora_a_val = jnp.asarray(lora_a.value).astype(np.float32)
    lora_b_val = jnp.asarray(lora_b.value).astype(np.float32)

    combined_lora = lora_a_val @ lora_b_val
    base_state[state_key] = base_state[state_key] + combined_lora.T

print(f"Merged {len(lora_layers)} LoRA layers")

print("\nStep 3: Saving as safetensors...")
safetensors_path = os.path.join(output_dir, "model.safetensors")
safe_np.save_file(base_state, safetensors_path)
print("Model weights saved")

# Copy everything from the base model repo except original safetensors
for filename in os.listdir(local_model_path):
    if not filename.endswith(".safetensors"):
        src = os.path.join(local_model_path, filename)
        dst = os.path.join(output_dir, filename)
        if os.path.isfile(src):
            shutil.copy(src, dst)
            print(f"Copied {filename}")

print("\n=== Finished Saving Model ===")
print(f"Output directory: {output_dir}")

# Optional Colab download
if USE_COLAB:
    from google.colab import files
    shutil.make_archive(output_dir, 'zip', output_dir)
    files.download(f"{output_dir}.zip")


Saving model to /tmp/content/google/gemma-3-1b-it-lora

Step 1: Extracting LoRA weights from lora_model...
Found 52 LoRA layers
LoRA layer names: ['layers.0.mlp.down_proj', 'layers.0.mlp.up_proj', 'layers.1.mlp.down_proj']...
Starting with 340 base model parameters



  variable[...]

For other Variable types use:

  variable.get_value()

  lora_a_val = jnp.asarray(lora_a.value).astype(np.float32)

  variable[...]

For other Variable types use:

  variable.get_value()

  lora_b_val = jnp.asarray(lora_b.value).astype(np.float32)


Merged 52 LoRA layers

Step 3: Saving as safetensors...
Model weights saved
Copied special_tokens_map.json
Copied config.json
Copied added_tokens.json
Copied tokenizer.json
Copied .gitattributes
Copied README.md
Copied tokenizer.model
Copied generation_config.json
Copied tokenizer_config.json

=== Finished Saving Model ===
Output directory: /tmp/content/google/gemma-3-1b-it-lora


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>