In [None]:
%%capture
import os

if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install transformers==4.51.3
    !pip install --no-deps unsloth


if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm==0.8.5.post1
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

In [None]:
import subprocess
import os

import urllib.request

# Download Ballerina .deb package
ballerina_url = "https://dist.ballerina.io/downloads/2201.12.7/ballerina-2201.12.7-swan-lake-linux-x64.deb"
deb_filename = "ballerina-2201.12.7-swan-lake-linux-x64.deb"

print("Downloading Ballerina...")
urllib.request.urlretrieve(ballerina_url, deb_filename)
print(f"✅ Downloaded {deb_filename}")

# Install the .deb package
print("Installing Ballerina...")
try:
    subprocess.run(["dpkg", "-i", deb_filename], check=True)
    print("✅ Ballerina installed successfully!")
except subprocess.CalledProcessError as e:
    print(f"Installation failed: {e}")
    print("Trying to fix dependencies...")
    subprocess.run(["sudo", "apt-get", "-f", "install"], check=True)

# Test Ballerina version
print("Testing Ballerina installation...")
try:
    result = subprocess.run(["bal", "-v"], capture_output=True, text=True, check=True)
    print("✅ Ballerina version:")
    print(result.stdout)
except subprocess.CalledProcessError as e:
    print(f"❌ Failed to run 'bal -v': {e}")
except FileNotFoundError:
    print("❌ 'bal' command not found. Installation may have failed.")

# Clean up downloaded file
os.remove(deb_filename)
print(f"🧹 Cleaned up {deb_filename}")

In [None]:
import os
HF_TOKEN = os.environ.get("HF_TOKEN")
WANDB_API_KEY = os.environ.get("wandb_api_key")

print(HF_TOKEN)
print(WANDB_API_KEY)

os.environ["WANDB_API_KEY"] = WANDB_API_KEY

In [None]:
import subprocess
import tempfile
import os
import re
import requests
from pathlib import Path
from uuid import uuid4
from typing import Optional, List

CONFIG = {
    "dataset_url": "https://gist.githubusercontent.com/xlight05/860d56e432adbbcf5428aca45382c2d1/raw/ff442cbd0d3509a16940b28c4e200e554029d7de/combined.json",  # Replace with actual Gist URL
    # "model_name": "xlight05/bal_coder_full_16bit_vllm",
    "model_name" : "unsloth/Qwen2.5-Coder-7B-Instruct",
    "max_seq_length": 2048,
    "lora_rank": 16,
    "learning_rate": 5e-6,
    "max_steps": 10, # change
    "save_steps": 5, # change
    "num_generations": 2,
    "batch_size": 2,
    "gradient_accumulation_steps": 1,
    "num_train_epochs" : 1,
    "run_name": "base_test_1_"
}

def generate_model_name(training_type: str, format_type: str, bits: str = None) -> str:
    """
    Generate a model name based on config and training parameters.
    
    Args:
        training_type: "sft" or "grpo"
        format_type: "vllm" or "gguf"
        bits: "4bit", "8bit", "16bit", or None for GGUF
    
    Returns:
        Formatted model name string
    """
    base_name = CONFIG["run_name"]
    
    # Build the suffix
    suffix_parts = [training_type]
    
    if bits:
        suffix_parts.append(bits)
    
    suffix_parts.append(format_type)
    
    suffix = "_".join(suffix_parts)
    
    # Combine with base name, ensuring proper formatting
    model_name = f"xlight05/{base_name}{suffix}"
    
    return model_name

# Helper function to update Gist URL
def set_dataset_url(gist_url: str):
    """Update the dataset URL in CONFIG"""
    CONFIG["dataset_url"] = gist_url
    print(f"Dataset URL updated to: {gist_url}")

print("✅ Configuration loaded!")
print(f"Model: {CONFIG['model_name']}")
print(f"Dataset URL: {CONFIG['dataset_url']}")
print(f"Run name: {CONFIG['run_name']}")

# # Test the model name generation
# print("\nExample model names:")
# print(f"SFT 16bit VLLM: {generate_model_name('sft', 'vllm', '16bit')}")
# print(f"SFT 8bit VLLM: {generate_model_name('sft', 'vllm', '8bit')}")
# print(f"SFT GGUF: {generate_model_name('sft', 'gguf')}")
# print(f"GRPO 16bit VLLM: {generate_model_name('grpo', 'vllm', '16bit')}")
# print(f"GRPO GGUF: {generate_model_name('grpo', 'gguf')}")

In [None]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048
dtype = None
load_in_4bit = True
lora_rank = 16

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen2.5-Coder-7B-Instruct",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16*2,
    lora_dropout = 0.1,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)


In [None]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "qwen-2.5",
)

def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }
pass

In [None]:
import json
import requests

from datasets import Dataset

TDD_SYSTEM_PROMPT = """You are a pragmatic Ballerina programmer who enjoys test driven development. Given the following question, write a Ballerina function to complete the task and then write the the unit tests to validate the function.

1. Make the code simple and easy to understand.
2. Try to limit library usage to the standard library. Be careful with your types, and try to limit yourself to the basic built in types and standard library functions.
3. Before you start writing the function you can think through how to solve the problem and perform reasoning in the comments above the function.
4. Then write unit tests for the function you defined. Make sure to write at least 4 assertions to test the function. The tests should be a simple.

Strictly follow the following output format for each response: Make sure to include code inside <CODE> and <TESTS> blocks.

# Overview
Brief overview about the solution.

<CODE>
```ballerina
// Reasoning goes here
// and can be multi-line
function add(int a, int b) returns int {
    return a + b;
}
```
</CODE>

<TESTS>
```ballerina
import ballerina/test;

@test:Config { }
function testAssertEquals() {
    int addResult = add(40, 2);
    test:assertEquals(addResult, 42);

    addResult = add(0, 0);
    test:assertEquals(addResult, 0);

    addResult = add(-1, 1);
    test:assertEquals(addResult, 0);

    addResult = add(-5, -5);
    test:assertEquals(addResult, -10);
}
```
</TESTS>

"""

BBE_SYSTEM_PROMPT = """"You are a pragmatic Ballerina programmer. Given the following question, write the code to complete the task.
Strictly follow the following output format for each response: Make sure to include code inside <CODE> blocks.

<CODE>
```ballerina
// Reasoning goes here
// and can be multi-line
function add(int a, int b) returns int {
    return a + b;
}
```
</CODE>
"""

# Define the gist URLs with their corresponding system prompts
gist_configs = [
    {
        "url": "https://gist.githubusercontent.com/xlight05/f8e1e94c7b65c2e34dac70bb27f04f0b/raw/2705c986f2e57c44085360ec9bd0258b23347ce6/bbe_train.json",
        "system_prompt": BBE_SYSTEM_PROMPT
    },
    {
        "url": "https://gist.githubusercontent.com/xlight05/67fcc85b8b549b7919772bc43e9c2fc5/raw/e034ae3f29adb58fb10b477d6bfaa1f4575340ba/tdd_train.json",
        "system_prompt": TDD_SYSTEM_PROMPT
    }
]


In [None]:
# Load and combine data from all gists
combined_data = []
for config in gist_configs:
    # Download the gist content
    response = requests.get(config["url"])
    json_data = response.text
    
    # Load the JSON data
    data = json.loads(json_data)
    
    # Add system prompt information to each item
    for item in data:
        item["system_prompt"] = config["system_prompt"]
    
    combined_data.extend(data)

# Create a dataset from the combined list of dictionaries
dataset = Dataset.from_list(combined_data)

# prompt: dataset length

print(len(dataset))

dataset[0]['answer']



In [None]:

# Function to create the 'conversations' field
def create_conversations_field(examples):
    """Formats an example into a conversational structure."""
    # examples is a dictionary where keys are column names and values are lists when batched=True
    prompts = examples['prompt']
    main_codes = examples['answer']
    system_prompts = examples['system_prompt']

    conversations_batch = []
    for prompt, main_code, system_prompt in zip(prompts, main_codes, system_prompts):
        conversations = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": main_code},
        ]
        conversations_batch.append(conversations)

    # Return a dictionary with the new 'conversations' column
    return {"conversations": conversations_batch}

# Function to apply the chat template to the conversations field
def formatting_prompts_func(examples):
    """Applies the chat template to the 'conversations' field."""
    # examples is a dictionary where keys are column names and values are lists when batched=True
    convos = examples["conversations"]
    # Apply chat template to each list of conversations in the batch
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }

# Apply the first mapping to create the 'conversations' field (batched=True is more efficient)
dataset_with_conversations = dataset.map(create_conversations_field, batched=True, )

# Apply the second mapping to create the 'text' field from 'conversations' (batched=True is also efficient here)
dataset = dataset_with_conversations.map(formatting_prompts_func, batched=True, )

# Now the dataset has both "conversations" and "text" fields
print(dataset[0])

# Now you can access the data in the dataset
print(dataset[0]) # Print the first example


In [None]:
"""<a name="Train"></a>
### Train the model
Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!

The trainer includes our **gradient accumulation bug fix**. Read more about it here: [Blog post](https://unsloth.ai/blog/gradient)
"""

from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    dataset_num_proc = 4,
    packing = True,
    args = TrainingArguments(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 8,
        warmup_steps = 50,
        num_train_epochs = CONFIG['num_train_epochs'],
        learning_rate = 1e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 10,
        optim = "paged_adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "wandb",
    ),
)

"""We also use Unsloth's `train_on_completions` method to only train on the assistant outputs and ignore the loss on the user's inputs."""

from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<|im_start|>user\n",
    response_part = "<|im_start|>assistant\n",
)

"""We verify masking is actually done:"""

tokenizer.decode(trainer.train_dataset[5]["input_ids"])

space = tokenizer(" ", add_special_tokens = False).input_ids[0]
tokenizer.decode([space if x == -100 else x for x in trainer.train_dataset[5]["labels"]])


In [None]:
"""Train using SFT"""

trainer_stats = trainer.train()

In [None]:
"""<a name="Inference"></a>
### Inference
Let's run the model! You can change the instruction and input - leave the output blank!



We use `min_p = 0.1` and `temperature = 1.5`. Read this [Tweet](https://x.com/menhguin/status/1826132708508213629) for more information on why.
"""

from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "qwen-2.5",
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference

messages = [
    {"role": "user", "content": "Continue the fibonnaci sequence: 1, 1, 2, 3, 5, 8,"},
]
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize = True,
    add_generation_prompt = True, # Must add for generation
    return_tensors = "pt",
).to("cuda")

outputs = model.generate(input_ids = inputs, max_new_tokens = 64, use_cache = True,
                         temperature = 1.5, min_p = 0.1)
tokenizer.batch_decode(outputs)

In [None]:
"""<a name="Save"></a>
### Saving, loading finetuned models
To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.

**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!
"""

# model.save_pretrained("bal_coder_lora2")  # Local saving
# tokenizer.save_pretrained("bal_coder_lora2")
# model.push_to_hub("xlight05/bal_coder_lora2_int", token = HF_TOKEN) # Online saving
# tokenizer.push_to_hub("xlight05/bal_coder_lora2_int", token = HF_TOKEN) # Online saving

"""### Saving to float16 for VLLM

We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.
"""

# Generate model names based on config
sft_4bit_vllm_name = generate_model_name('sft', 'vllm', '4bit')
sft_8bit_vllm_name = generate_model_name('sft', 'vllm', '8bit')
sft_16bit_vllm_name = generate_model_name('sft', 'vllm', '16bit')
sft_gguf_name = generate_model_name('sft', 'gguf')

# print(f"Generated model names:")
# print(f"4bit VLLM: {sft_4bit_vllm_name}")
# print(f"8bit VLLM: {sft_8bit_vllm_name}")
# print(f"16bit VLLM: {sft_16bit_vllm_name}")
# print(f"GGUF: {sft_gguf_name}")

# Merge to 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged(sft_4bit_vllm_name, tokenizer, save_method = "merged_4bit", token = HF_TOKEN)

# # Merge to 16bit
# if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
# if False: model.push_to_hub_merged(sft_16bit_vllm_name, tokenizer, save_method = "merged_16bit", token = HF_TOKEN)

# # Just LoRA adapters
# if False: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
# if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "lora", token = "")

if True: model.push_to_hub_merged(sft_8bit_vllm_name, tokenizer, save_method = "merged_8bit", token = HF_TOKEN)

if True: model.push_to_hub_merged(sft_16bit_vllm_name, tokenizer, save_method = "merged_16bit", token = HF_TOKEN)


"""### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.

Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):
* `q8_0` - Fast conversion. High resource use, but generally acceptable.
* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.
* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.

[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
"""

# Save to 8bit Q8_0
# if True: model.save_pretrained_gguf("model", tokenizer,)
# Remember to go to https://huggingface.co/settings/tokens for a token!
# And change hf to your username!
if True: model.push_to_hub_gguf(sft_gguf_name, tokenizer, token = HF_TOKEN)



In [None]:
torch.cuda.empty_cache()
import gc
gc.collect()

In [None]:
!

In [None]:
# Define the code block delimiters
code_start = "<CODE>"
code_end = "</CODE>"
test_start = "<TESTS>"
test_end = "</TESTS>"
import difflib

import subprocess
import os
from typing import Dict, List, Optional, Tuple


class BallerinaManager:
    def __init__(self, project_path: str = "."):
        self.project_path = project_path
    
    def get_build_status(self) -> Dict[str, any]:
        try:
            result = subprocess.run(
                ["bal", "build", "--offline"],
                cwd=self.project_path,
                capture_output=True,
                text=True,
                timeout=60
            )
            
            return {
                "success": result.returncode == 0,
                "return_code": result.returncode,
                "stdout": result.stdout,
                "stderr": result.stderr,
                "compilation_errors": self._extract_compilation_errors(result.stdout + result.stderr)
            }
        except subprocess.TimeoutExpired:
            return {
                "success": False,
                "return_code": -1,
                "stdout": "",
                "stderr": "Build process timed out",
                "compilation_errors": ["Build process timed out after 60 seconds"]
            }
        except FileNotFoundError:
            return {
                "success": False,
                "return_code": -1,
                "stdout": "",
                "stderr": "bal command not found",
                "compilation_errors": ["Ballerina CLI not found. Please ensure Ballerina is installed and in PATH"]
            }
        except Exception as e:
            return {
                "success": False,
                "return_code": -1,
                "stdout": "",
                "stderr": str(e),
                "compilation_errors": [f"Unexpected error: {str(e)}"]
            }
    
    def get_test_status(self) -> Dict[str, any]:
        try:
            result = subprocess.run(
                ["bal", "test", "--offline"],
                cwd=self.project_path,
                capture_output=True,
                text=True,
                timeout=120
            )
            
            return {
                "success": result.returncode == 0,
                "return_code": result.returncode,
                "stdout": result.stdout,
                "stderr": result.stderr,
                "test_results": self._extract_test_results(result.stdout + result.stderr),
                "compilation_errors": self._extract_compilation_errors(result.stdout + result.stderr)
            }
        except subprocess.TimeoutExpired:
            return {
                "success": False,
                "return_code": -1,
                "stdout": "",
                "stderr": "Test process timed out",
                "test_results": {"passed": 0, "failed": 0, "total": 0},
                "compilation_errors": ["Test process timed out after 120 seconds"]
            }
        except FileNotFoundError:
            return {
                "success": False,
                "return_code": -1,
                "stdout": "",
                "stderr": "bal command not found",
                "test_results": {"passed": 0, "failed": 0, "total": 0},
                "compilation_errors": ["Ballerina CLI not found. Please ensure Ballerina is installed and in PATH"]
            }
        except Exception as e:
            return {
                "success": False,
                "return_code": -1,
                "stdout": "",
                "stderr": str(e),
                "test_results": {"passed": 0, "failed": 0, "total": 0},
                "compilation_errors": [f"Unexpected error: {str(e)}"]
            }
    
    def _extract_compilation_errors(self, output: str) -> List[str]:
        errors = []
        lines = output.split('\n')
        
        for line in lines:
            line = line.strip()
            if any(keyword in line.lower() for keyword in ['error:', 'compilation error', 'build failed']):
                errors.append(line)
            elif line.startswith('ERROR') or 'error occurred' in line.lower():
                errors.append(line)
        
        return errors
    
    def _extract_test_results(self, output: str) -> Dict[str, int]:
        results = {"passed": 0, "failed": 0, "total": 0}
        lines = output.split('\n')
        
        for line in lines:
            line = line.strip()
            
            # Look for Ballerina test output format: "X passing", "Y failing", "Z skipped"
            if 'passing' in line:
                try:
                    parts = line.split()
                    for i, part in enumerate(parts):
                        if part == 'passing' and i > 0:
                            results["passed"] = int(parts[i-1])
                            break
                except (ValueError, IndexError):
                    continue
            elif 'failing' in line:
                try:
                    parts = line.split()
                    for i, part in enumerate(parts):
                        if part == 'failing' and i > 0:
                            results["failed"] = int(parts[i-1])
                            break
                except (ValueError, IndexError):
                    continue
        
        results["total"] = results["passed"] + results["failed"]
        return results
    



In [None]:

from uuid import uuid4
import tempfile
from pathlib import Path

"""
Define functions for setting up and testing Ballerina projects.
"""

def create_ballerina_toml(package_name: str) -> str:
    return f"""[package]
org = "test"
name = "test_project"
version = "0.1.0"
distribution = "2201.12.7"

[build-options]
observabilityIncluded = false
"""

def create_main_bal(main_content: str) -> str:
    return f"""{main_content}"""

def create_test_bal(test_content: str) -> str:
    return f"""{test_content}"""

def setup_build_ballerina(main_content: str, test_content: str) -> dict:
    """Set up temporary Ballerina project and run build with error handling"""
    try:
        # Create temporary directory with random UUID suffix
        package_name = f"test-project-{str(uuid4())[:8]}"
        
        with tempfile.TemporaryDirectory() as temp_dir:
            project_dir = Path(temp_dir) / package_name
            project_dir.mkdir()
            tests_dir = project_dir / "tests"
            tests_dir.mkdir()

            # Write project files
            (project_dir / "Ballerina.toml").write_text(create_ballerina_toml(package_name))
            (project_dir / "main.bal").write_text(create_main_bal(main_content))
            (tests_dir / "test.bal").write_text(create_test_bal(test_content))

            # Use BallerinaManager to get build status
            ballerina_manager = BallerinaManager(str(project_dir))
            build_result = ballerina_manager.get_build_status()

            return {
                "build_passed": build_result["success"],
                "build_stderr": build_result["stderr"],
                "compilation_errors": build_result["compilation_errors"],
                "package_name": package_name
            }
    except Exception as e:
        print(f"Error setting up Ballerina project: {e}")
        return {
            "build_passed": False,
            "build_stderr": f"Project setup error: {e}",
            "compilation_errors": [f"Project setup error: {e}"],
            "package_name": "unknown"
        }

def setup_build_test_ballerina(main_content: str, test_content: str) -> dict:
    """Set up temporary Ballerina project with tests and run build and test with error handling"""
    try:
        # Create temporary directory with random UUID suffix
        package_name = f"test-project-{str(uuid4())[:8]}"
        
        with tempfile.TemporaryDirectory() as temp_dir:
            project_dir = Path(temp_dir) / package_name
            project_dir.mkdir()
            tests_dir = project_dir / "tests"
            tests_dir.mkdir()

            # Write project files
            (project_dir / "Ballerina.toml").write_text(create_ballerina_toml(package_name))
            (project_dir / "main.bal").write_text(create_main_bal(main_content))
            (tests_dir / "test.bal").write_text(create_test_bal(test_content))

            # Use BallerinaManager to get build and test status
            ballerina_manager = BallerinaManager(str(project_dir))
            
            # Get build status first
            build_result = ballerina_manager.get_build_status()
            
            # Get test status only if build succeeds
            if build_result["success"]:
                test_result = ballerina_manager.get_test_status()
            else:
                test_result = {
                    "success": False,
                    "stdout": "",
                    "stderr": "Build failed, skipping tests",
                    "test_results": {"passed": 0, "failed": 0, "total": 0},
                    "compilation_errors": []
                }

            return {
                "build_passed": build_result["success"],
                "build_stderr": build_result["stderr"],
                "build_compilation_errors": build_result["compilation_errors"],
                "test_passed": test_result["success"],
                "test_stderr": test_result["stderr"],
                "test_results": test_result["test_results"],
                "test_compilation_errors": test_result["compilation_errors"],
                "package_name": package_name
            }
    except Exception as e:
        print(f"Error setting up Ballerina project with tests: {e}")
        return {
            "build_passed": False,
            "build_stderr": f"Project setup error: {e}",
            "build_compilation_errors": [f"Project setup error: {e}"],
            "test_passed": False,
            "test_stderr": f"Project setup error: {e}",
            "test_results": {"passed": 0, "failed": 0, "total": 0},
            "test_compilation_errors": [f"Project setup error: {e}"],
            "package_name": "unknown"
        }

print("✅ Ballerina project setup functions defined!")




In [None]:

def extract_ballerina_code(response: str) -> str:
    """Extract Ballerina code from response - extracts content inside ```ballerina blocks within <CODE> tags"""
    # Extract everything between <CODE> and </CODE>
    pattern = rf"{re.escape(code_start)}(.*?){re.escape(code_end)}"
    match = re.search(pattern, response, re.DOTALL)
    if match:
        content = match.group(1).strip()
        # Now extract content from ```ballerina code block
        ballerina_pattern = r"```ballerina\s*(.*?)\s*```"
        ballerina_match = re.search(ballerina_pattern, content, re.DOTALL)
        if ballerina_match:
            return ballerina_match.group(1).strip()
    
    return ""

def extract_ballerina_tests(response: str) -> str:
    """Extract Ballerina tests from response - extracts content inside ```ballerina blocks within <TESTS> tags"""
    # Extract everything between <TESTS> and </TESTS>
    pattern = rf"{re.escape(test_start)}(.*?){re.escape(test_end)}"
    match = re.search(pattern, response, re.DOTALL)
    if match:
        content = match.group(1).strip()
        # Now extract content from ```ballerina code block
        ballerina_pattern = r"```ballerina\s*(.*?)\s*```"
        ballerina_match = re.search(ballerina_pattern, content, re.DOTALL)
        if ballerina_match:
            return ballerina_match.group(1).strip()
    
    return ""

def exact_ballerina_main_code(content: str) -> str:
    """Extract main Ballerina code content"""
    return extract_ballerina_code(content)

def exact_ballerina_test_code(content: str) -> str:
    """Extract test Ballerina code content"""
    return extract_ballerina_tests(content)



In [None]:

def reward_build_test_content(content: str) -> float:
    """Reward for code with passing tests (higher reward) - content only version"""
    code = exact_ballerina_main_code(content)
    tests = exact_ballerina_test_code(content)

    if not (code and tests):
        return 0.0
    if "function " not in code:
        return 0.0
    try:
        results = setup_build_test_ballerina(code, tests)
        # print(results)
        if results["build_passed"] is False:
            diags = results.get("build_compilation_errors", 0)
            # print(f"Compilation errors: {diags}")
            diag_count = len(diags) - 1
            if diag_count == 1:
                return 1.5
            elif diag_count == 2:
                return 1.0
            elif diag_count == 3:
                return 0.5
            else:  # more than 3 diagnostics
                return 0.0
        else:
            score = 3.0
            testScore =  min(6.0, 1.5 * results["test_results"]["passed"]) if results["test_passed"] else 0.0
            return score + testScore
    except Exception:
        return 0.0

def reward_code_structure_content(content: str) -> float:
    """Reward for having proper structure with code and test blocks - content only version"""
    score = 0.0

    # Check for code block (0.5 points)
    code = exact_ballerina_main_code(content)
    if code:
        score += 0.5
        
        # Check if code has at least one function (0.5 points)
        if "function " in code:
            score += 0.5

    # Check for test block (0.5 points)
    tests = exact_ballerina_test_code(content)
    if tests:
        score += 0.5
        
        # Check for test import (0.1 points)
        if "import ballerina/test" in tests:
            score += 0.1
        
        # Check for test config (0.1 points)
        if "@test:Config" in tests:
            score += 0.1
        
        # Check for assertions (0.2 points each)
        assert_count = tests.count("test:assert")
        score += min(assert_count * 0.2, 1.0)  # Cap at 1.0 for assertions

    return score

# # Wrapper functions that maintain the original interface for GRPO
# def reward_build(completions, **kwargs) -> list[float]:
#     """Wrapper for build reward function"""
#     return [reward_build_content(completion[0]["content"]) for completion in completions]

def reward_build_test(completions, **kwargs) -> list[float]:
    """Wrapper for test reward function"""
    return [reward_build_test_content(completion[0]["content"]) for completion in completions]

def reward_code_structure(completions, **kwargs) -> list[float]:
    """Wrapper for code structure reward function"""
    return [reward_code_structure_content(completion[0]["content"]) for completion in completions]


print("✅ All reward functions defined!")


In [None]:
"""
Load dataset from Gist URL or use sample data.
"""

def load_gist_dataset(gist_url: str):
    """Load dataset from Gist URL containing JSON array"""
    print(f"Fetching dataset from: {gist_url}")
    response = requests.get(gist_url, timeout=30)
    response.raise_for_status()

    # Parse JSON array
    dataset_json = response.json()

    if not isinstance(dataset_json, list):
        raise ValueError("Expected JSON array format")

    # Convert to expected format
    formatted_samples = []
    for item in dataset_json:
        if not isinstance(item, dict) or "prompt" not in item:
            print(f"Skipping invalid item: {item}")
            continue

        formatted_samples.append({
            "prompt": [
                {"role": "system", "content": TDD_SYSTEM_PROMPT},
                {"role": "user", "content": item["prompt"]},
            ],
            "response": ""  # Will be generated during training
        })

    print(f"Loaded {len(formatted_samples)} samples from Gist")
    return formatted_samples

In [None]:
print(f"Loading dataset from: {CONFIG['dataset_url']}")
dataset_samples = load_gist_dataset(CONFIG["dataset_url"])

# Convert to format expected by trainer
from datasets import Dataset
dataset_dict = {
    "prompt": [sample["prompt"] for sample in dataset_samples],
    "response": [sample["response"] for sample in dataset_samples]
}
dataset = Dataset.from_dict(dataset_dict)

print(f"✅ Dataset created with {len(dataset)} samples")

In [None]:
"""
Test the chat template with a sample prompt.
"""

print("Testing chat template:")
test_messages = dataset[0]["prompt"]
formatted = tokenizer.apply_chat_template(test_messages, tokenize=False, add_generation_prompt=True)
print("Formatted prompt:")
print("-" * 50)
print(formatted)
print("-" * 50)
print("✅ Chat template working correctly!")

In [None]:
"""
Configure GRPO training parameters and create trainer.
"""

from trl import GRPOConfig, GRPOTrainer
from vllm import SamplingParams

# VLLM sampling parameters
vllm_sampling_params = SamplingParams(
    min_p=0.1,
    top_p=1.0,
    top_k=-1,
    seed=3407,
    stop=[tokenizer.eos_token],
    include_stop_str_in_output=True,
)

# GRPO training configuration
training_args = GRPOConfig(
    vllm_sampling_params=vllm_sampling_params,
    temperature=1.0,
    learning_rate=CONFIG["learning_rate"],
    weight_decay=0.01,
    warmup_ratio=0.1,
    lr_scheduler_type="linear",
    optim="adamw_8bit",
    logging_steps=1,
    per_device_train_batch_size=CONFIG["batch_size"],
    gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
    num_generations=CONFIG["num_generations"],
    max_prompt_length=256,
    max_completion_length=CONFIG["max_seq_length"] - 512,
    max_steps=CONFIG["max_steps"],
    save_steps=CONFIG["save_steps"],
    report_to="none",
    output_dir="rust_grpo_outputs",
)

print("Setting up GRPO trainer...")

# Create trainer with all reward functions matching original
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        reward_code_structure,
        reward_build_test,
    ],
    args=training_args,
    train_dataset=dataset,
)

print("✅ GRPO trainer configured successfully!")


In [None]:
"""
Start the GRPO training process.
WARNING: This may take a long time as each generation runs bal tools!
"""

print("🚀 Starting GRPO training...")
print("Note: This may take some time as each generation runs bal tools...")
print("Watch the reward values - they should increase over time!")

# Start training
trainer.train()

print("✅ Training completed!")

In [None]:
"""
Test the trained model with inference.
"""

print("Testing trained model...")
test_prompt = [
    {"role": "system", "content": TDD_SYSTEM_PROMPT},
    {"role": "user", "content": "Write a function that calculates the factorial of a number."},
]

text = tokenizer.apply_chat_template(
    test_prompt,
    tokenize=False,
    add_generation_prompt=True,
)

from transformers import TextStreamer
print("Generated response:")
print("-" * 50)
_ = model.generate(
    **tokenizer(text, return_tensors="pt").to("cuda"),
    temperature=0.7,
    max_new_tokens=1024,
    streamer=TextStreamer(tokenizer, skip_prompt=True),
)
print("-" * 50)

print("✅ Inference test completed!")


In [None]:

if True: model.push_to_hub_merged(generate_model_name('grpo', 'vllm', '16bit'), tokenizer, save_method = "merged_16bit", token = HF_TOKEN)


In [None]:
if True: model.push_to_hub_gguf(generate_model_name('grpo', 'gguf'), tokenizer, token = HF_TOKEN)


In [None]:
import os
os.system("runpodctl stop pod $RUNPOD_POD_ID")
os.system("runpodctl terminate pod $RUNPOD_POD_ID")