## Pre steps to setup the CPU node pool and get k8s credential
1. Create a CPU node pool in GKE (update the env var based on your setup)

```
export PROJECT_ID=cloud-tpu-multipod-dev
export CLUSTER_NAME=mlperf-v5p
export ZONE=europe-west4
export CPU_POOL_NAME="tsbao-cpu-pool"
export MACHINE_TYPE="n2-standard-8"
export NUM_NODES=1

gcloud container node-pools create ${CPU_POOL_NAME}   --cluster=${CLUSTER_NAME}   --zone=${ZONE}   --project=${PROJECT_ID}    --machine-type=${MACHINE_TYPE}   --num-nodes=${NUM_NODES}   --enable-autoscaling --min-nodes=1 --max-nodes=5  --node-labels="cloud.google.com/gke-nodepool=${CPU_POOL_NAME}"
```

2. Create k8s credential (this will add credential to your local ~/.kube/config)

```
 gcloud container clusters get-credentials ${CLUSTER_NAME} --zone ${ZONE} --project ${PROJECT_ID}
```

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.insert(0, '/scratch/git/rllm')
sys.path.insert(0, '/scratch/git/pathways-utils')

In [3]:
import os
from datasets import load_dataset
DATASET_CACHE = os.getenv('DATASET_CACHE', '/tmp/dataset_cache')
TASKS_TO_PROCESS = 100

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
dataset = load_dataset("R2E-Gym/R2E-Gym-V1", split="train", cache_dir=DATASET_CACHE, num_proc=32)
entries = []
unique_images = set()
for i, entry in enumerate(dataset):
  if "docker_image" in entry:
    unique_images.add(entry["docker_image"])
    entries.append(entry)
  if i >= TASKS_TO_PROCESS - 1:
    break
unique_images = list(unique_images)
print(f"Found {len(unique_images)} unique Docker images to download")
IDS = [f"task-{i}" for i in range(len(entries))]

Found 100 unique Docker images to download


In [5]:
import os

os.environ["KUBECONFIG"] = "~/.kube/config"

from kubernetes import client, config
config.load_kube_config()
k8s_client = client.CoreV1Api()
# k8s_client.list_namespace(timeout_seconds=5)

In [6]:
# import r2egym

# print(r2egym.__file__)
# from r2egym.agenthub.runtime.docker import DockerRuntime
# from r2egym.agenthub.utils.log import get_logger
# from r2egym.agenthub.environment.env import EnvArgs, RepoEnv

# env_args = EnvArgs(ds=entries[0])
# env = RepoEnv(env_args, backend="kubernetes")


In [7]:
# from rllm.environments.swe.swe import R2EGYM_COMMAND_FILES

# env.add_commands(cmd_files=R2EGYM_COMMAND_FILES)

In [8]:
MODEL_PATH = "/tmp/models/DeepSeek-R1-Distill-Qwen-1.5B/DeepSeek-R1-Distill-Qwen-1.5B/"
MODEL_VERSION = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

from transformers import AutoTokenizer
from tunix.rl.agentic.parser.chat_template_parser import parser

tokenizer = AutoTokenizer.from_pretrained(MODEL_VERSION)

chat_parser = parser.QwenChatTemplateParser(tokenizer)

In [9]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from tunix.models.qwen2 import params as params_lib
from tunix.models.qwen2 import model as model_lib
from tunix.sft import utils as sft_utils

devices = jax.devices()
split = int(len(devices) / 2)
rollout_devices = np.array(devices[:split-2]).reshape(split-2, 1)
train_devices = np.array(devices[split:]).reshape(split, 1)
rollout_mesh = Mesh(rollout_devices, axis_names=('fsdp', 'tp'))
train_mesh = Mesh(train_devices, axis_names=('fsdp', 'tp'))

config = model_lib.ModelConfig.deepseek_r1_distill_qwen_1p5b()
qwen2_actor = params_lib.create_model_from_safe_tensors(MODEL_PATH, config, train_mesh, dtype=jnp.float32)
qwen2_ref = params_lib.create_model_from_safe_tensors(MODEL_PATH, config, train_mesh, dtype=jnp.float32)
sft_utils.show_hbm_usage()

In [10]:
from tunix.generate import sampler

sampler = sampler.Sampler(qwen2_actor, tokenizer, sampler.CacheConfig(cache_size=16384, num_layers=28, num_kv_heads=2, head_dim=128))

In [11]:
# # ====== Data ======
# TRAIN_FRACTION = 1.0

# # ====== Reproducibility ======
# SEED = 42

# # ====== LoRA ======
# RANK = 64
# ALPHA = 64.0
# TRAIN_WITH_LORA = False

# # ====== Sharding ======
# MESH = [(2, 4), ("fsdp", "tp")]

# # ====== GRPO ======
# # === Generation during GRPO training ===
# MAX_PROMPT_LENGTH = 2048
# TOTAL_GENERATION_STEPS = 512
# # Important to keep a high-ish temperature for varied, diverse responses during
# # training.
# TEMPERATURE = 0.6
# TOP_P = 0.95
# TOP_K = 50
# # The number of times the policy generates multiple responses for a given prompt
# # within a single training step. This corresponds to `G` in Algorithm 1 in the
# # paper. The "group" in GRPO comes from here.
# NUM_GENERATIONS = 2

# # === other GRPO configs ===
# # The number of iterations per batch (ùúá in GRPO algo 1).
# NUM_ITERATIONS = 1
# # The coefficient for the KL divergence penalty (ùõΩ) in the GRPO loss function.
# # Important to keep a high enough value for this, otherwise, the KL divergence
# # can increase unchecked.
# BETA = 0.001
# # Epsilon value for clipping (ùúÄ in GRPO loss in paper). Similar to PPO, for
# # stable updates.
# EPSILON = 0.2

# # ====== Training ======
# BATCH_SIZE = 16
# MINI_BATCH_SIZE = 16
# # ROLLOUT_MICRO_BATCH_SIZE = 8
# # LOGPS_MICRO_BATCH_SIZE = 8
# NUM_BATCHES = 100
# # Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# # increased to a max. of 330 (if batch size is 4).
# NUM_TEST_BATCHES = 50

# EVAL_EVERY_N_STEPS = 1000  # this doesn't matter if `TRAIN_FRACTION = 1.0`.
# NUM_EPOCHS = 100 # can potentially train for more epochs

# # Number of training steps.
# MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

# # === AdamW, warmup, cosine scheduler ===
# LEARNING_RATE = 1e-6
# B1 = 0.9  # Adam beta1
# B2 = 0.99  # Adam beta2
# WEIGHT_DECAY = 0.1
# # == Cosine decay with warmup scheduler ==
# # Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# # steps, and then gradually decrease the learning rate to 0 using cosine
# # scheduler.
# WARMUP_STEPS = int(0.1 * MAX_STEPS)
# # == Grad clipping ==
# # Grad clipping to prevent large gradients. Found this
# # important to keep KL divergence in check.
# MAX_GRAD_NORM = 0.1

# # ====== Checkpoint saving ======
# SAVE_INTERVAL_STEPS = 500
# MAX_TO_KEEP = 4
# DO_MEM_PROFILING = False

# # ====== Inference ======
# GENERATION_CONFIGS = {
#     # greedy search
#     "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
#     # some randomness
#     "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
#     # liberal
#     "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
# }
# # ====== Rollout ======
# ROLLOUT_ENGINE = "sglang_jax" # one of "vanilla", "vllm" or "sglang_jax"

# CKPT_DIR = os.path.join("/tmp/cp", "deepscaler_ckpt/01")

In [12]:
# from tunix.rl import rl_cluster as rl_cluster_lib
# import optax
# from tunix.sft import metrics_logger
# from orbax import checkpoint as ocp
# from tunix.rl.rollout import base_rollout

# checkpointing_options = ocp.CheckpointManagerOptions(
#     save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
# )
# metrics_logging_options = metrics_logger.MetricsLoggerOptions(
#     log_dir="/tmp/tensorboard/grpo", flush_every_n_steps=20
# )

# optimizer = optax.adamw(
#     learning_rate=optax.schedules.warmup_cosine_decay_schedule(
#         init_value=0.0,
#         peak_value=LEARNING_RATE,
#         warmup_steps=WARMUP_STEPS,
#         decay_steps=MAX_STEPS,
#         end_value=0.0,
#     ),
#     b1=B1,
#     b2=B2,
#     weight_decay=WEIGHT_DECAY,
# )

# cluster_config = rl_cluster_lib.ClusterConfig(
#     role_to_mesh={
#         rl_cluster_lib.Role.ACTOR: train_mesh,
#         rl_cluster_lib.Role.REFERENCE: train_mesh,
#         rl_cluster_lib.Role.ROLLOUT: rollout_mesh,
#     },
#     rollout_engine=ROLLOUT_ENGINE,
#     offload_to_cpu=False,
#     training_config=rl_cluster_lib.RLTrainingConfig(
#         actor_optimizer=optimizer,
#         eval_every_n_steps=EVAL_EVERY_N_STEPS,
#         max_steps=20,
#         mini_batch_size=MINI_BATCH_SIZE,
#         train_micro_batch_size = 1,  # larger than 1 will cause OOM on HBM
#         # metrics logging
#         metrics_logging_options=metrics_logging_options,
#         # checkpoint saving
#         checkpoint_root_directory=CKPT_DIR,
#         checkpointing_options=checkpointing_options,
#     ),
#     rollout_config=base_rollout.RolloutConfig(
#         max_tokens_to_generate=TOTAL_GENERATION_STEPS,
#         max_prompt_length=MAX_PROMPT_LENGTH,
#         kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
#         temperature=TEMPERATURE,
#         top_p=TOP_P,
#         top_k=TOP_K,
#         eos_tokens=[tokenizer.encode("<|im_end|>")[0]],
#         # sglang-jax specific configs
#         rollout_sglang_jax_model_version="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
#         rollout_sglang_jax_mem_fraction_static=0.2,
#         rollout_sglang_jax_init_with_random_weights=True,
#         rollout_sglang_jax_disable_radix_cache=True,
#         rollout_sglang_jax_enable_deterministic_sampling=False,
#         rollout_sglang_jax_precompile_bs_paddings=[1, 2],
#         rollout_sglang_jax_precompile_token_paddings=[2048, 4096, 8192],
#         rollout_sglang_jax_chunked_prefill_size=2048,
#         rollout_sglang_jax_page_size=64,
#     ),
# )

# rl_cluster = rl_cluster_lib.RLCluster(
#     actor=qwen2_actor,
#     reference=qwen2_ref,
#     tokenizer=tokenizer,
#     cluster_config=cluster_config,
# )

In [13]:
from swe_agent import SWEAgent
from swe_env import SWEEnv
from tunix.rl.agentic.trajectory import trajectory_collect_engine
from tunix.rl.agentic.parser.chat_template_parser.parser import QwenChatTemplateParser
from tunix.rl.agentic.rewards.reward_types import RewardOutput

chat_parser = QwenChatTemplateParser(tokenizer)

# def model_call(chat_lists, rl_cluster):
#     result = rl_cluster.generate(
#         prompts=chat_lists,
#         apply_chat_template=True,
#         mode=rl_cluster_lib.Mode.TRAIN,
#     )
#     return result.text[0]

def model_call(chat_completions, _):
    p = chat_parser.parse(chat_completions)
    out = sampler(p, max_generation_steps=128, echo=False)
    return out.text[0]

agent = SWEAgent()
env = SWEEnv(entry=entries[0])

print(chat_parser.parse(agent.chat_completions))

engine = trajectory_collect_engine.TrajectoryCollectEngine(
    agent=agent,
    env=env,
    model_call=model_call,
    final_reward_fn=lambda x, y: RewardOutput(reward=0, metadata={}),
    max_steps=5,
    gamma=0.9,
)


# res = await engine.collect(mode="Trajectory")

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


<|im_start|>system
You are a programming agent who is provided a github issue and repository bash environment and is tasked to solve certain tasks (e.g., file localization, testcase generation, code repair and editing etc) to resolve the issue.

We have access to the following functions:

‚Äì‚Äì BEGIN FUNCTION #1: file_editor ‚Äì‚Äì
Description:
Custom editing tool for viewing, creating and editing files
  ‚Ä¢	State is persistent across command calls and discussions with the user
  ‚Ä¢	If path is a file, view displays the result of applying cat -n. If path is a directory, view lists non-hidden files and directories up to 2 levels deep
  ‚Ä¢	The create command cannot be used if the specified path already exists as a file
  ‚Ä¢	If a command generates a long output, it will be truncated and marked with <response clipped>
  ‚Ä¢	The undo_edit command will revert the last edit made to the file at path

Notes for using the str_replace command:
  ‚Ä¢	The old_str parameter should match EXACTLY 

In [14]:
import logging
import sys

# Remove existing handlers to prevent duplicate logs or conflicts
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    stream=sys.stdout,  # Direct logs to standard output (notebook cell)
    level=logging.INFO, # Set the minimum level to INFO
    format="%(asctime)s - %(levelname)s - %(message)s", # Optional: customize the format
    datefmt="%Y-%m-%d %H:%M:%S" # Optional: customize the date format
)

res = await engine.collect(mode='Trajectory')

created env
inital obs: 
**Title:** Context migration fails to remove incompatible contexts, causing initialization errors

**Description:**
When initializing the `ContextHandler` with a mix of compatible and incompatible contexts, the migration process does not remove the incompatible contexts as expected. Instead, it raises an `IncompatibleContext` error, preventing successful initialization.

**Example Code:**
```python
handler = ContextHandler()
handler.bind(SimpleWidget)

widget = SimpleWidget()
contexts = [Context(foo=i) for i in (13, 13, 0, 1, 13, 2, 13)]

def migrate_context(context, _):
    if context.foo == 13:
        raise IncompatibleContext()

handler.initialize(widget, dict(context_settings=contexts))
# Expected: Incompatible contexts with foo=13 should be removed
# Actual: IncompatibleContext error is raised, and contexts are not removed
```

**Expected Behavior:**
During initialization, contexts that are incompatible (e.g., those that cause `IncompatibleContext` to be 

In [None]:
print(env.total_steps)
res.steps[0].model_response


0


"</im_end>\n\n</think>\n\nTo fix the issue where incompatible contexts are not being removed during context migration, we need to modify the `ContextHandler` class to check for incompatible contexts before initializing. Here's the step-by-step solution:\n\n1. **Modify the `initialize` method**:\n   - Add a check to see if any context in the provided contexts has a `foo` value that matches any context in the handler.\n   - If a match is found, remove that context from the handler.\n\nHere's the code change:\n\n```python\nclass ContextHandler:\n    def __init__(self, context_settings=None):\n        self.context"

# Random stuff for debugging

In [None]:
# runtime = DockerRuntime(ds=entries[0], command=["/bin/bash", "-l"], logger=get_logger(), backend="kubernetes", id=IDS[0])
# runtime.get_task_instruction()

In [None]:
# runtime.run(code="ls -l")
# runtime.stop_container()

In [None]:
# DOCKER_PATH = "/root/.venv/bin:/root/.local/bin:/root/.cargo/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
# pod_name = "tsbao-test-cpu-pod"
# docker_image = entries[0]["docker_image"]
# command = "/bin/bash"

# env_vars = {"PATH": DOCKER_PATH}
# env_spec = [{"name": k, "value": str(v)} for k, v in env_vars.items()]
# pod_body = {
#     "apiVersion": "v1",
#     "kind": "Pod",
#     "metadata": {"name": pod_name},
#     "spec": {
#         "restartPolicy": "Never",
#         "containers": [
#             {
#                 "name": pod_name,
#                 "image": docker_image,
#                 "command": ["/bin/sh", "-c"],
#                 "args": [command] if isinstance(command, str) else command,
#                 "stdin": True,
#                 "tty": True,
#                 "env": env_spec,
#                 "resources": {
#                     "requests": {"cpu": "1", "memory": "1Gi"},
#                 },
#             }
#         ],
#         "imagePullSecrets": [{"name": "dockerhub-pro"}],
#         "nodeSelector": {"cloud.google.com/gke-nodepool": "tsbao-cpu-pool"},
#         "tolerations": [
#             {
#                 "key": "node.kubernetes.io/disk-pressure",
#                 "operator": "Exists",
#                 "effect": "NoExecute",
#                 "tolerationSeconds": 10800
#             }
#         ],
#     },
# }

pod = k8s_client.create_namespaced_pod(
    namespace="default", body=pod_body, _request_timeout=60,
)

In [None]:
# k8s_client.list_namespaced_pod(namespace="default")
pod_name = "tsbao-test-pod"
pod = k8s_client.read_namespaced_pod(name=pod_name, namespace="default")
pod.status.phase



'Running'

In [None]:
# from kubernetes.stream import stream

# full_command = ["/bin/sh", "-c", "ls -l"]
# resp = stream(
#     k8s_client.connect_get_namespaced_pod_exec,
#     name=pod_name,
#     namespace="default",
#     command=full_command,
#     stderr=True,
#     stdin=False,
#     stdout=True,
#     tty=False,  # Match docker exec_run settings
#     _preload_content=False,  # Important for streaming
# )
# resp

<kubernetes.stream.ws_client.WSClient at 0x78b19a7390f0>

In [None]:
# combined_chunks = []
# stdout_chunks = []
# stderr_chunks = []
# while resp.is_open():
#     resp.update(timeout=1)  # wait for data
#     if resp.peek_stdout():
#         chunk = resp.read_stdout()
#         stdout_chunks.append(chunk)
#         combined_chunks.append(chunk)
#     if resp.peek_stderr():
#         chunk = resp.read_stderr()
#         stderr_chunks.append(chunk)
#         combined_chunks.append(chunk)
# resp.close()
# exit_code = resp.returncode
# combined_output = "".join(combined_chunks)

In [None]:
# from r2egym.agenthub.agent.commands import ParseCommandBash

# cmd_parser = ParseCommandBash()
# cmds = cmd_parser.parse_command_file("/scratch/git/R2E-Gym/src/r2egym/agenthub/tools/r2egym/file_editor.py")
# cmds[0]

