In [1]:
BASE_ID = "unsloth/Llama-3.2-3B-Instruct"

In [2]:
from unsloth import FastLanguageModel
import torch
from random import randint
max_seq_length = 2048 # Can increase for longer RL output
lora_rank = 128        # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Llama-3.2-3B-Instruct",
    load_in_4bit = False,
    max_seq_length = max_seq_length,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth Zoo will now patch everything to make training faster!
Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.
==((====))==  Unsloth 2025.10.9: Fast Llama patching. Transformers: 4.56.2.
   \\   /|    AMD Radeon Graphics. Num GPUs = 1. Max memory: 191.688 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+rocm6.4. ROCm Toolkit: 6.4.43482-0f2d60242. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards: 100% 2/2 [00:02<00:00,  1.20s/it]


In [3]:
model_policy = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)
model_policy.print_trainable_parameters()

Unsloth 2025.10.9 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


trainable params: 194,510,848 || all params: 3,407,260,672 || trainable%: 5.7087


In [4]:
import os, sys, subprocess

working_directory = "/shared-docker/OpenEnv"  # repo root containing src/
port = "8030"
keepalive = "1000"

cmd = [
    sys.executable, "-m", "uvicorn",
    "envs.cartpole_env.server.app:app",  # <— our new app path
    "--host", "0.0.0.0",
    "--port", port,
    "--timeout-keep-alive", keepalive,
]

env = {
    **os.environ,
    "PYTHONPATH": f"{working_directory}/src",

    # CartPole knobs (analogous to OPENSPIEL_*):
    "CARTPOLE_ENV_ID": "CartPole-v1",
    "CARTPOLE_SEED": "123",
    "CARTPOLE_MAX_EPISODE_STEPS": "500",
    "CARTPOLE_RENDER_MODE": "none",  # or "rgb_array"
}

#proc = subprocess.Popen(cmd, env=env, cwd=working_directory,
#                        stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

#print("CartPole server starting at http://localhost:8020 … PID:", proc.pid)


In [5]:
import httpx, time
time.sleep(1)
r = httpx.post("http://localhost:8030/reset", json={}); print(r.json())

INFO:httpx: HTTP Request: POST http://localhost:8030/reset "HTTP/1.1 200 OK"


{'observation': {'state': [0.015229926444590092, -0.04562246799468994, -0.047997042536735535, 0.0339212566614151], 'legal_actions': [0, 1], 'episode_length': 0, 'total_reward': 0.0}, 'reward': 0.0, 'done': False}


In [6]:
import requests
import time
time.sleep(5) # Wait 5 seconds for OpenEnv to start!
from envs.cartpole_environment import CartpoleEnv, CartpoleAction
import httpx

# Same shape as: OpenSpielEnv(base_url=..., request_timeout_s=...)
base_url = "http://localhost:8030"
request_timeout_s = 1000  # seconds

openenv_process = openenv_process = CartpoleEnv(
        base_url=base_url,
        request_timeout_s=request_timeout_s
    )
# quick smoke test
# info = openenv_process.info()

# print("CartPole info:", info)

state = openenv_process.reset()
print("reset:", state)

# step with discrete action 0/1
state = openenv_process.step(CartpoleAction(action_id = 0))
print("step:", state)

reset: StepResult(observation=CartpoleObservation(done=False, reward=0.0, metadata={}, state=[-0.037743449211120605, -0.0241886917501688, -0.009422927163541317, 0.04691839590668678], legal_actions=[0, 1], episode_length=0, total_reward=0.0), reward=0.0, done=False)
step: StepResult(observation=CartpoleObservation(done=False, reward=1.0, metadata={}, state=[-0.03822722285985947, -0.21917426586151123, -0.008484559133648872, 0.3366134762763977], legal_actions=[0, 1], episode_length=1, total_reward=1.0), reward=1.0, done=False)


In [7]:
def strategy_simple(state):
    # state = [x, dx, angle, dangle]
    return 0  # left or 1 right


def build_user_prompt():
    return (
        """You are an expert CartPole player and a precise Python code generator.

Context / How this will be used
- Your function will be called every environment step to control the entire episode of CartPole-v1 (OpenAI Gym/Gymnasium style).
- The simulator updates at ~0.02 s per step (≈50 Hz).
- The episode ends early if the pole falls or the cart goes out of bounds; otherwise it caps at the env’s max length.
  - Termination (approx.): |angle| > ~0.209 rad (≈12°) or |x| > 2.4 m.
  - Reward is +1 per step; the goal is to survive as long as possible (ideally to the cap).

Your objective: keep the pole upright and the cart within bounds for the longest possible duration.

What you must write
- A single Python function with this exact signature (no extras):
    def cartpole_strategy(state):
- Input state is a list of 4 floats: [x, dx, angle, dangle]
  - x = cart position (m)
  - dx = cart velocity (m/s)
  - angle = pole angle (rad, 0 is upright; + leans right)
  - dangle = angular velocity (rad/s)
- Output: return an int action — 0 (push left) or 1 (push right).

Design guidance for long-horizon stability
- Prioritize angle correction, then angular velocity damping, and only then center the cart (x, dx) to avoid boundary terminations.
- Use a simple deterministic control law (e.g., a weighted linear rule with a small dead-zone/hysteresis to avoid flapping on noise).
- Keep it short and stateless (no memory): e.g., one or two thresholds or a sign of a weighted sum is fine.
- Avoid overreacting to tiny oscillations; prefer small margins rather than exact limit chasing.
- No stochasticity; identical inputs must produce identical outputs.

Hard constraints
- Do not import, print, read/write files, use globals, randomness, or I/O.
- The output must be exactly one fenced code block in Python, with nothing before or after.
  - The first line inside the block must be: def cartpole_strategy(state):
  - The last line of your entire response must be the closing backticks to clearly end the program. No trailing commentary.

Output format reminder (dummy example — do NOT copy this logic):
```
def cartpole_strategy(state):
    x, dx, ang, dang = state
    # return 0 or 1 using a short, deterministic rule
    return 1 if (0.9*ang + 0.4*dang + 0.05*x + 0.02*dx) > 0 else 0 
        ```
        All helper functions should be inside def cartpole_strategy. Only output the short function `strategy`.
        """.strip())

print(build_user_prompt())

You are an expert CartPole player and a precise Python code generator.

Context / How this will be used
- Your function will be called every environment step to control the entire episode of CartPole-v1 (OpenAI Gym/Gymnasium style).
- The simulator updates at ~0.02 s per step (≈50 Hz).
- The episode ends early if the pole falls or the cart goes out of bounds; otherwise it caps at the env’s max length.
  - Termination (approx.): |angle| > ~0.209 rad (≈12°) or |x| > 2.4 m.
  - Reward is +1 per step; the goal is to survive as long as possible (ideally to the cap).

Your objective: keep the pole upright and the cart within bounds for the longest possible duration.

What you must write
- A single Python function with this exact signature (no extras):
    def cartpole_strategy(state):
- Input state is a list of 4 floats: [x, dx, angle, dangle]
  - x = cart position (m)
  - dx = cart velocity (m/s)
  - angle = pole angle (rad, 0 is upright; + leans right)
  - dangle = angular velocity (rad/s)

In [8]:
def extract_function(text):
    if text.count("```") >= 2:
        first = text.find("```") + 3
        second = text.find("```", first)
        fx = text[first : second].strip()
        fx = fx[fx.find("def"):]
        if fx.startswith("def cartpole_strategy(state):"): return fx
    return None
print(extract_function(build_user_prompt()))

def cartpole_strategy(state):
    x, dx, ang, dang = state
    # return 0 or 1 using a short, deterministic rule
    return 1 if (0.9*ang + 0.4*dang + 0.05*x + 0.02*dx) > 0 else 0


In [9]:
# Minimal safe executor (reuse your nb's create_locked_down_function if available)
from unsloth import create_locked_down_function

from unsloth import check_python_modules

def _safe_compile(func_src: str):
    # Use the notebook's 'create_locked_down_function' if present
    return create_locked_down_function(func_src)



In [10]:
import numpy as np
from random import randint
global _PRINT_COUNTER
_PRINT_COUNTER = 0

    
def execute_strategy(strategy_fn, initial_state, max_steps=500):
    """Run strategy on CartPole server until done or step limit.
       Returns (steps_survived, done_flag)."""
    steps = 0
    done = False
    state = initial_state
    reward = 0
    while not done and steps < max_steps:
        # print("try strategy")
        a = int(strategy_fn(state.observation.state))
        # print(a)
        if a not in (0,1):
            # clamp invalid actions
            a = 0 if a <= 0 else 1
        res = openenv_process.step(CartpoleAction(action_id=a))
        #print(res)
        state = res
        done = bool(res.done)
        steps += 1
        reward += res.reward
    return steps, done


def function_works(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        function = extract_function(response)
        if function is not None:
            ok, info = check_python_modules(function)
        if function is None or "error" in info:
            score = -2.0
        else:
            try:
                new_strategy = _safe_compile(function)
                score = randint(7, 20)
            except:
                function
                score = -0.5
        scores.append(score)
    return scores
    
def strategy_succeeds(completions, **kwargs):
    """completions: list of candidate generations (like the nb).
       Returns list[float] rewards (higher is better)."""
    global _PRINT_COUNTER
    scores = []

    # Reset the env once per candidate to a fresh start
    for completion in completions:
        try:
            # The notebook packs text like completion[0]["content"]
            response = completion[0]["content"]
        except Exception:
            scores.append(0.0); continue

        # Optional: print every 5th for debugging
        if _PRINT_COUNTER % 5 == 0:
            try:
                print(response.splitlines()[0][:120])
            except:
                print("...candidate omitted...")
        _PRINT_COUNTER += 1
        #print(response)
        # Parse the function from the text (reuse nb's helpers if present)
        try:
            func_src = extract_function(response)
            print(func_src)
            strategy_fn = _safe_compile(func_src)
        except Exception as e:
            print("Compile error:", e)
            scores.append(0.0); continue
        print(func_src)
        print(strategy_fn)
        # Rollout on CartPole
        try:
            # Fresh episode
            res0 = openenv_process.reset()
            current_state = res0
            steps, finished = execute_strategy(strategy_fn, current_state, max_steps=10000)

            # Reward = steps survived (CartPole native) with mild shaping:
            # huge bonus if finished by reaching max steps (i.e., perfect 500).
            reward = float(steps)
            if steps >= 1000:
                reward += 1000.0

            # print a brief trace for failing candidates occasionally
            if steps < 20 and (_PRINT_COUNTER % 7 == 0):
                print("Short episode:", steps)

            scores.append(reward)
        except TimeoutError:
            # keep the semantics from the nb
            scores.append(-1.0)
        except Exception as e:
            # print(f"Exception = {str(e)}")
            scores.append(-3.0)

    return scores


In [11]:
toy1 = """
```
def cartpole_strategy(state):
    # state: [x, dx, angle, dangle]
    print(state)
    x, dx, ang, dang = state
    score = ang + 0.1 * dang
    return 1 if score > 0.0 else 0
    ```
"""

toy = """
```
def cartpole_strategy(state):
    return 1
    ```
"""


    
fn1 = extract_function(toy)

fn = _safe_compile(fn1)

s0 = openenv_process.reset();
# print(s0.observation.values[0],s0.observation.values[1],s0.observation.values[2],s0.observation.values[3])
print(fn(s0.observation.state))
done = False

steps, done = execute_strategy(fn, s0)
print("Toy strategy survived steps:", steps)

1
Toy strategy survived steps: 10


In [12]:
from datasets import Dataset
dataset = Dataset.from_list([{"prompt" : [{"role": "user", "content": build_user_prompt().strip()}], "answer" : 0, "reasoning_effort": "low"}]*1000)
maximum_length = len(tokenizer.apply_chat_template([{"role": "user", "content": build_user_prompt().strip()}], add_generation_prompt = True))
print(maximum_length)

643


In [13]:
dataset[0]

{'prompt': [{'content': 'You are an expert CartPole player and a precise Python code generator.\n\nContext / How this will be used\n- Your function will be called every environment step to control the entire episode of CartPole-v1 (OpenAI Gym/Gymnasium style).\n- The simulator updates at ~0.02 s per step (≈50 Hz).\n- The episode ends early if the pole falls or the cart goes out of bounds; otherwise it caps at the env’s max length.\n  - Termination (approx.): |angle| > ~0.209 rad (≈12°) or |x| > 2.4 m.\n  - Reward is +1 per step; the goal is to survive as long as possible (ideally to the cap).\n\nYour objective: keep the pole upright and the cart within bounds for the longest possible duration.\n\nWhat you must write\n- A single Python function with this exact signature (no extras):\n    def cartpole_strategy(state):\n- Input state is a list of 4 floats: [x, dx, angle, dangle]\n  - x = cart position (m)\n  - dx = cart velocity (m/s)\n  - angle = pole angle (rad, 0 is upright; + leans ri

In [18]:
max_prompt_length = maximum_length + 1 # + 1 just in case!
max_completion_length = max_seq_length - max_prompt_length

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    temperature = 1.0,
    learning_rate = 1e-6,
    weight_decay = 0.01,
    warmup_ratio = 0.1,
    lr_scheduler_type = "linear",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 8, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 100,
    save_steps = 100,
    report_to = "trackio", # Can use Weights & Biases, TrackIO
    output_dir = "outputs",

    # For optional training + evaluation
    # fp16_full_eval = True,
    # per_device_eval_batch_size = 4,
    # eval_accumulation_steps = 1,
    # eval_strategy = "steps",
    # eval_steps = 1,
)

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 8


In [19]:
# For optional training + evaluation
# new_dataset = dataset.train_test_split(test_size = 0.01)

trainer = GRPOTrainer(
    model = model_policy,
    processing_class = tokenizer,
    reward_funcs = [
        function_works,
        strategy_succeeds
    ],
    args = training_args,
    train_dataset = dataset,

    # For optional training + evaluation
    # train_dataset = new_dataset["train"],
    # eval_dataset = new_dataset["test"],
)
os.makedirs("adapters", exist_ok=True)
model_policy.save_pretrained("adapters/cartpole-lora-sft")


In [None]:
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,000 | Num Epochs = 1 | Total steps = 100
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 1 x 1) = 8
 "-____-"     Trainable parameters = 194,510,848 of 3,407,260,672 (5.71% trained)


* Trackio project initialized: huggingface
* Trackio metrics logged to: /root/.cache/huggingface/trackio


* Created new run: radiant-wolf-61
def cartpole_strategy(state):
    x, dx, ang, dang = state
    push_right = 1 if (0.9*ang + 0.4*dang + 0.05*x + 0.02*dx) > 0 else 0
    return push_right
def cartpole_strategy(state):
    x, dx, ang, dang = state
    push_right = 1 if (0.9*ang + 0.4*dang + 0.05*x + 0.02*dx) > 0 else 0
    return push_right
<function cartpole_strategy at 0x7879476aba30>
```ócwendungalkerتونnettnett peersIdeervoetàhang CART NobleebecartcartominatedBugnett bulletouv-state ciclo kiệmrema edi
None
Compile error: compile() arg 1 must be a string, bytes or AST object
None
Compile error: compile() arg 1 must be a string, bytes or AST object
None
Compile error: compile() arg 1 must be a string, bytes or AST object
None
Compile error: compile() arg 1 must be a string, bytes or AST object
None
Compile error: compile() arg 1 must be a string, bytes or AST object
```iere Feature yetněmctest.diskidon GreeneiringDirection finde MeanCLSctest�anhStrange diagonalctest.removeAll Gur Tre

Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / function_works / mean,rewards / function_works / std,rewards / strategy_succeeds / mean,rewards / strategy_succeeds / std
1,0.0,61.625,179.958679,1236.875,67.0,1404.0,0.875,67.0,67.0,67.0,0,0,0,0,0,0.000616,-0.875,3.181981,62.5,176.776703
2,0.0,12.625,41.365746,1242.375,111.0,1404.0,0.875,111.0,111.0,111.0,No Log,No Log,No Log,No Log,No Log,0.000708,-0.125,5.303301,12.75,36.062447
3,0.0,22.25,68.589355,1239.5,88.0,1404.0,0.875,88.0,88.0,88.0,No Log,No Log,No Log,No Log,No Log,0.000663,-0.625,3.889087,22.875,64.700272
4,0.0,2.5,12.727922,1253.75,202.0,1404.0,0.875,202.0,202.0,202.0,No Log,No Log,No Log,No Log,No Log,0.00062,-0.25,4.949748,2.75,7.778175
5,0.0,0.75,7.778174,1251.125,181.0,1404.0,0.875,181.0,181.0,181.0,No Log,No Log,No Log,No Log,No Log,0.000732,-0.5,4.24264,1.25,3.535534
6,0.0,24.625,75.30687,1242.625,113.0,1404.0,0.875,113.0,113.0,113.0,No Log,No Log,No Log,No Log,No Log,0.001498,-0.125,5.303301,24.75,70.003571
7,0.0,17.625,55.507881,1239.875,91.0,1404.0,0.875,91.0,91.0,91.0,No Log,No Log,No Log,No Log,No Log,0.001089,-0.375,4.596194,18.0,50.91169
8,0.0,31.375,94.398758,1242.75,114.0,1404.0,0.875,114.0,114.0,114.0,No Log,No Log,No Log,No Log,No Log,0.001903,0.25,6.363961,31.125,88.034798
9,0.0,2.0,11.313708,1239.25,86.0,1404.0,0.875,86.0,86.0,86.0,No Log,No Log,No Log,No Log,No Log,0.001504,0.75,7.778175,1.25,3.535534
10,0.0,44.0,130.107651,1241.625,105.0,1404.0,0.875,105.0,105.0,105.0,No Log,No Log,No Log,No Log,No Log,0.00157,-0.25,4.949748,44.25,125.157906


def cartpole_strategy(state):
    x, dx, ang, dang = state
    if abs(ang) > 0.209:
        return 1 if ang > 0 else 0
    elif abs(x) > 2.4:
        return 0 if dx > 0 else 1
    elif abs(dang) > 0.2:
        return 1 if dang > 0 else 0
    else:
        return 0 if dx > 0 else 1
def cartpole_strategy(state):
    x, dx, ang, dang = state
    if abs(ang) > 0.209:
        return 1 if ang > 0 else 0
    elif abs(x) > 2.4:
        return 0 if dx > 0 else 1
    elif abs(dang) > 0.2:
        return 1 if dang > 0 else 0
    else:
        return 0 if dx > 0 else 1
<function cartpole_strategy at 0x7879459c04c0>
None
Compile error: compile() arg 1 must be a string, bytes or AST object
None
Compile error: compile() arg 1 must be a string, bytes or AST object
```yyysink Pastor Kaftsy={[' reinbjerg�laclest centro fancyUSustilrophy именноandon Pastctest-vectoraginiubeкар﻿
None
Compile error: compile() arg 1 must be a string, bytes or AST object
None
Compile error: compile() arg 1 must be a string, 

In [17]:
os.makedirs("adapters", exist_ok=True)
model_policy.save_pretrained("adapters/cartpole-lora-grpo_trained")
tokenizer.save_pretrained("adapters/cartpole-lora-grpo_trained")

('adapters/cartpole-lora-grpo_trained/tokenizer_config.json',
 'adapters/cartpole-lora-grpo_trained/special_tokens_map.json',
 'adapters/cartpole-lora-grpo_trained/chat_template.jinja',
 'adapters/cartpole-lora-grpo_trained/tokenizer.json')