In [1]:
# Minimal Paligemma-style GRPO demo (SigLIP + Gemma with LoRA).
# This uses a dummy image stream; replace with your VQA dataset.

# %% [markdown]
# ## Install

# %%
!pip install -q tensorboardX grain datasets pillow
!pip install -q jaxtyping sentencepiece datasets grain tensorboardX tensorflow_datasets

In [2]:
!pip install -q git+https://github.com/google/qwix

In [3]:
import sys
sys.path.insert(0, "/home/grads/tianjiao/tunix")

In [None]:
# ## Imports & paths

# %%
import os, numpy as np, jax, jax.numpy as jnp
from flax import nnx
import optax

from tunix.models.gemma3 import params as gemma3_params
from tunix.models.gemma3 import model as gemma3_model
from tunix.models.gemma import data as gemma_tokenizer_lib
from tunix.models.siglip import params as siglip_params
from tunix.models.siglip import model as siglip_model

from tunix.generate import vlm_sampler as vlm_sampler_lib
from tunix.rl.rollout import vlm_rollout as vlm_rollout_lib
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
from tunix.rl.rewards import vqa_rewards as R
from tunix.sft import metrics_logger
from tunix.rl.rollout import base_rollout
import qwix

SIGLIP_DIR = "/home/grads/tianjiao/checkpoints/siglip-so400m-patch14-384"    # @title <-- put your safetensors dir
GEMMA3_CKPT = "gs://gemma-data/checkpoints/gemma3-4b-it"  # @title

WORKDIR = "/home/grads/tianjiao/vlm_grpo"
os.makedirs(WORKDIR, exist_ok=True)


In [5]:
import os, jax
print("backend:", jax.default_backend())
print("devices:", jax.devices())
print("local_devices:", jax.local_devices())
print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))
print("JAX_PLATFORM_NAME:", os.environ.get("JAX_PLATFORM_NAME"))

backend: cpu
devices: [CpuDevice(id=0)]
local_devices: [CpuDevice(id=0)]
CUDA_VISIBLE_DEVICES: None
JAX_PLATFORM_NAME: None


In [6]:
# ## Tokenizer, mesh, configs

# %%
tokenizer = gemma_tokenizer_lib.GemmaTokenizer()
mesh = jax.make_mesh((1,1), ("fsdp","tp"))

MAX_PROMPT_LEN = 256
MAX_GEN_STEPS = 128


In [7]:
# ## Load models (text + vision) and apply LoRA

# %%
gcfg = gemma3_model.Gemma3Config.gemma3_4b()
gemma_text = gemma3_params.create_model_from_checkpoint(GEMMA3_CKPT, gcfg, mesh)
ref_text   = gemma3_params.create_model_from_checkpoint(GEMMA3_CKPT, gcfg, mesh)

lora_provider = qwix.LoraProvider(
    module_path=".*q_einsum|.*kv_einsum|.*attn_vec_einsum|.*gate_proj|.*up_proj|.*down_proj",
    rank=32, alpha=32.0,
)
lora_text = qwix.apply_lora_to_model(gemma_text, lora_provider, **gemma_text.get_model_input())

with mesh:
  st = nnx.state(lora_text)
  ps = nnx.get_partition_spec(st)
  nnx.update(lora_text, jax.lax.with_sharding_constraint(st, ps))

E0907 01:26:48.882745  441336 google_auth_provider.cc:188] Could not find the credentials file in the standard gcloud location [/home/grads/tianjiao/.config/gcloud/application_default_credentials.json]. You may specify a credentials file using $GOOGLE_APPLICATION_CREDENTIALS, or to use Google application default credentials, run: gcloud auth application-default login


In [12]:

scfg = siglip_model.SigLIPConfig.so400m_patch14_384()
siglip = siglip_params.create_model_from_safe_tensors(SIGLIP_DIR, scfg, mesh)

ValueError: No safetensors in /home/grads/tianjiao/siglip-so400m-patch14-384

In [None]:

# %%
cache_cfg = vlm_sampler_lib.CacheConfig(
    cache_size=MAX_PROMPT_LEN + MAX_GEN_STEPS + 256,
    num_layers=gcfg.num_layers,
    num_kv_heads=gcfg.num_kv_heads,
    head_dim=gcfg.head_dim,
)
sampler = vlm_sampler_lib.VLMSampler(
    transformer=lora_text,
    vision_encoder=siglip,
    tokenizer=tokenizer,
    cache_config=cache_cfg,
    image_size=384,
)
rollout = vlm_rollout_lib.VLMRollout(
    sampler=sampler,
    pad_id=tokenizer.pad_id(),
    eos_id=tokenizer.eos_id(),
)


In [None]:
# ## Rewards

# %%
reward_fns = [
  R.match_format_exact,
  R.match_format_soft,
  (lambda **kw: R.exact_match(**kw, answer=kw["answer"])),
  (lambda **kw: R.fuzzy_contains(**kw, answer=kw["answer"])),
  (lambda **kw: R.numeric_tolerance(**kw, answer=kw["answer"], rtol=0.02)),
  R.brevity_penalty,
]


In [None]:
# ## RL cluster & GRPO

# %%
ckpt_opts = rl_cluster_lib.ocp.CheckpointManagerOptions(save_interval_steps=200, max_to_keep=4)
mlog_opts = metrics_logger.MetricsLoggerOptions(log_dir=os.path.join(WORKDIR,"tb"), flush_every_n_steps=20)

cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine="vlm",
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optax.adamw(3e-6, b1=0.9, b2=0.99, weight_decay=0.1),
        eval_every_n_steps=50,
        max_steps=300,
        gradient_accumulation_steps=1,
        metrics_logging_options=mlog_opts,
        checkpoint_root_directory=os.path.join(WORKDIR, "ckpts"),
        checkpointing_options=ckpt_opts,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=MAX_GEN_STEPS,
        max_prompt_length=MAX_PROMPT_LEN,
        kv_cache_size=cache_cfg.cache_size,
        temperature=0.9, top_p=1.0, top_k=50,
    ),
)

rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_text,
    reference=ref_text,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
    rollout=rollout,
)

grpo = GrpoLearner(
    rl_cluster=rl_cluster,
    reward_fns=reward_fns,
    grpo_config=GrpoConfig(num_generations=2, num_iterations=1, beta=0.08, epsilon=0.2),
)

In [None]:
# ## Dummy VQA iterator (replace with your dataset)

# %%
import numpy as np
def dummy_iter(bs):
  H = W = 384
  img = np.ones((H,W,3), dtype=np.uint8) * 255
  while True:
    yield {
      "prompts": ["What is in the image? Answer in <answer>...</answer> format."] * bs,
      "image": np.stack([img]*bs, axis=0),
      "answer": ["a blank white image"] * bs,
      "question": ["What is in the image?"] * bs,
    }

train_ds = dummy_iter(1)


In [None]:
# %%
with mesh:
  grpo.train(train_ds=train_ds, eval_ds=None, skip_jit=False)