# Setting Up unsloth and dependencies

In [None]:
%%capture
import os, importlib.util
!pip install --upgrade -qqq uv
if importlib.util.find_spec("torch") is None or "COLAB_" in "".join(os.environ.keys()):
    try: import numpy; get_numpy = f"numpy=={numpy.__version__}"
    except: get_numpy = "numpy"
    !uv pip install -qqq \
        "torch>=2.8.0" "triton>=3.4.0" {get_numpy} torchvision bitsandbytes "transformers==4.56.2" trackio \
        "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \
        "unsloth[base] @ git+https://github.com/unslothai/unsloth" \
        git+https://github.com/triton-lang/triton.git@05b2c186c1b6c9a08375389d5efe9cb4c401c075#subdirectory=python/triton_kernels
elif importlib.util.find_spec("unsloth") is None:
    !uv pip install -qqq unsloth trackio
!uv pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo

# Install the environment and custom openEnv 

In [1]:
%%capture
!pip install -qqq fastapi uvicorn requests
!git clone https://github.com/yogesh1801/OpenEnv.git > /dev/null 2>&1
%cd OpenEnv
!git checkout yogesh-julia-env

In [2]:
import subprocess, sys, os
from pathlib import Path
sys.path.insert(0, './src')
working_directory = str(Path.cwd().parent.absolute() / "OpenEnv")

# Setting up a Model using unsloth which we will train

In [3]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 1500
lora_rank = 10
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Llama-3.2-3B-Instruct",
    fast_inference = True,
    load_in_4bit = False,
    max_seq_length = max_seq_length,
    gpu_memory_utilization = 0.85,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
#### Unsloth: `hf_xet==1.1.10` and `ipykernel>6.30.1` breaks progress bars. Disabling for now in XET.
#### Unsloth: To re-enable progress bars, please downgrade to `ipykernel==6.30.1` or wait for a fix to
https://github.com/huggingface/xet-core/issues/526
INFO 10-27 14:00:20 [__init__.py:225] Automatically detected platform rocm.
🦥 Unsloth Zoo will now patch everything to make training faster!
Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.


[2025-10-27 14:00:23] INFO vllm_utils.py:752: Unsloth: Patching vLLM


INFO 10-27 14:00:23 [vllm_utils.py:694] Unsloth: Patching vLLM v1 graph capture
Unsloth: Could not patch vLLM V0 graph capture: No module named 'vllm.worker'
==((====))==  Unsloth 2025.10.9: Fast Llama patching. Transformers: 4.56.2. vLLM: 0.11.1rc3.dev39+gf417746ad.rocm700.
   \\   /|    . Num GPUs = 1. Max memory: 191.688 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0a0+git1c57644. ROCm Toolkit: 7.0.51831-a3e329ad8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/Llama-3.2-3B-Instruct with actual GPU utilization = 84.85%
Unsloth: Your GPU has CUDA compute capability 9.4 with VRAM = 191.69 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 1500. Num Sequences = 400.
Unsloth: vLLM's KV Cache can use up to 156.48 GB. Also swap space = 6 GB.
Unsloth: Not an error,

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]


INFO 10-27 14:00:30 [default_loader.py:314] Loading weights took 2.55 seconds
INFO 10-27 14:00:30 [punica_selector.py:20] Using PunicaWrapperGPU.
INFO 10-27 14:00:30 [gpu_model_runner.py:2917] Model loading took 7.0195 GiB and 3.054333 seconds
INFO 10-27 14:00:35 [backends.py:609] Using cache directory: /root/.cache/vllm/torch_compile_cache/9f1bce1ea9/rank_0_0/backbone for vLLM's torch.compile
INFO 10-27 14:00:35 [backends.py:623] Dynamo bytecode transform time: 4.72 s
INFO 10-27 14:00:37 [backends.py:207] Directly load the compiled graph(s) for dynamic shape from the cache, took 1.300 s
INFO 10-27 14:00:51 [monitor.py:34] torch.compile takes 6.02 s in total
INFO 10-27 14:00:51 [gpu_worker.py:337] Available KV cache memory: 153.07 GiB
INFO 10-27 14:00:52 [kv_cache_utils.py:1229] GPU KV cache size: 1,433,040 tokens
INFO 10-27 14:00:52 [kv_cache_utils.py:1234] Maximum concurrency for 1,500 tokens per request: 952.82x
INFO 10-27 14:00:52 [vllm_utils.py:699] Unsloth: Running patched vLLM v

Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):   0%|                                                                                                                                                                                         | 0/134 [00:00<?, ?it/s]



Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:07<00:00, 18.58it/s]
Capturing CUDA graphs (decode, FULL): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 106/106 [00:06<00:00, 17.51it/s]

INFO 10-27 14:01:05 [gpu_model_runner.py:3843] Graph capturing finished in 13 secs, took 0.83 GiB
INFO 10-27 14:01:05 [vllm_utils.py:706] Unsloth: Patched vLLM v1 graph capture finished in 13 secs.





INFO 10-27 14:01:05 [core.py:238] init engine (profile, create kv cache, warmup model) took 35.25 seconds
INFO 10-27 14:01:06 [llm.py:343] Supported tasks: ('generate',)
Unsloth: Just some info: will skip parsing ['post_feedforward_layernorm', 'input_layernorm', 'ffn_norm', 'norm2', 'layer_norm1', 'post_attention_layernorm', 'q_norm', 'norm1', 'post_layernorm', 'attention_norm', 'k_norm', 'layer_norm2', 'pre_feedforward_layernorm']


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Unsloth: Just some info: will skip parsing ['post_feedforward_layernorm', 'input_layernorm', 'ffn_norm', 'norm2', 'layer_norm1', 'post_attention_layernorm', 'cross_attn_input_layernorm', 'q_norm', 'norm1', 'post_layernorm', 'attention_norm', 'k_norm', 'layer_norm2', 'cross_attn_post_attention_layernorm', 'pre_feedforward_layernorm']


In [4]:
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank, # *2 speeds up training
    use_gradient_checkpointing = "unsloth", # Reduces memory usage
    random_state = 3407,
)

Unsloth 2025.10.9 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


# Connecting to julia Env exposed by OpenEnv

In [5]:
from envs.julia_env import JuliaEnv
from envs.julia_env.models import JuliaAction, JuliaObservation

from envs.r_env import REnv
from envs.r_env.models import RAction, RObservation

1. It is not recommended to directly open a subprocess in python instead expose OpenEnv using a docker container because u know that is the whole point of OpenEnv

### First, build the OpenEnv base image (one-time setup):

```bash
# From OpenEnv root directory
docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
```

### Build Julia Environment Image

```bash
# From OpenEnv root directory
docker build -t julia-env:latest -f src/envs/julia_env/server/Dockerfile .
```

### Run the Server

```bash
# Run in background
docker run -d -p 8000:8000 --name julia-env-server julia-env:latest

# OR run in foreground (to see logs)
docker run -p 8000:8000 --name julia-env-server julia-env:latest
```


Note this has to be done in system terminal, since this notebook is hosted inside the docker running these commands in terminal exposed by this notebook is illogical as you cannot run docker inside a docker

In [6]:
port = "8000"
localhost = f"http://localhost:{port}"
# openenv_process = subprocess.Popen(
#     [sys.executable, "-m", "uvicorn", "envs.julia_env.server.app:app", "--host", "0.0.0.0", "--port", port],
#     env = {**os.environ,
#          "PYTHONPATH": f"{working_directory}/src",
#          },
#     stdout = subprocess.PIPE,
#     stderr = subprocess.PIPE,
#     text = True,
#     cwd = working_directory,
# )


# used docker images instead

In [7]:
import requests
import time
print(requests.get(f"{localhost}/health", timeout = 2).content)
openenv_process = JuliaEnv(base_url = localhost)

b'{"status":"healthy"}'


In [8]:
result = openenv_process.reset()
current_state = result.observation
current_state

action = JuliaAction(core_code="""
module MyProject

export add

function add(a::Number, b::Number)
    return a + b
end

end""", test_code = """
using Test
using .MyProject

@testset "Add function Tests" begin
    @test MyProject.add(1, 2) == 3
    @test MyProject.add(1, 2) == 3
    @test Project.add(1, 2) == 3
end
""")

result = openenv_process.step(action)
result

StepResult(observation=JuliaObservation(done=False, reward=None, metadata={}, stdout='Add function Tests: Error During Test at /tmp/tmpirgrv23d.jl:19\n  Test threw exception\n  Expression: Project.add(1, 2) == 3\n  UndefVarError: `Project` not defined\n  Stacktrace:\n   [1] macro expansion\n     @ ~/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/share/julia/stdlib/v1.10/Test/src/Test.jl:669 [inlined]\n   [2] macro expansion\n     @ /tmp/tmpirgrv23d.jl:19 [inlined]\n   [3] macro expansion\n     @ ~/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]\n   [4] top-level scope\n     @ /tmp/tmpirgrv23d.jl:17\nTest Summary:      | Pass  Error  Total  Time\nAdd function Tests |    2      1      3  0.7s\n', stderr='ERROR: LoadError: Some tests did not pass: 2 passed, 0 failed, 1 errored, 0 broken.\nin expression starting at /tmp/tmpirgrv23d.jl:16\n', exit_code=1, tests_passed=2, tests_failed=1, code_compiles=True), reward=7, done=False)

# Training starts here

In [9]:
julia_code_gen_prompt = """
You are a precise and pragmatic Julia programmer.

Write a **single Julia function** that correctly solves the problem described below.

Rules:
- The code must be syntactically correct and runnable as is.
- Do not use arrow functions, ternary operators, or modern syntax that may cause issues.
- Use only the Julia standard library.
- Do **not** wrap the code in a module or add a `main` function.
- Do **not** include any test code in your response.
- Do **not** hardcode specific test cases or outputs — the function must work for general inputs.
- The **function name must exactly match** the one used in the provided tests.
- Respond with **only the Julia function** and nothing else (no explanations, no comments, no extra text)
- The function name must exactly match the one used in the provided tests.
- Return only the Julia function.
- character literal should not contain multiple characters.
- take care of object types and mind that spaces matter in julia so cannot add random spaces

Passing tests and clean, compilable code are rewarded. Hardcoding or failing tests is penalized.

Test Reference (for context only, do not include in the output):
{julia_test}

Code:
""".strip()


In [10]:
import re
def remove_ticks(text):
    text = re.sub(r'^```julia\s*\n?', '', text, flags=re.IGNORECASE)
    text = re.sub(r'\n?```\s*$', '', text)

    return text

In [11]:
from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(
    temperature = 1.0,
    learning_rate = 5e-5,
    weight_decay = 0.01,
    warmup_ratio = 0.1,
    lr_scheduler_type = "linear",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1,
    num_generations = 3,
    max_prompt_length = 2048,
    max_completion_length = 1024,
    num_train_epochs = 1,
    max_steps = 100,
    save_steps = 250,
    report_to = "trackio",
    output_dir = "outputs",

    # For optional training + evaluation
    # fp16_full_eval = True,
    # per_device_eval_batch_size = 4,
    # eval_accumulation_steps = 1,
    # eval_strategy = "steps",
    # eval_steps = 1,
)

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 3


In [12]:
import pandas as pd
from datasets import Dataset

df = pd.read_parquet("../julia_dataset.parquet")

df = df[["julia_prompt", "julia_test"]]
dataset = Dataset.from_pandas(df)

dataset = dataset.map(lambda x: {
    "prompt": [
        {"role" : "system", "content": julia_code_gen_prompt.format(julia_test=x["julia_test"])},
        {"role" : "user", "content": x["julia_prompt"]},
    ],
})
print(dataset[0])

Map:   0%|          | 0/1247 [00:00<?, ? examples/s]

{'julia_prompt': 'Implement a function `echo_nums(x, y)` that takes two integers, `x` and `y`, and returns a vector of all numerical values within the range from `x` to `y`, inclusive. The function should handle cases where `x` is greater than `y` by returning an empty vector.', 'julia_test': 'using Test\n\n@testset "my test" begin\n@test echo_nums(1, 5) == [1, 2, 3, 4, 5]\n@test echo_nums(10, 1) == Int[]\n@test echo_nums(5, 5) == [5]\nend', 'prompt': [{'content': 'You are a precise and pragmatic Julia programmer.\n\nWrite a **single Julia function** that correctly solves the problem described below.\n\nRules:\n- The code must be syntactically correct and runnable as is.\n- Do not use arrow functions, ternary operators, or modern syntax that may cause issues.\n- Use only the Julia standard library.\n- Do **not** wrap the code in a module or add a `main` function.\n- Do **not** include any test code in your response.\n- Do **not** hardcode specific test cases or outputs — the function m

In [13]:
def julia_env_reward(completions, **kwargs):
    rewards = []
    test_codes_list = kwargs.get('julia_test', [])
    
    for i, completion in enumerate(completions):
        response = completion[0]["content"]
        core_code = remove_ticks(response)
        test_code = test_codes_list[i]
        
        try:
            result = openenv_process.reset()
            action = JuliaAction(core_code=core_code, test_code=test_code)
            result = openenv_process.step(action)

            reward = result.reward if result.reward is not None else 0.0
            
        except Exception as e:
            print(f"Error for completion {i}:", e)
            reward = 0.0
        
        rewards.append(reward)
    
    return rewards

In [14]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    remove_unused_columns = False,
    reward_funcs = [julia_env_reward],
    args = training_args,
    train_dataset = dataset,

    # For optional training + evaluation
    # train_dataset = new_dataset["train"],
    # eval_dataset = new_dataset["test"],
)

In [15]:
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 128004}.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,247 | Num Epochs = 1 | Total steps = 100
O^O/ \_/ \    Batch size per device = 3 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (3 x 1 x 1) = 3
 "-____-"     Trainable parameters = 15,196,160 of 3,227,945,984 (0.47% trained)


* Trackio project initialized: huggingface
* Trackio metrics logged to: /root/.cache/huggingface/trackio


* Created new run: lucky-marsh-31


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / julia_env_reward / mean,rewards / julia_env_reward / std
1,0.0,-1.333333,0.57735,105.666672,92.0,126.0,0.0,105.666672,92.0,126.0,0,0,0,0,0,0.000316,-1.333333,0.57735
2,0.0,3.0,8.660254,109.333336,105.0,114.0,0.0,109.333336,105.0,114.0,No Log,No Log,No Log,No Log,No Log,0.0,3.0,8.660254
3,0.0,-2.0,0.0,87.0,57.0,108.0,0.0,87.0,57.0,108.0,No Log,No Log,No Log,No Log,No Log,0.000375,-2.0,0.0
4,0.0,1.333333,4.932883,190.666672,147.0,248.0,0.0,190.666672,147.0,248.0,No Log,No Log,No Log,No Log,No Log,0.000318,1.333333,4.932883
5,0.0,-1.333333,0.57735,46.333336,27.0,82.0,0.0,46.333336,27.0,82.0,No Log,No Log,No Log,No Log,No Log,0.000302,-1.333333,0.57735
6,0.0,-1.666667,0.57735,55.333336,36.0,74.0,0.0,55.333336,36.0,74.0,No Log,No Log,No Log,No Log,No Log,0.000811,-1.666667,0.57735
7,0.0,13.0,0.0,27.0,27.0,27.0,0.0,27.0,27.0,27.0,No Log,No Log,No Log,No Log,No Log,5.2e-05,13.0,0.0
8,0.0,-2.0,0.0,173.333344,130.0,228.0,0.0,173.333344,130.0,228.0,No Log,No Log,No Log,No Log,No Log,0.000455,-2.0,0.0
9,0.0,-1.666667,0.57735,122.333336,120.0,126.0,0.0,122.333336,120.0,126.0,No Log,No Log,No Log,No Log,No Log,0.000267,-1.666667,0.57735
10,0.0,-2.0,0.0,148.333344,124.0,175.0,0.0,148.333344,124.0,175.0,No Log,No Log,No Log,No Log,No Log,0.00026,-2.0,0.0


Unsloth: Will smartly offload gradients to save VRAM!
* Run finished. Uploading logs to Trackio (please wait...)


TrainOutput(global_step=100, training_loss=8.766921057841159e-06, metrics={'train_runtime': 602.1391, 'train_samples_per_second': 0.498, 'train_steps_per_second': 0.166, 'total_flos': 0.0, 'train_loss': 8.766921057841159e-06})

# We start R training on same model here

In [16]:
port = "8002"
localhost = f"http://localhost:{port}"
# openenv_process = subprocess.Popen(
#     [sys.executable, "-m", "uvicorn", "envs.r_env.server.app:app", "--host", "0.0.0.0", "--port", port],
#     env = {**os.environ,
#          "PYTHONPATH": f"{working_directory}/src",
#          },
#     stdout = subprocess.PIPE,
#     stderr = subprocess.PIPE,
#     text = True,
#     cwd = working_directory,
# )


# used docker images instead

In [17]:
import requests
import time
print(requests.get(f"{localhost}/health", timeout = 2).content)
openenv_process = REnv(base_url = localhost)

b'{"status":"healthy"}'


In [18]:
action = RAction(
    core_code = """
add <- function(a, b) {
  return(a + b)
}
""",
    test_code = """
library(testthat)

test_that("add works for positive numbers", {
  expect_equal(add(1, 2), 3)
})

test_that("add works for zeros", {
  expect_equal(add(0, 0), 0)
})

test_that("add works for negative numbers", {
  expect_equal(add(-1, -2), -3)
})
"""
)

result = openenv_process.step(action)
result

StepResult(observation=RObservation(done=False, reward=None, metadata={}, stdout='\n══ Testing tmp_mznnjka.R ═══════════════════════════════════════════════════════\n\n[ FAIL 0 | WARN 0 | SKIP 0 | PASS 0 ]\n[ FAIL 0 | WARN 0 | SKIP 0 | PASS 1 ]\n[ FAIL 0 | WARN 0 | SKIP 0 | PASS 2 ]\n[ FAIL 0 | WARN 0 | SKIP 0 | PASS 3 ] Done!\n', stderr='', exit_code=0, tests_passed=3, tests_failed=0, code_compiles=True), reward=13, done=False)

In [19]:
R_code_gen_prompt = """
You are a precise and pragmatic R programmer.
Write a **single R function** that correctly solves the problem described below.
Rules:
- The code must be syntactically correct and runnable as is.
- Use traditional R function syntax with proper braces and return statements.
- Use only base R (no external packages).
- Do **not** wrap the code in additional scaffolding or add a `main` function.
- Do **not** include any test code in your response.
- Do **not** hardcode specific test cases or outputs — the function must work for general inputs.
- The **function name must exactly match** the one used in the provided tests.
- Respond with **only the R function** and nothing else (no explanations, no comments, no extra text)
- The function name must exactly match the one used in the provided tests.
- Return only the R function.
- Use proper R syntax: assignment with <-, proper vectorization, and correct indexing (1-based).
- Pay attention to data types: use as.integer(), as.numeric(), as.character() when needed.
- Remember R is 1-indexed, not 0-indexed.
Passing tests and clean, runnable code are rewarded. Hardcoding or failing tests is penalized.
Test Reference (for context only, do not include in the output):
{r_test}
Code:
""".strip()

In [20]:
import re
def remove_ticks(text):
    text = re.sub(r'^```r\s*\n?', '', text, flags=re.IGNORECASE)
    text = re.sub(r'\n?```\s*$', '', text)

    return text

In [21]:
def r_env_reward(completions, **kwargs):
    rewards = []
    test_codes_list = kwargs.get('r_test', [])
    
    for i, completion in enumerate(completions):
        response = completion[0]["content"]
        core_code = remove_ticks(response)
        test_code = test_codes_list[i]
        try:
            result = openenv_process.reset()
            action = RAction(core_code=core_code, test_code=test_code)
            result = openenv_process.step(action)

            reward = result.reward if result.reward is not None else 0.0
            
        except Exception as e:
            print(f"Error for completion {i}:", e)
            reward = 0.0
        
        rewards.append(reward)
    
    return rewards

In [22]:
import pandas as pd
from datasets import Dataset

df = pd.read_parquet("../r_dataset.parquet")

df = df[["r_prompt", "r_test"]]
dataset = Dataset.from_pandas(df)

dataset = dataset.map(lambda x: {
    "prompt": [
        {"role" : "system", "content": R_code_gen_prompt.format(r_test=x["r_test"])},
        {"role" : "user", "content": x["r_prompt"]},
    ],
})
print(dataset[0])

Map:   0%|          | 0/1224 [00:00<?, ? examples/s]

{'r_prompt': 'Implement a function `echo_nums(x, y)` that takes two integers, `x` and `y`, and returns a vector of all numerical values within the range from `x` to `y`, inclusive. The function should handle cases where `x` is greater than `y` by returning an empty vector.', 'r_test': 'library(testthat)\ntest_that("echo_nums returns correct sequence when x is less than y", {\n  expect_equal(echo_nums(1, 3), c(1, 2, 3))\n})\ntest_that("echo_nums returns correct sequence when x is equal to y", {\n  expect_equal(echo_nums(5, 5), c(5))\n})\ntest_that("echo_nums returns empty vector when x is greater than y", {\n  expect_equal(echo_nums(10, 5), numeric(0))\n})', 'prompt': [{'content': 'You are a precise and pragmatic R programmer.\nWrite a **single R function** that correctly solves the problem described below.\nRules:\n- The code must be syntactically correct and runnable as is.\n- Use traditional R function syntax with proper braces and return statements.\n- Use only base R (no external p

In [23]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    remove_unused_columns = False,
    reward_funcs = [r_env_reward],
    args = training_args,
    train_dataset = dataset,

    # For optional training + evaluation
    # train_dataset = new_dataset["train"],
    # eval_dataset = new_dataset["test"],
)

In [24]:
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,224 | Num Epochs = 1 | Total steps = 100
O^O/ \_/ \    Batch size per device = 3 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (3 x 1 x 1) = 3
 "-____-"     Trainable parameters = 15,196,160 of 3,227,945,984 (0.47% trained)


* Trackio project initialized: huggingface
* Trackio metrics logged to: /root/.cache/huggingface/trackio


* Created new run: modest-cliff-32


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / r_env_reward / mean,rewards / r_env_reward / std
1,0.0,28.0,25.980764,90.0,74.0,111.0,0.0,90.0,74.0,111.0,0,0,0,0,0,0.001513,28.0,25.980762
2,0.0,3.333333,8.386497,74.0,18.0,113.0,0.0,74.0,18.0,113.0,No Log,No Log,No Log,No Log,No Log,0.004141,3.333333,8.386497
3,0.0,22.0,0.0,119.0,107.0,127.0,0.0,119.0,107.0,127.0,No Log,No Log,No Log,No Log,No Log,0.00192,22.0,0.0
4,0.0,6.333333,13.576941,90.666672,61.0,107.0,0.0,90.666672,61.0,107.0,No Log,No Log,No Log,No Log,No Log,0.001282,6.333333,13.576941
5,0.0,1.666667,2.309401,52.333336,45.0,67.0,0.0,52.333336,45.0,67.0,No Log,No Log,No Log,No Log,No Log,0.000815,1.666667,2.309401
6,0.0,1.0,5.196153,55.333336,50.0,61.0,0.0,55.333336,50.0,61.0,No Log,No Log,No Log,No Log,No Log,0.014932,1.0,5.196152
7,0.0,0.0,2.645751,155.666672,109.0,213.0,0.0,155.666672,109.0,213.0,No Log,No Log,No Log,No Log,No Log,0.002093,0.0,2.645751
8,0.0,13.666667,12.701706,71.666672,57.0,97.0,0.0,71.666672,57.0,97.0,No Log,No Log,No Log,No Log,No Log,0.006268,13.666667,12.701706
9,0.0,-1.0,0.0,41.666668,29.0,67.0,0.0,41.666668,29.0,67.0,No Log,No Log,No Log,No Log,No Log,0.040002,-1.0,0.0
10,0.0,-1.666667,0.57735,156.333344,123.0,196.0,0.0,156.333344,123.0,196.0,No Log,No Log,No Log,No Log,No Log,0.000937,-1.666667,0.57735


* Run finished. Uploading logs to Trackio (please wait...)


TrainOutput(global_step=100, training_loss=4.5361398246868136e-05, metrics={'train_runtime': 280.266, 'train_samples_per_second': 1.07, 'train_steps_per_second': 0.357, 'total_flos': 0.0, 'train_loss': 4.5361398246868136e-05})

# Post training and inferencing

In [25]:
model.save_lora("grpo_saved_lora_unified")

In [26]:
from safetensors import safe_open

tensors = {}
with safe_open("grpo_saved_lora_unified/adapter_model.safetensors", framework = "pt") as f:
    # Verify both A and B are non zero
    for key in f.keys():
        tensor = f.get_tensor(key)
        n_zeros = (tensor == 0).sum() / tensor.numel()
        assert(n_zeros.item() != tensor.numel())

In [27]:
import pandas as pd
from datasets import Dataset

df = pd.read_parquet("../julia_dataset.parquet")

df = df[["julia_prompt", "julia_test", "task_id"]]
dataset = Dataset.from_pandas(df)

MAX_LEN = 1500

dataset = dataset.map(lambda x: {
    "prompt": [
        {"role": "system", "content": julia_code_gen_prompt.format(julia_test=x["julia_test"])[:MAX_LEN]},
        {"role": "user", "content": x["julia_prompt"][:MAX_LEN]},
    ],
})

Map:   0%|          | 0/1247 [00:00<?, ? examples/s]

In [None]:
from vllm import SamplingParams

sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)

results = []

for i in range(len(dataset)):
    task_id = dataset[i]["task_id"]
    julia_test = dataset[i]["julia_test"]
    message = dataset[i]["prompt"]
    message[1]

    text = tokenizer.apply_chat_template(
        message,
        add_generation_prompt=True,
        tokenize=False,
    )

    response = model.fast_generate(
        text,
        sampling_params=sampling_params,
        lora_request=model.load_lora("grpo_saved_lora_unified"),
    )

    julia_code = response[0].outputs[0].text
    julia_code = remove_ticks(julia_code)

    results.append({
        "task_id": task_id,
        "julia_code": julia_code,
        "julia_test": julia_test
    })

result_df = pd.DataFrame(results)
result_df.to_parquet("julia_code_output_smaller_model.parquet", index=False)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                                      …

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
import pandas as pd
from datasets import Dataset

df = pd.read_parquet("../r_dataset.parquet")

df = df[["r_prompt", "r_test", "task_id"]]
dataset = Dataset.from_pandas(df)

MAX_LEN = 1500

dataset = dataset.map(lambda x: {
    "prompt": [
        {"role": "system", "content": R_code_gen_prompt.format(r_test=x["r_test"])[:MAX_LEN]},
        {"role": "user", "content": x["r_prompt"][:MAX_LEN]},
    ],
})

In [None]:
from vllm import SamplingParams

sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)

results = []

for i in range(len(dataset)):
    task_id = dataset[i]["task_id"]
    r_test = dataset[i]["r_test"]
    message = dataset[i]["prompt"]
    message[1]

    text = tokenizer.apply_chat_template(
        message,
        add_generation_prompt=True,
        tokenize=False,
    )

    response = model.fast_generate(
        text,
        sampling_params=sampling_params,
        lora_request=model.load_lora("grpo_saved_lora_unified"),
    )

    r_code = response[0].outputs[0].text
    r_code = remove_ticks(r_code)

    results.append({
        "task_id": task_id,
        "r_code": r_code,
        "r_test": r_test
    })

result_df = pd.DataFrame(results)
result_df.to_parquet("r_code_output_smaller_model.parquet", index=False)