# Setup

## Package Installation

In [None]:
#%pip install --upgrade pip
#%pip install transformers==4.37.0
#%pip uninstall torch torchvision torchaudio -y
#%pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 -y
#%pip install torch torchvision torchaudio
#%pip install tqdm
#%pip install numpy==1.24 #probably not needed, leave this commented
#%pip install urllib3==1.26.15
#%pip install accelerate==0.25.0
#%pip install datasets

In [2]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM

import os
import gc
import re
import json
import logging
from tqdm import tqdm
from datasets import load_dataset

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

print(torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('We are using the device {}.'.format(device))
if torch.cuda.is_available():
    print(f"Device count: {torch.cuda.device_count()}")
    print(f"Device name: {torch.cuda.get_device_name(0)}")

1.13.1+cu116
We are using the device cuda.
Device count: 1
Device name: NVIDIA A100-SXM4-40GB


## Utils

In [3]:
def clear_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

In [4]:
# Display total GPU memory
print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

# Display currently allocated memory
print(f"Currently allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")

# Display cached memory (reserved by PyTorch but not used)
print(f"Cached: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

Total GPU memory: 39.39 GB
Currently allocated: 0.00 GB
Cached: 0.00 GB


In [5]:
# # Check disk space
# !df -h

# Data Preparation

## Prompt Templates

In [None]:
COT_PROMPT_TEMPLATE = """Generate a detailed step-by-step solution for this coding problem.
Break down your thought process into clear, concise steps, explaining your reasoning at each stage.

Problem:
{problem}

Step-by-step solution:"""

In [None]:
COT_PROMPT_TEMPLATE = """Generate a detailed step-by-step solution for this coding problem.
Break down your thought process clearly, explaining your reasoning while considering:
- What are the inputs and outputs of the function?
- What algorithm or data structure is most appropriate?
- Are there any edge cases to handle?
- What's the efficiency of your approach?

Be thorough but concise in your explanation.

Problem:
{problem}

Step-by-step solution:"""

In [None]:
CODER_PROMPT_TEMPLATE = """Generate the python code for this coding problem. Follow the
step-by-step process as a guideline for how to solve the problem. Only return python code.

Step-by-step solution:
{cot_solution}

Python code:"""

In [None]:
DEBUGGER_PROMPT_TEMPLATE = """Check the provided python code for any errors. Then regenerate
the code so that any errors have been debugged.

Python code:
{gen_code}

Debugged Python code:"""

In [None]:
EXPLAINER_PROMPT_TEMPLATE = """Generate an explanation of this code using the step-by-step
solution and the code itself. Keep the explanation concise.

Step-by-step solution:
{cot_solution}

Python code:
{gen_code}

Explanation of the code:"""

## Dataset Modules

In [6]:
class CodeCraftDataset(Dataset):
    """
    A generalized dataset for Code Craft agents that works with various prompt templates.

    Args:
        examples: List of dictionaries that hold all agent prompt information.
        tokenizer: Used to tokenize the inputs to the model.
        prompt_template: The prompt template string with placeholders.
        output_field: The name of the field in examples that contains the expected output.
        max_length: The maximum token length of the inputs.
    """
    def __init__(self, examples, tokenizer, prompt_template, output_field, max_length=512):
        self.examples = examples
        self.tokenizer = tokenizer
        self.prompt_template = prompt_template
        self.output_field = output_field
        self.max_length = max_length

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        output = example[self.output_field]

        # Create prompt by formatting template with example data
        # This will use all fields from the example that match placeholders in the template
        try:
            prompt = self.prompt_template.format(**example)
        except KeyError as e:
            missing_key = str(e).strip("'")
            raise KeyError(f"Example at index {idx} is missing required field '{missing_key}' "
                          f"for prompt template: {self.prompt_template}")

        # Combine prompt with expected output
        full_text_with_output = prompt + output

        # Tokenize the combined text
        encoded = self.tokenizer(
            full_text_with_output,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        # Create labels (same as input_ids but with -100 for prompt tokens)
        prompt_tokens = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0]
        prompt_length = len(prompt_tokens)

        labels = encoded["input_ids"].clone()
        labels[0, :prompt_length] = -100  # Don't compute loss for prompt tokens

        result = {
            "input_ids": encoded["input_ids"][0],
            "attention_mask": encoded["attention_mask"][0],
            "labels": labels[0]
        }

        return result

In [14]:
def generate_dataset(problem_dataset, task_prompt, solution_field, output_marker,
    model, tokenizer, num_examples=50, max_new_tokens=512, teacher=True, regen=False,
    output_dir="dataset"):
    """
    Generate a dataset by prompting a teacher model to solve problems for distillation.

    Args:
        problem_dataset: List of dictionaries containing problem data
        task_prompt: Prompt template string with placeholders
        solution_field: Field name for the generated solution in output examples
        output_marker: String marker after which the solution starts in the model output
                       (or None if the entire output is the solution)
        model: The model used to generate solutions
        tokenizer: Tokenizer for the model
        num_examples: Number of examples to generate
        max_new_tokens: Maximum token length for generation
        teacher: a flag indicating if the model is teacher (true) or student (f)
        regen: a flag indicating if the data should be regenerated if it already exists
        output_dir: Directory to save the generated examples

    Returns:
        List of dictionaries containing the problems and their solutions
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Get the model type from the teacher param
    if teacher:
        model_name = "teacher"
    else:
        model_name = "student"

    # If indicated not to regenerate the examples and they exist then return them
    file_name = os.path.join(output_dir, f"{solution_field}_{model_name}_{num_examples}_dataset.json")
    if regen and os.path.exists(file_name):
        with open(file_name, 'r') as examples_file:
            examples = json.load(examples_file)
        return examples

    examples = []
    logger.info(f"Generating {solution_field} with {model_name} for {num_examples} problems...")

    # Take a subset of problems for efficiency
    problems_subset = problem_dataset[:num_examples]

    for i, problem in enumerate(tqdm(problems_subset, desc=f"Generating {solution_field}")):
        try:
            # Format the prompt with the problem data
            prompt = task_prompt.format(**problem)

            # Tokenize the prompt
            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

            # Generate the solution from the model
            model.eval()
            with torch.no_grad():
                output = model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=0.7,
                    do_sample=True,
                    top_p=0.9,
                    num_return_sequences=1
                )

            # Decode the model output
            generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

            # Extract the solution portion if an output marker is provided
            if output_marker and output_marker in generated_text:
                solution_start_idx = generated_text.find(output_marker) + len(output_marker)
                solution = generated_text[solution_start_idx:].strip()
            else:
                # Use the entire output if no marker is provided or found
                solution = generated_text.replace(prompt, "").strip()

            # Create the example with all original problem fields plus the solution
            example = problem.copy()  # Preserve all original fields
            example[solution_field] = solution  # Add the generated solution
            examples.append(example)

            # Save a few examples for inspection
            if i < 2:
                print(f"\nExample {i+1}:")
                print(f"Problem: {example['problem'][:150]}...")
                print(f"Solution (first 150 chars): {example[solution_field][:150]}...")

            # Log progress details periodically
            if (i + 1) % 10 == 0:
                logger.info(f"Generated {i + 1}/{len(problems_subset)} solutions")

        except Exception as e:
            logger.error(f"Error generating solution for problem {i}: {e}")
            continue

    logger.info(f"Successfully generated {len(examples)} {solution_field} solutions")

    # Save the dataset
    with open(file_name, "w") as f:
        json.dump(examples, f, indent=2)

    logger.info(f"Dataset saved to {file_name}")
    return examples

## Load Dataset Functions

In [None]:
# Load MBPP dataset
def load_mbpp_dataset():
    mbpp = load_dataset("mbpp")

    train_problems = []
    # Extract problems from the MBPP dataset with correct field names
    for item in mbpp["train"]:
        train_problems.append({
            "problem": item["text"],
            "test_case": item["test_list"],
            "solution_code": item["code"]
        })

    test_problems = []
    for item in mbpp["test"]:
        test_problems.append({
            "problem": item["text"],
            "test_case": item["test_list"],
            "solution_code": item["code"]
        })

    print(f"Loaded {len(train_problems)} train problems and {len(test_problems)} evaluation problems from MBPP dataset")
    return train_problems, test_problems

In [None]:
# Load BAAI/TACO dataset
def load_taco_dataset():
    taco = load_dataset("BAAI/TACO")

    train_problems = []
    for item in taco["train"]:
        train_problems.append({
            "problem": item["question"],
            "test_case": item["input_output"],
            "solution_code": item["solutions"][0]
        })

    test_problems = []
    for item in taco["test"]:
        train_problems.append({
            "problem": item["question"],
            "test_case": item["test_cases"],
            "solution_code": item["solutions"][0]
        })

    print(f"Loaded {len(train_problems)} train problems and {len(test_problems)} test problems from TACO dataset")
    return train_problems, test_problems

# Agent Code

## Models

In [None]:
# Load models
def load_models(teacher_model_name, student_model_name):
    logger.info(f"Loading teacher model: {teacher_model_name}")
    teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
    teacher_model = AutoModelForCausalLM.from_pretrained(
        teacher_model_name,
        device_map="auto",
        torch_dtype=torch.float32
    )
    logger.info(f"Teacher model loaded successfully")

    logger.info(f"Loading student model: {student_model_name}")
    student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)
    student_model = AutoModelForCausalLM.from_pretrained(
        student_model_name,
        device_map="auto",
        torch_dtype=torch.float32
    )
    logger.info(f"Student model loaded successfully")

    return teacher_model, teacher_tokenizer, student_model, student_tokenizer

## Training

In [None]:
def fine_tune_student_model(student_model, student_tokenizer, train_data, prompt,
                        output_field, batch_size=8, num_epochs=3, learning_rate=5e-5,
                        max_grad_norm=1.0, warmup_steps=0, max_length=512,
                        output_dir="results"):
    """
    Fine-tune the student model on examples generated by the teacher model.

    Args:
        student_model: The student model to train
        student_tokenizer: Tokenizer for the student model
        train_data: List of data dictionaries for training
        prompt: The prompt containing fields for training
        output: The output data field to train on
        batch_size: Training batch size
        num_epochs: Number of training epochs
        learning_rate: Learning rate for the optimizer
        max_grad_norm: Maximum gradient norm for gradient clipping
        warmup_steps: Linear warmup steps for the learning rate scheduler
        max_length: the maximum number of tokens in the dataset values
        output_dir: Directory to save the trained model
    """
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"Starting training the student model for {num_epochs} epochs")

    # Create PyTorch dataset and dataloader
    dataset = CodeCraftDataset(
        examples=train_data,
        tokenizer=student_tokenizer,
        prompt_template=prompt,
        output_field=output_field,
        max_length=max_length
    )
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=True
    )

    # Set up optimizer and learning rate scheduler
    optimizer = optim.AdamW(student_model.parameters(), lr=learning_rate)
    total_steps = len(dataloader) * num_epochs
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate, total_steps=total_steps,
        pct_start=warmup_steps/total_steps if warmup_steps > 0 else 0.1
    )

    # Set up training tracking
    best_loss = float('inf')
    global_step = 0
    student_model.train()

    # Training loop
    for epoch in range(num_epochs):
        epoch_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch in progress_bar:
            # Move batch to device
            input_ids = batch["input_ids"].to(student_model.device)
            attention_mask = batch["attention_mask"].to(student_model.device)
            labels = batch["labels"].to(student_model.device)

            # Forward pass - compute student model outputs
            outputs = student_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            loss = outputs.loss

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_grad_norm)

            # Update parameters
            optimizer.step()
            scheduler.step()

            # Track loss
            epoch_loss += loss.item()
            global_step += 1

            # Update progress bar
            progress_bar.set_postfix({"loss": loss.item()})

            # Save checkpoint occasionally
            if global_step % 100 == 0:
                logger.info(f"Step {global_step}: loss = {loss.item():.4f}")

        # Compute average epoch loss
        avg_epoch_loss = epoch_loss / len(dataloader)
        logger.info(f"Epoch {epoch+1}/{num_epochs} - Average loss: {avg_epoch_loss:.4f}")

        # Save checkpoint if it's the best model so far
        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            checkpoint_path = os.path.join(output_dir, f"student_model_{output_field}_epoch_{epoch+1}")
            logger.info(f"Saving best model so far (loss: {best_loss:.4f}) to {checkpoint_path}")
            student_model.save_pretrained(checkpoint_path)
            student_tokenizer.save_pretrained(checkpoint_path)

    # Save final model
    final_model_path = os.path.join(output_dir, f"student_model_{output_field}_final")
    logger.info(f"Training completed. Saving final model to {final_model_path}")
    student_model.save_pretrained(final_model_path)
    student_tokenizer.save_pretrained(final_model_path)

    return student_model, student_tokenizer

In [None]:
def logit_distillation_loss(student_logits, teacher_logits, temperature=1.0, alpha=0.5):
    """
    Calculate the knowledge distillation loss between student and teacher logits.

    Args:
        student_logits: Logits from the student model [batch_size, seq_len, vocab_size]
        teacher_logits: Logits from the teacher model [batch_size, seq_len, vocab_size]
        temperature: Temperature parameter to soften the distributions
        alpha: Weight for the distillation loss (1-alpha for the regular CE loss)

    Returns:
        The distillation loss
    """
    # Apply temperature scaling
    student_logits_scaled = student_logits / temperature
    teacher_logits_scaled = teacher_logits / temperature

    # Convert logits to probabilities
    student_probs = F.softmax(student_logits_scaled, dim=-1)
    teacher_probs = F.softmax(teacher_logits_scaled, dim=-1)

    # Calculate KL divergence loss
    kl_div = F.kl_div(
        F.log_softmax(student_logits, dim=-1),
        F.softmax(teacher_logits, dim=-1, dtype=torch.float32),  # Specify dtype
        reduction='batchmean',
        log_target=False  # Teacher probs are not in log space
    )

    return loss

In [None]:
def train_with_logit_distillation(
    model, train_dataloader, optimizer, scheduler=None,
    num_epochs=3, device="cuda", alpha=0.5, temperature=2.0,
    max_grad_norm=1.0):
    """
    Train a model using logit distillation.

    Args:
        model: The student model to train
        train_dataloader: DataLoader containing training examples with teacher logits
        optimizer: Optimizer for training
        scheduler: Learning rate scheduler (optional)
        num_epochs: Number of training epochs
        device: Device to use for training
        alpha: Weight for distillation loss vs standard cross-entropy loss
        temperature: Temperature for softening logit distributions
        max_grad_norm: Maximum gradient norm for clipping

    Returns:
        Trained model and training losses
    """
    model.train()
    losses = []

    # Create cross entropy loss for regular training
    ce_loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

    for epoch in range(num_epochs):
        epoch_losses = []
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch in progress_bar:
            # Move batch to device
            model_device = next(model.parameters()).device
            input_ids = batch["input_ids"].to(model_device)
            attention_mask = batch["attention_mask"].to(model_device)
            labels = batch["labels"].to(model_device)

            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
                output_hidden_states=True
            )

            # Standard cross-entropy loss from labels
            ce_loss = outputs.loss

            # Get student logits
            student_logits = outputs.logits

            # Extract teacher logits if available and calculate distillation loss
            total_loss = ce_loss
            if "teacher_logits" in batch:
                teacher_logits = batch["teacher_logits"].to(device)

                # Make sure teacher_logits has the same shape as student_logits
                if teacher_logits.shape != student_logits.shape:
                    # Handle different sequence lengths if needed
                    min_len = min(teacher_logits.shape[1], student_logits.shape[1])
                    teacher_logits = teacher_logits[:, :min_len, :]
                    student_logits = student_logits[:, :min_len, :]

                # Calculate distillation loss
                kd_loss = logit_distillation_loss(
                    student_logits,
                    teacher_logits,
                    temperature=temperature
                )

                # Combine losses
                total_loss = (1 - alpha) * ce_loss + alpha * kd_loss

            # Backward pass
            optimizer.zero_grad()
            total_loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            # Update parameters
            optimizer.step()
            if scheduler is not None:
                scheduler.step()

            # Track loss
            epoch_losses.append(total_loss.item())
            progress_bar.set_postfix({"loss": total_loss.item()})

        # Calculate and report average loss for the epoch
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        logger.info(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")
        losses.append(avg_loss)

    return model, losses

## Evaluation

In [None]:
def evaluate_student_model(student_model, student_tokenizer, test_problems, teacher_model=None,
                          batch_size=4, max_length=512, output_dir="results/evaluations"):
    """
    Evaluate the student model on a set of test problems.

    Args:
        student_model: Trained student model
        student_tokenizer: Tokenizer for the student model
        test_problems: List of test problems to evaluate on
        teacher_model: Optional teacher model for comparison
        batch_size: Batch size for evaluation
        max_length: Maximum sequence length for generation
        output_dir: Directory to save evaluation results

    Returns:
        Dictionary with evaluation metrics
    """
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"Evaluating student model on {len(test_problems)} test problems")

    # Set models to evaluation mode
    student_model.eval()
    if teacher_model is not None:
        teacher_model.eval()

    results = {
        "total_problems": len(test_problems),
        "student_generations": [],
        "teacher_generations": [] if teacher_model else None,
        "prompts": []
    }

    # Process test problems in batches
    for i in range(0, len(test_problems), batch_size):
        batch_problems = test_problems[i:i+batch_size]
        batch_prompts = []

        for problem in batch_problems:
            prompt = PROMPT_TEMPLATE.format(problem=problem["problem"])
            batch_prompts.append(prompt)
            results["prompts"].append(prompt)

        # Generate solutions with student model
        student_outputs = []
        for prompt in tqdm(batch_prompts, desc="Generating student solutions"):
            inputs = student_tokenizer(prompt, return_tensors="pt").to(student_model.device)

            student_model.eval()
            with torch.no_grad():
                output = student_model.generate(
                    **inputs,
                    max_length=max_length,
                    temperature=0.7,
                    do_sample=True,
                    top_p=0.9,
                    num_return_sequences=1
                )

            decoded_output = student_tokenizer.decode(output[0], skip_special_tokens=True)
            student_outputs.append(decoded_output)

        results["student_generations"].extend(student_outputs)

        # If teacher model is provided, generate solutions for comparison
        if teacher_model:
            teacher_outputs = []
            for prompt in tqdm(batch_prompts, desc="Generating teacher solutions"):
                inputs = student_tokenizer(prompt, return_tensors="pt").to(teacher_model.device)

                with torch.no_grad():
                    output = teacher_model.generate(
                        **inputs,
                        max_length=max_length,
                        temperature=0.7,
                        do_sample=True,
                        top_p=0.9,
                        num_return_sequences=1
                    )

                decoded_output = student_tokenizer.decode(output[0], skip_special_tokens=True)
                teacher_outputs.append(decoded_output)

            results["teacher_generations"].extend(teacher_outputs)

    # Process and extract solutions
    logger.info("Processing generated solutions")
    student_solutions = []
    teacher_solutions = [] if teacher_model else None

    for output in results["student_generations"]:
        solution_start_marker = "Step-by-step solution:"
        solution_start_idx = output.find(solution_start_marker) + len(solution_start_marker)
        solution = output[solution_start_idx:].strip()
        student_solutions.append(solution)

    if teacher_model:
        for output in results["teacher_generations"]:
            solution_start_marker = "Step-by-step solution:"
            solution_start_idx = output.find(solution_start_marker) + len(solution_start_marker)
            solution = output[solution_start_idx:].strip()
            teacher_solutions.append(solution)

    # Calculate some basic metrics
    logger.info("Calculating evaluation metrics")

    # Calculate average solution length
    student_avg_length = sum(len(solution.split()) for solution in student_solutions) / len(student_solutions)
    results["student_avg_word_count"] = student_avg_length

    if teacher_model:
        teacher_avg_length = sum(len(solution.split()) for solution in teacher_solutions) / len(teacher_solutions)
        results["teacher_avg_word_count"] = teacher_avg_length
        results["length_ratio"] = student_avg_length / teacher_avg_length if teacher_avg_length > 0 else 0

    # Check for step-by-step reasoning keywords
    reasoning_keywords = ["first", "second", "third", "next", "then", "finally", "step", "let's", "because", "reason"]
    student_keyword_counts = []

    for solution in student_solutions:
        solution_lower = solution.lower()
        count = sum(1 for keyword in reasoning_keywords if keyword in solution_lower)
        student_keyword_counts.append(count)

    results["student_avg_reasoning_markers"] = sum(student_keyword_counts) / len(student_keyword_counts)

    if teacher_model:
        teacher_keyword_counts = []
        for solution in teacher_solutions:
            solution_lower = solution.lower()
            count = sum(1 for keyword in reasoning_keywords if keyword in solution_lower)
            teacher_keyword_counts.append(count)

        results["teacher_avg_reasoning_markers"] = sum(teacher_keyword_counts) / len(teacher_keyword_counts)
        results["reasoning_marker_ratio"] = (results["student_avg_reasoning_markers"] /
                                           results["teacher_avg_reasoning_markers"]
                                           if results["teacher_avg_reasoning_markers"] > 0 else 0)

    # Save a few example comparisons
    with open(os.path.join(output_dir, "solution_examples.txt"), "w") as f:
        for i in range(min(5, len(student_solutions))):
            f.write(f"Problem {i+1}:\n")
            f.write(f"{results['prompts'][i]}\n\n")
            f.write(f"Student solution:\n{student_solutions[i]}\n\n")
            if teacher_model:
                f.write(f"Teacher solution:\n{teacher_solutions[i]}\n\n")
            f.write("-" * 80 + "\n\n")

    # Save all evaluation results
    with open(os.path.join(output_dir, "evaluation_results.json"), "w") as f:
        # Create a summary version without the full generations for easier reading
        summary_results = {k: v for k, v in results.items()
                         if k not in ["student_generations", "teacher_generations", "prompts"]}
        json.dump(summary_results, f, indent=2)

    # Save the full results separately
    with open(os.path.join(output_dir, "full_results.json"), "w") as f:
        json.dump(results, f, indent=2)

    logger.info(f"Evaluation complete. Results saved to {output_dir}")
    return results

In [None]:
def track_best_model(evaluation_results, best_metrics, model_path, output_dir="results/best_model"):
    """
    Track and save the best student model based on evaluation metrics.

    Args:
        evaluation_results: Results dictionary from evaluate_student_model
        best_metrics: Dictionary with current best metrics
        model_path: Path to the current model
        output_dir: Directory to save the best model

    Returns:
        Updated best_metrics dictionary
    """
    os.makedirs(output_dir, exist_ok=True)

    # Define a scoring function to rank models (higher is better)
    # Here we prioritize reasoning marker ratio and solution length ratio
    current_score = (
        evaluation_results.get("reasoning_marker_ratio", 0) * 0.7 +
        evaluation_results.get("length_ratio", 0) * 0.3
    )

    best_score = (
        best_metrics.get("reasoning_marker_ratio", 0) * 0.7 +
        best_metrics.get("length_ratio", 0) * 0.3
    )

    # Check if current model is better than the best so far
    if current_score > best_score:
        logger.info(f"New best model found! Score: {current_score:.4f} (previous: {best_score:.4f})")

        # Update best metrics
        best_metrics = {
            "score": current_score,
            "model_path": model_path,
            "reasoning_marker_ratio": evaluation_results.get("reasoning_marker_ratio", 0),
            "length_ratio": evaluation_results.get("length_ratio", 0),
            "student_avg_reasoning_markers": evaluation_results.get("student_avg_reasoning_markers", 0),
            "student_avg_word_count": evaluation_results.get("student_avg_word_count", 0)
        }

        # Copy the model to the best model directory
        if os.path.exists(model_path):
            logger.info(f"Copying best model from {model_path} to {output_dir}")

            # Clear previous best model
            if os.path.exists(output_dir):
                for file in os.listdir(output_dir):
                    file_path = os.path.join(output_dir, file)
                    if os.path.isfile(file_path):
                        os.remove(file_path)

            # Copy new best model
            for file in os.listdir(model_path):
                source_file = os.path.join(model_path, file)
                if os.path.isfile(source_file):
                    shutil.copy(source_file, os.path.join(output_dir, file))

        # Save best metrics
        with open(os.path.join(output_dir, "best_metrics.json"), "w") as f:
            json.dump(best_metrics, f, indent=2)

    return best_metrics

# Main

In [1]:
# Global Params
TEACHER_EXAMPLE_LEN = 374 # number of train mbpp problems
STUDENT_EXAMPLE_LEN = 10
GENERATED_TOKEN_LEN = 512

# Training Params
NUM_EPOCHS = 6
LEARNING_RATE = 2e-5
BATCH_SIZE = 10
WARMUP_STEPS = TEACHER_EXAMPLE_LEN * 0.05

In [None]:
clear_gpu_memory()

print("Loading MBPP dataset...")
mbpp_train_problems, mbpp_test_problems = load_mbpp_dataset()

## CoT Agent

In [None]:
clear_gpu_memory()

# CoT Agent Params
SOLUTION_FIELD = "solution_cot"
OUTPUT_MARKER = "Step-by-step solution:"

print("Loading CoT models...")
teacher_model, teacher_tokenizer, student_model, student_tokenizer = load_models("Qwen/Qwen2.5-7B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct")

2025-04-16 20:38:46,835 - INFO - Loading teacher model: Qwen/Qwen2.5-7B-Instruct


Loading CoT models...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2025-04-16 20:38:50,376 - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

2025-04-16 20:39:04,206 - INFO - Teacher model loaded successfully
2025-04-16 20:39:04,207 - INFO - Loading student model: Qwen/Qwen2.5-0.5B
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2025-04-16 20:39:05,832 - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
2025-04-16 20:39:09,539 - INFO - Student model loaded successfully


Loading CoT dataset...
Loaded 374 train problems and 500 evaluation problems from MBPP dataset


In [None]:
clear_gpu_memory()

print("Generating CoT examples using Teacher model...")
train_cot_examples = generate_dataset(
    mbpp_train_problems,
    COT_PROMPT_TEMPLATE,
    SOLUTION_FIELD,
    OUTPUT_MARKER,
    teacher_model,
    teacher_tokenizer,
    num_examples=TEACHER_EXAMPLE_LEN,
    max_new_tokens=GENERATED_TOKEN_LEN
)

print("Generating CoT examples using untrained Student model...")
untrained_cot_examples = generate_dataset(
    mbpp_test_problems,
    COT_PROMPT_TEMPLATE,
    SOLUTION_FIELD,
    OUTPUT_MARKER,
    student_model,
    student_tokenizer,
    num_examples=STUDENT_EXAMPLE_LEN,
    max_new_tokens=GENERATED_TOKEN_LEN,
    teacher=False
)

2025-04-16 17:48:58,026 - INFO - Generating solution_cot for 374 problems...


Generating CoT examples using Teacher model...


Generating solution_cot:   0%|          | 1/374 [00:12<1:16:38, 12.33s/it]


Example 1:
Problem: Write a function to find the longest chain which can be formed from the given set of pairs....
Solution (first 150 chars): 1. Define a function called "longest_chain" that takes in a list of tuples as input.
    - This function will take the set of pairs and return the len...


Generating solution_cot:   1%|          | 2/374 [00:24<1:14:11, 11.97s/it]


Example 2:
Problem: Write a python function to find the first repeated character in a given string....
Solution (first 150 chars): 1. Define a function called "first_repeated_char" that takes one parameter, 'input_string'.
2. Create an empty set called 'seen_chars' to store charac...


Generating solution_cot:   2%|▏         | 9/374 [01:46<1:11:17, 11.72s/it]2025-04-16 17:50:55,736 - INFO - Generated 10/374 solutions
Generating solution_cot:   5%|▌         | 19/374 [03:43<1:09:18, 11.72s/it]2025-04-16 17:52:52,856 - INFO - Generated 20/374 solutions
Generating solution_cot:   8%|▊         | 29/374 [05:38<1:06:55, 11.64s/it]2025-04-16 17:54:47,865 - INFO - Generated 30/374 solutions
Generating solution_cot:  10%|█         | 39/374 [07:32<1:03:24, 11.36s/it]2025-04-16 17:56:41,886 - INFO - Generated 40/374 solutions
Generating solution_cot:  13%|█▎        | 49/374 [09:28<1:03:21, 11.70s/it]2025-04-16 17:58:38,566 - INFO - Generated 50/374 solutions
Generating solution_cot:  16%|█▌        | 59/374 [11:23<57:08, 10.88s/it]  2025-04-16 18:00:32,835 - INFO - Generated 60/374 solutions
Generating solution_cot:  18%|█▊        | 69/374 [13:20<59:19, 11.67s/it]2025-04-16 18:02:29,779 - INFO - Generated 70/374 solutions
Generating solution_cot:  21%|██        | 79/374 [15:17<57

Generating CoT examples using Student model...


Generating solution_cot:   0%|          | 0/5 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  20%|██        | 1/5 [00:05<00:20,  5.02s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.



Example 1:
Problem: Write a function to find the longest chain which can be formed from the given set of pairs....
Solution (first 150 chars): 1. Initialize an empty dictionary to store the longest chains for each pair.
2. Iterate through the given set of pairs.
3. For each pair, find all pos...


Generating solution_cot:  40%|████      | 2/5 [00:09<00:14,  4.97s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.



Example 2:
Problem: Write a python function to find the first repeated character in a given string....
Solution (first 150 chars): 1. Initialize a dictionary called 'char_count' with all characters of the string as keys, and their respective counts as values.
2. Iterate through th...


Generating solution_cot:  60%|██████    | 3/5 [00:14<00:09,  4.84s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  80%|████████  | 4/5 [00:18<00:04,  4.53s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot: 100%|██████████| 5/5 [00:27<00:00,  5.50s/it]
2025-04-16 19:01:30,006 - INFO - Successfully generated 5 solutions with logits
2025-04-16 19:01:30,012 - INFO - Dataset saved to solution_cot_dataset.json


In [None]:
# # in case of disk/memory filling, this reloads the examples from json

# clear_gpu_memory()

# mdpp_examples_file = open(f"{SOLUTION_FIELD}_teacher_{TEACHER_EXAMPLE_LEN}_dataset.json")
# train_cot_examples = json.load(mdpp_examples_file)

# print(train_cot_examples[0])

{'problem': 'Write a function to find the longest chain which can be formed from the given set of pairs.', 'test_case': ['assert max_chain_length([Pair(5, 24), Pair(15, 25),Pair(27, 40), Pair(50, 60)], 4) == 3', 'assert max_chain_length([Pair(1, 2), Pair(3, 4),Pair(5, 6), Pair(7, 8)], 4) == 4', 'assert max_chain_length([Pair(19, 10), Pair(11, 12),Pair(13, 14), Pair(15, 16), Pair(31, 54)], 5) == 5'], 'code': 'class Pair(object): \r\n\tdef __init__(self, a, b): \r\n\t\tself.a = a \r\n\t\tself.b = b \r\ndef max_chain_length(arr, n): \r\n\tmax = 0\r\n\tmcl = [1 for i in range(n)] \r\n\tfor i in range(1, n): \r\n\t\tfor j in range(0, i): \r\n\t\t\tif (arr[i].a > arr[j].b and\r\n\t\t\t\tmcl[i] < mcl[j] + 1): \r\n\t\t\t\tmcl[i] = mcl[j] + 1\r\n\tfor i in range(n): \r\n\t\tif (max < mcl[i]): \r\n\t\t\tmax = mcl[i] \r\n\treturn max', 'solution_cot': '1. Define a function called "longest_chain" that takes in a list of tuples as input.\n    - This function will take the set of pairs and return th

In [None]:
clear_gpu_memory()

# Fine-tune the student model
print("Fine-Tuning CoT on student model...")
trained_student_model, trained_tokenizer = fine_tune_student_model(
    student_model=student_model,
    student_tokenizer=student_tokenizer,
    train_data=train_cot_examples,
    prompt=COT_PROMPT_TEMPLATE,
    output_field=SOLUTION_FIELD,
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    max_length=GENERATED_TOKEN_LEN
)

2025-04-16 20:40:00,172 - INFO - Starting training of student model for 6 epochs


Fine-Tuning CoT on student model...


Epoch 1/6: 100%|██████████| 38/38 [00:38<00:00,  1.01s/it, loss=0.688]
2025-04-16 20:40:40,037 - INFO - Epoch 1/6 - Average loss: 0.7380
2025-04-16 20:40:40,038 - INFO - Saving best model so far (loss: 0.7380) to results/student_model_epoch_1
Epoch 2/6: 100%|██████████| 38/38 [00:37<00:00,  1.02it/s, loss=0.411]
2025-04-16 20:41:21,161 - INFO - Epoch 2/6 - Average loss: 0.4596
2025-04-16 20:41:21,162 - INFO - Saving best model so far (loss: 0.4596) to results/student_model_epoch_2
Epoch 3/6:  61%|██████    | 23/38 [00:23<00:14,  1.00it/s, loss=0.265]2025-04-16 20:41:49,016 - INFO - Step 100: loss = 0.2654
Epoch 3/6: 100%|██████████| 38/38 [00:37<00:00,  1.02it/s, loss=0.228]
2025-04-16 20:42:02,363 - INFO - Epoch 3/6 - Average loss: 0.2502
2025-04-16 20:42:02,364 - INFO - Saving best model so far (loss: 0.2502) to results/student_model_epoch_3
Epoch 4/6: 100%|██████████| 38/38 [00:54<00:00,  1.44s/it, loss=0.13] 
2025-04-16 20:43:01,014 - INFO - Epoch 4/6 - Average loss: 0.1280
2025-04

SafetensorError: Error while serializing: IoError(Os { code: 122, kind: QuotaExceeded, message: "Disk quota exceeded" })

In [None]:
# # in case of disk/memory filling, this reloads the trained model from files

# trained_student_path = f"results/student_model_{SOLUTION_FIELD}_final"
# trained_student_model = AutoModelForCausalLM.from_pretrained(trained_student_path).to(device)
# trained_tokenizer = AutoTokenizer.from_pretrained(trained_student_path)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
print("Generating CoT examples using Trained Student model...")
trained_cot_examples = generate_dataset(
    mbpp_test_problems,
    COT_PROMPT_TEMPLATE,
    SOLUTION_FIELD,
    OUTPUT_MARKER,
    trained_student_model,
    trained_tokenizer,
    num_examples=STUDENT_EXAMPLE_LEN+1, # add 1 to not overwrite the untrained student data file
    max_new_tokens=GENERATED_TOKEN_LEN,
    teacher=False
)

# print("Evaluating CoT student model...")
# evaluation_results = evaluate_student_model(
#     student_model=student_model,
#     student_tokenizer=student_tokenizer,
#     test_problems=test_problems,
#     teacher_model=teacher_model,
#     batch_size=BATCH_SIZE,
#     max_length=GENERATED_TOKEN_LEN,
#     output_dir="results/evaluations"
# )

2025-04-16 21:06:30,498 - INFO - Generating solution_cot for 10 problems...


Generating CoT examples using Trained Student model...


Generating solution_cot:   0%|          | 0/10 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  10%|█         | 1/10 [00:09<01:21,  9.02s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.



Example 1:
Problem: Write a function to find the longest chain which can be formed from the given set of pairs....
Solution (first 150 chars): 1. Define the function `longest_chain` that takes in a list of tuples as input.
   - This function will take a set of pairs and return the length of t...


Generating solution_cot:  20%|██        | 2/10 [00:17<01:11,  8.98s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.



Example 2:
Problem: Write a python function to find the first repeated character in a given string....
Solution (first 150 chars): 1. Initialize an empty set called 'seen_chars' to store characters encountered.
2. Iterate through each character in the input string.
   - If the cha...


Generating solution_cot:  30%|███       | 3/10 [00:26<01:02,  8.89s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  40%|████      | 4/10 [00:34<00:51,  8.56s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  50%|█████     | 5/10 [00:43<00:43,  8.65s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  60%|██████    | 6/10 [00:52<00:34,  8.71s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  70%|███████   | 7/10 [01:01<00:26,  8.75s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  80%|████████  | 8/10 [01:10<00:17,  8.78s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  90%|█████████ | 9/10 [01:18<00:08,  8.79s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end gene

## Coder Agent

In [None]:
clear_gpu_memory()

# CoT Agent Params
SOLUTION_FIELD = "gen_code"
OUTPUT_MARKER = "Python code:"

print("Loading Code models...")
teacher_model, teacher_tokenizer, student_model, student_tokenizer = load_models("Qwen/Qwen2.5-Coder-7B-Instruct", "Qwen/Qwen2.5-Coder-0.5B-Instruct")

2025-04-16 20:38:46,835 - INFO - Loading teacher model: Qwen/Qwen2.5-7B-Instruct


Loading CoT models...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2025-04-16 20:38:50,376 - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

2025-04-16 20:39:04,206 - INFO - Teacher model loaded successfully
2025-04-16 20:39:04,207 - INFO - Loading student model: Qwen/Qwen2.5-0.5B
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2025-04-16 20:39:05,832 - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
2025-04-16 20:39:09,539 - INFO - Student model loaded successfully


Loading CoT dataset...
Loaded 374 train problems and 500 evaluation problems from MBPP dataset


In [None]:
clear_gpu_memory()

print("Generating Code examples using Teacher model...")
train_code_examples = generate_dataset(
    train_cot_examples,
    CODER_PROMPT_TEMPLATE,
    SOLUTION_FIELD,
    OUTPUT_MARKER,
    teacher_model,
    teacher_tokenizer,
    num_examples=TEACHER_EXAMPLE_LEN,
    max_new_tokens=GENERATED_TOKEN_LEN
)

print("Generating Code examples using untrained Student model...")
untrained_code_examples = generate_dataset(
    trained_cot_examples,
    CODER_PROMPT_TEMPLATE,
    SOLUTION_FIELD,
    OUTPUT_MARKER,
    student_model,
    student_tokenizer,
    num_examples=STUDENT_EXAMPLE_LEN,
    max_new_tokens=GENERATED_TOKEN_LEN,
    teacher=False
)

2025-04-16 17:48:58,026 - INFO - Generating solution_cot for 374 problems...


Generating CoT examples using Teacher model...


Generating solution_cot:   0%|          | 1/374 [00:12<1:16:38, 12.33s/it]


Example 1:
Problem: Write a function to find the longest chain which can be formed from the given set of pairs....
Solution (first 150 chars): 1. Define a function called "longest_chain" that takes in a list of tuples as input.
    - This function will take the set of pairs and return the len...


Generating solution_cot:   1%|          | 2/374 [00:24<1:14:11, 11.97s/it]


Example 2:
Problem: Write a python function to find the first repeated character in a given string....
Solution (first 150 chars): 1. Define a function called "first_repeated_char" that takes one parameter, 'input_string'.
2. Create an empty set called 'seen_chars' to store charac...


Generating solution_cot:   2%|▏         | 9/374 [01:46<1:11:17, 11.72s/it]2025-04-16 17:50:55,736 - INFO - Generated 10/374 solutions
Generating solution_cot:   5%|▌         | 19/374 [03:43<1:09:18, 11.72s/it]2025-04-16 17:52:52,856 - INFO - Generated 20/374 solutions
Generating solution_cot:   8%|▊         | 29/374 [05:38<1:06:55, 11.64s/it]2025-04-16 17:54:47,865 - INFO - Generated 30/374 solutions
Generating solution_cot:  10%|█         | 39/374 [07:32<1:03:24, 11.36s/it]2025-04-16 17:56:41,886 - INFO - Generated 40/374 solutions
Generating solution_cot:  13%|█▎        | 49/374 [09:28<1:03:21, 11.70s/it]2025-04-16 17:58:38,566 - INFO - Generated 50/374 solutions
Generating solution_cot:  16%|█▌        | 59/374 [11:23<57:08, 10.88s/it]  2025-04-16 18:00:32,835 - INFO - Generated 60/374 solutions
Generating solution_cot:  18%|█▊        | 69/374 [13:20<59:19, 11.67s/it]2025-04-16 18:02:29,779 - INFO - Generated 70/374 solutions
Generating solution_cot:  21%|██        | 79/374 [15:17<57

Generating CoT examples using Student model...


Generating solution_cot:   0%|          | 0/5 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  20%|██        | 1/5 [00:05<00:20,  5.02s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.



Example 1:
Problem: Write a function to find the longest chain which can be formed from the given set of pairs....
Solution (first 150 chars): 1. Initialize an empty dictionary to store the longest chains for each pair.
2. Iterate through the given set of pairs.
3. For each pair, find all pos...


Generating solution_cot:  40%|████      | 2/5 [00:09<00:14,  4.97s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.



Example 2:
Problem: Write a python function to find the first repeated character in a given string....
Solution (first 150 chars): 1. Initialize a dictionary called 'char_count' with all characters of the string as keys, and their respective counts as values.
2. Iterate through th...


Generating solution_cot:  60%|██████    | 3/5 [00:14<00:09,  4.84s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  80%|████████  | 4/5 [00:18<00:04,  4.53s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot: 100%|██████████| 5/5 [00:27<00:00,  5.50s/it]
2025-04-16 19:01:30,006 - INFO - Successfully generated 5 solutions with logits
2025-04-16 19:01:30,012 - INFO - Dataset saved to solution_cot_dataset.json


In [None]:
# # in case of disk/memory filling, this reloads the examples from json

# clear_gpu_memory()

# cot_examples_file = open(f"{SOLUTION_FIELD}_teacher_{TEACHER_EXAMPLE_LEN}_dataset.json")
# train_code_examples = json.load(cot_examples_file)

# print(train_code_examples[0])

{'problem': 'Write a function to find the longest chain which can be formed from the given set of pairs.', 'test_case': ['assert max_chain_length([Pair(5, 24), Pair(15, 25),Pair(27, 40), Pair(50, 60)], 4) == 3', 'assert max_chain_length([Pair(1, 2), Pair(3, 4),Pair(5, 6), Pair(7, 8)], 4) == 4', 'assert max_chain_length([Pair(19, 10), Pair(11, 12),Pair(13, 14), Pair(15, 16), Pair(31, 54)], 5) == 5'], 'code': 'class Pair(object): \r\n\tdef __init__(self, a, b): \r\n\t\tself.a = a \r\n\t\tself.b = b \r\ndef max_chain_length(arr, n): \r\n\tmax = 0\r\n\tmcl = [1 for i in range(n)] \r\n\tfor i in range(1, n): \r\n\t\tfor j in range(0, i): \r\n\t\t\tif (arr[i].a > arr[j].b and\r\n\t\t\t\tmcl[i] < mcl[j] + 1): \r\n\t\t\t\tmcl[i] = mcl[j] + 1\r\n\tfor i in range(n): \r\n\t\tif (max < mcl[i]): \r\n\t\t\tmax = mcl[i] \r\n\treturn max', 'solution_cot': '1. Define a function called "longest_chain" that takes in a list of tuples as input.\n    - This function will take the set of pairs and return th

In [None]:
clear_gpu_memory()

# Fine-tune the student model
print("Fine-Tuning Code Gen on student model...")
trained_student_model, trained_tokenizer = fine_tune_student_model(
    student_model=student_model,
    student_tokenizer=student_tokenizer,
    train_data=train_code_examples,
    prompt=CODER_PROMPT_TEMPLATE,
    output_field=SOLUTION_FIELD,
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    max_length=GENERATED_TOKEN_LEN
)

2025-04-16 20:40:00,172 - INFO - Starting training of student model for 6 epochs


Fine-Tuning CoT on student model...


Epoch 1/6: 100%|██████████| 38/38 [00:38<00:00,  1.01s/it, loss=0.688]
2025-04-16 20:40:40,037 - INFO - Epoch 1/6 - Average loss: 0.7380
2025-04-16 20:40:40,038 - INFO - Saving best model so far (loss: 0.7380) to results/student_model_epoch_1
Epoch 2/6: 100%|██████████| 38/38 [00:37<00:00,  1.02it/s, loss=0.411]
2025-04-16 20:41:21,161 - INFO - Epoch 2/6 - Average loss: 0.4596
2025-04-16 20:41:21,162 - INFO - Saving best model so far (loss: 0.4596) to results/student_model_epoch_2
Epoch 3/6:  61%|██████    | 23/38 [00:23<00:14,  1.00it/s, loss=0.265]2025-04-16 20:41:49,016 - INFO - Step 100: loss = 0.2654
Epoch 3/6: 100%|██████████| 38/38 [00:37<00:00,  1.02it/s, loss=0.228]
2025-04-16 20:42:02,363 - INFO - Epoch 3/6 - Average loss: 0.2502
2025-04-16 20:42:02,364 - INFO - Saving best model so far (loss: 0.2502) to results/student_model_epoch_3
Epoch 4/6: 100%|██████████| 38/38 [00:54<00:00,  1.44s/it, loss=0.13] 
2025-04-16 20:43:01,014 - INFO - Epoch 4/6 - Average loss: 0.1280
2025-04

SafetensorError: Error while serializing: IoError(Os { code: 122, kind: QuotaExceeded, message: "Disk quota exceeded" })

In [None]:
# # in case of disk/memory filling, this reloads the trained model from files

# trained_student_path = f"results/student_model_{SOLUTION_FIELD}_final"
# trained_student_model = AutoModelForCausalLM.from_pretrained(trained_student_path).to(device)
# trained_tokenizer = AutoTokenizer.from_pretrained(trained_student_path)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
print("Generating CoT examples using Trained Student model...")
trained_code_examples = generate_dataset(
    trained_cot_examples,
    CODER_PROMPT_TEMPLATE,
    SOLUTION_FIELD,
    OUTPUT_MARKER,
    trained_student_model,
    trained_tokenizer,
    num_examples=STUDENT_EXAMPLE_LEN+1, # add 1 to not overwrite the untrained student data file
    max_new_tokens=GENERATED_TOKEN_LEN,
    teacher=False
)

# print("Evaluating CoT student model...")
# evaluation_results = evaluate_student_model(
#     student_model=student_model,
#     student_tokenizer=student_tokenizer,
#     test_problems=test_problems,
#     teacher_model=teacher_model,
#     batch_size=BATCH_SIZE,
#     max_length=GENERATED_TOKEN_LEN,
#     output_dir="results/evaluations"
# )

2025-04-16 21:06:30,498 - INFO - Generating solution_cot for 10 problems...


Generating CoT examples using Trained Student model...


Generating solution_cot:   0%|          | 0/10 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  10%|█         | 1/10 [00:09<01:21,  9.02s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.



Example 1:
Problem: Write a function to find the longest chain which can be formed from the given set of pairs....
Solution (first 150 chars): 1. Define the function `longest_chain` that takes in a list of tuples as input.
   - This function will take a set of pairs and return the length of t...


Generating solution_cot:  20%|██        | 2/10 [00:17<01:11,  8.98s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.



Example 2:
Problem: Write a python function to find the first repeated character in a given string....
Solution (first 150 chars): 1. Initialize an empty set called 'seen_chars' to store characters encountered.
2. Iterate through each character in the input string.
   - If the cha...


Generating solution_cot:  30%|███       | 3/10 [00:26<01:02,  8.89s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  40%|████      | 4/10 [00:34<00:51,  8.56s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  50%|█████     | 5/10 [00:43<00:43,  8.65s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  60%|██████    | 6/10 [00:52<00:34,  8.71s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  70%|███████   | 7/10 [01:01<00:26,  8.75s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  80%|████████  | 8/10 [01:10<00:17,  8.78s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating solution_cot:  90%|█████████ | 9/10 [01:18<00:08,  8.79s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end gene

# Extras

In [54]:
def extract_problem_description(source_code):
    """
    Extracts the problem description from the first docstring in the source code,
    whether it's enclosed in triple double quotes or triple single quotes.
    """
    docstring_pattern = re.compile(r'("""|\'\'\')(.*?)(\1)', re.DOTALL)
    match = docstring_pattern.search(source_code)
    
    if match:
        description = match.group(2)
        # Clean up leading/trailing whitespace on each line
        cleaned_lines = [line.strip() for line in description.strip().splitlines() if line.strip()]
        return ' '.join(cleaned_lines)

    raise Exception(f"Error: Unable to extract problem description. Please check the format of the prompt:\n{source_code}")
    return None

def extract_code_header(source_code):
    """
    Extracts everything from the beginning of the source code up to 
    the first occurrence of either triple single quotes or triple double quotes.
    """
    # Match from start of string to the first triple quotes (single or double)
    docstring_pattern = re.compile(r'^(.*?)(?="""|\'\'\')', re.DOTALL)
    match = docstring_pattern.search(source_code)
    
    if match:
        header = match.group(1)
        # Clean up leading/trailing whitespace on each line
        cleaned_lines = [line.strip() for line in header.strip().splitlines() if line.strip()]
        return ' '.join(cleaned_lines)
    raise Exception(f"Error: Unable to extract code header. Please check the format of the prompt:\n{source_code}")
    return None

def load_human_eval_dataset():
    human_eval = load_dataset("openai_humaneval")

    train_problems = []
    # Extract problems from the MBPP dataset with correct field names
    for item in human_eval["test"]:
        train_problems.append({
            "problem": extract_problem_description(item["prompt"]),
            "code_header": extract_code_header(item["prompt"]),
            "test_case": item["prompt"],
            "solution_code": item["prompt"] + item["canonical_solution"]
        })
    return train_problems

COT_PROMPT_TEMPLATE = """Generate a detailed step-by-step solution for this coding problem.
Break down your thought process clearly, explaining your reasoning while considering:
- What are the inputs and outputs of the function?
- What algorithm or data structure is most appropriate?
- Are there any edge cases to handle?
- What's the efficiency of your approach?

Be concise in your explanation.

Problem:
{problem}

Step-by-step solution:"""

# CODER_PROMPT_TEMPLATE = """Generate only a markdown code block that contains clean, efficient 
# Python code for this coding problem based on the solution approach. The code block must start
# with ```python on its own line, then the code, and end with ``` on its own line.
# Focus on:
# - Implementing the key algorithmic insights
# - Handling edge cases identified in the solution
# - Maintaining readability and efficiency
# Do not include:
# - test cases
# - extra code explanation

# Step-by-step solution:
# {cot_solution}

# Python code:
# {code_header}"""

CODER_PROMPT_TEMPLATE = """Generate only a markdown code block that contains clean, efficient 
Python code for this coding problem based on the solution approach. The code block must start
with ```python on its own line, then the code, and end with ``` on its own line. Do not include
test cases or code explanations.
Focus on:
- Implementing the key algorithmic insights
- Handling edge cases identified in the solution
- Maintaining readability and efficiency

Step-by-step solution:
{cot_solution}

Python code:
{code_header}"""


human_eval = load_human_eval_dataset()
print(human_eval[0])

print("loaded dataset")

trained_cot_student_path = f"results/student_model_cot_solution_final"
trained_cot_student_model = AutoModelForCausalLM.from_pretrained(trained_cot_student_path).to(device)
trained_cot_tokenizer = AutoTokenizer.from_pretrained(trained_cot_student_path)

untrained_coder_model_name = "Qwen/Qwen2.5-Coder-0.5B"
untrained_coder_tokenizer = AutoTokenizer.from_pretrained(untrained_coder_model_name)
untrained_coder_model = AutoModelForCausalLM.from_pretrained(
    untrained_coder_model_name,
    device_map="auto",
    torch_dtype=torch.float32
)

print("Loaded models")

trained_cot_examples = generate_dataset(
    human_eval,
    COT_PROMPT_TEMPLATE,
    "cot_solution",
    "Step-by-step solution:",
    trained_cot_student_model,
    trained_cot_tokenizer,
    num_examples=100,
    max_new_tokens=512,
    teacher=False
)

print("cot examples generated")

code_examples = generate_dataset(
    trained_cot_examples,
    CODER_PROMPT_TEMPLATE,
    "gen_code",
    "Python code:",
    untrained_coder_model,
    untrained_coder_tokenizer,
    num_examples=100,
    max_new_tokens=512,
    teacher=False
)

print("code generated")



{'problem': 'Check if in given list of numbers, are any two numbers closer to each other than given threshold. >>> has_close_elements([1.0, 2.0, 3.0], 0.5) False >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3) True', 'code_header': 'from typing import List def has_close_elements(numbers: List[float], threshold: float) -> bool:', 'test_case': 'from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n    """ Check if in given list of numbers, are any two numbers closer to each other than\n    given threshold.\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n    False\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n    True\n    """\n', 'solution_code': 'from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n    """ Check if in given list of numbers, are any two numbers closer to each other than\n    given threshold.\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


tokenizer_config.json:   0%|          | 0.00/7.23k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/659 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

2025-04-17 04:00:18,898 - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


generation_config.json:   0%|          | 0.00/139 [00:00<?, ?B/s]

2025-04-17 04:00:19,962 - INFO - Generating cot_solution with student for 100 problems...


Loaded models


Generating cot_solution:   0%|          | 0/100 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating cot_solution:   1%|          | 1/100 [00:09<15:14,  9.24s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.



Example 1:
Problem: Check if in given list of numbers, are any two numbers closer to each other than given threshold. >>> has_close_elements([1.0, 2.0, 3.0], 0.5) False >...
Solution (first 150 chars): 1. Understand the Problem:
   - We need to write a function that takes two lists of numbers and a threshold value as inputs.
   - The function should ...


Generating cot_solution:   2%|▏         | 2/100 [00:18<15:05,  9.24s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.



Example 2:
Problem: Input to this function is a string containing multiple groups of nested parentheses. Your goal is to separate those group into separate strings and re...
Solution (first 150 chars): 1. Define the function with an input parameter for the string to be processed.
   - This allows the function to be reusable for different strings.
   ...


Generating cot_solution:   3%|▎         | 3/100 [00:27<14:56,  9.24s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating cot_solution:   4%|▍         | 4/100 [00:36<14:48,  9.25s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating cot_solution:   5%|▌         | 5/100 [00:46<14:38,  9.25s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating cot_solution:   6%|▌         | 6/100 [00:55<14:29,  9.25s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating cot_solution:   7%|▋         | 7/100 [01:04<14:20,  9.25s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating cot_solution:   8%|▊         | 8/100 [01:13<14:11,  9.25s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating cot_solution:   9%|▉         | 9/100 [01:23<14:02,  9.25s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-e

cot examples generated


Generating gen_code:   0%|          | 0/100 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating gen_code:   1%|          | 1/100 [00:04<07:15,  4.40s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.



Example 1:
Problem: Check if in given list of numbers, are any two numbers closer to each other than given threshold. >>> has_close_elements([1.0, 2.0, 3.0], 0.5) False >...
Solution (first 150 chars): from typing import List def has_close_elements(numbers: List[float], threshold: float) -> bool: # Check if the length of numbers is equal to threshold...


Generating gen_code:   2%|▏         | 2/100 [00:13<11:50,  7.25s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.



Example 2:
Problem: Input to this function is a string containing multiple groups of nested parentheses. Your goal is to separate those group into separate strings and re...
Solution (first 150 chars): from typing import List def separate_paren_groups(paren_string: str) -> List[str]: stack = [] result = [] for char in paren_string: if char == '(': st...


Generating gen_code:   3%|▎         | 3/100 [00:16<08:24,  5.20s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating gen_code:   4%|▍         | 4/100 [00:20<07:26,  4.65s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating gen_code:   5%|▌         | 5/100 [00:24<07:16,  4.60s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating gen_code:   6%|▌         | 6/100 [00:26<05:28,  3.50s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating gen_code:   7%|▋         | 7/100 [00:29<05:22,  3.47s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating gen_code:   8%|▊         | 8/100 [00:33<05:40,  3.70s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Generating gen_code:   9%|▉         | 9/100 [00:35<04:34,  3.02s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
2025-04-17 04

code generated


In [97]:
i = 22

#print('\nproblem:')
#print(new_code_examples[i]['problem'])
#print('\ncot')
#print(new_code_examples[i]['cot_solution'])
#print('\ngenerated_code')
print(new_code_examples[i]['gen_code'])

def filter_integers(values: List[Any]) -> List[int]:  # Corrected function name
    return [value for value in values if isinstance(value, int)]

# Test cases
print(filter_integers([1, 2, 'a', 3, 4.5, 5]))  # Should print [1, 2, 4, 5]
print(filter_integers(['a', 100, 'b', 200, 300]))  # Should print ['a', 200, 300]
print(filter_integers([]))  # Should print []
print(filter_integers([100, 200, 300]))  # Should print [100, 200, 300]
```

This solution provides a clean, efficient Python function that filters out non-integer values from a given list. The function is tested with various inputs to ensure its correctness. The solution is optimized for performance by using list comprehension, which is generally faster than a loop for filtering large lists.


In [160]:
def extract_before_def(source_code):
    """
    Extracts everything from the beginning of the source code up to 
    but not including the first occurrence of the 'def' keyword.
    Preserves original formatting.
    """
    pattern = re.compile(r'^(.*?)(?=def)', re.DOTALL)
    match = pattern.search(source_code)
    
    if match:
        return match.group(1)
    raise Exception(f"Error: Unable to extract content before 'def'. No 'def' keyword found in:\n{source_code}")
    return None

def extract_until_code_block(source_code):
    """
    Extracts everything from the beginning of the string up to 
    but not including the first occurrence of three backticks (```).
    Preserves original formatting.
    """
    pattern = re.compile(r'^(.*?)(?=```)', re.DOTALL)
    match = pattern.search(source_code)
    
    if match:
        return match.group(1)
    return 'BAD'

solutions = [item['solution_code'] for item in new_code_examples]
generated_codes = [item['gen_code'] for item in new_code_examples]
for i, generated_code in enumerate(generated_codes):
    generated_codes[i] = extract_before_def(solutions[i]) + extract_until_code_block(generated_codes[i])
print(generated_codes[0])

def remove_bad_strings(string_array):
    """
    Removes any strings containing 'BAD' from the given array.
    Also prints the indices of removed strings.
    
    Args:
        string_array: A list of strings to filter
        
    Returns:
        A new list with all strings containing 'BAD' removed
    """
    bad_indices = []
    clean_strings = []
    
    for i, s in enumerate(string_array):
        if 'BAD' in s:
            bad_indices.append(i)
        else:
            clean_strings.append(s)
    
    # Print the indices of bad strings
    if bad_indices:
        print(f"Found 'BAD' in strings at indices: {bad_indices}")
    else:
        print("No strings containing 'BAD' found.")
    
    return clean_strings, bad_indices

edited_codes, bad_indices = remove_bad_strings(generated_codes)
print(len(edited_codes))

from typing import List


BAD
Found 'BAD' in strings at indices: [0, 19, 36, 51, 66, 72, 74]
93


In [167]:
human_eval['test']

TypeError: list indices must be integers or slices, not str

In [163]:
%pip install evaluate

from evaluate import load

# Load evaluation metric
code_eval = load("code_eval")

import os
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

problems = human_eval
test = []
for i, item in enumerate(edited_codes):
    edited_codes[i] = [item]
pred = edited_codes
c = 0

for i, s in enumerate(human_eval[:100]):
    if i not in bad_indices:
        test.append(s)
        c = c+1
        print(c)

pass_at_k = code_eval.compute(
        predictions=pred,
        references=test,
        k=[1]
)
print(pass_at_k)
print(pass_at_k[0]['pass@1']*100)  

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93


ValueError: Predictions and/or references don't match the expected format.
Expected format: {'predictions': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'references': Value(dtype='string', id=None)},
Input predictions: [['from typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]: \n    result = [] \n    result_set = set() \n    i = 0 \n    while i < len(paren_string): \n        if paren_string[i] == \'(\':\n            result_set.add(paren_string[i]) \n            result.append(paren_string[i]) \n            i += 1 \n        elif paren_string[i] == \')\': \n            if len(result_set) > 0: \n                last_char = result_set.pop() \n                result.append(last_char) \n            i += 1 \n    return result \n\n# Test the function with various inputs \nprint(separate_paren_groups("(()())"))  # Output: [\'()\', \'()\', \'()]\']\nprint(separate_paren_groups("[]")          # Output: [\'[]\']\nprint(separate_paren_groups("((((()))))))")) # Output: [[\'((()))\'], [\'())\']]\nprint(separate_paren_groups("[][][]"))     # Output: [[\'[][]\'], [\'[]\']]\nprint(separate_paren_groups("[][][]")       # Output: [[\'[]\'], [\'[][]\']]\nprint(separate_paren_groups(""))            # Output: []\n'], ['\n\ndef truncate_number(number: float) -> float: \n    """Truncates a number to its integer part and returns the decimal part."""\n    int_part = int(number)\n    decimal_part = number - int_part\n    return decimal_part\n'], ['from typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:    \n    balance = 0\n    \n    for deposit, withdrawal in operations:\n        if withdrawal > deposit:\n            return True\n        \n        balance += deposit\n    \n    return False\n'], ..., ['\ndef multiply(a, b): \n    # Edge case handling\n    if a == 0 or b == 0:\n        return 0\n    \n    # Initialize the product variable\n    product = 0\n    \n    # Iterate through possible pairs of digits\n    for i in range(max(a, b)): \n        # Calculate the product of the unit digits\n        product += (i % 10) * (b // 10)\n    \n    # Return the total product\n    return product\n'], ["\ndef count_upper(s): \n    # Initialize the count of uppercase vowels \n    upper_count = 0 \n  \n    # Loop through the string \n    for i in range(len(s)): \n        # Check if the character is an uppercase vowel \n        if 'a' <= s[i] <= 'z': \n            # Check if the index is even \n            if i % 2 == 0: \n                # Check if the character is an uppercase vowel \n                if s[i] == 'a' or s[i] == 'e' or s[i] == 'i' or s[i] == 'o' or s[i] == 'u': \n                    upper_count += 1 \n  \n    # Return the count of uppercase vowels \n    return upper_count\n\n# Test cases \nprint(count_upper('aBCdEf'))  # Output: 1\nprint(count_upper('abcdefg'))  # Output: 0\nprint(count_upper('dBBE'))  # Output: 0\n"], ['\ndef closest_integer(value): \n    if isinstance(value, str):\n        value = float(value)\n    \n    if value.is_integer():\n        return int(value)\n    \n    distance = abs(value - int(value))\n    rounded_distance = round(distance)\n    \n    return int(rounded_distance)\n']],
Input references: [{'problem': "Input to this function is a string containing multiple groups of nested parentheses. Your goal is to separate those group into separate strings and return the list of those. Separate groups are balanced (each open brace is properly closed) and not nested within each other Ignore any spaces in the input string. >>> separate_paren_groups('( ) (( )) (( )( ))') ['()', '(())', '(()())']", 'code_header': 'from typing import List def separate_paren_groups(paren_string: str) -> List[str]:', 'test_case': 'from typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n    """ Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n    separate those group into separate strings and return the list of those.\n    Separate groups are balanced (each open brace is properly closed) and not nested within each other\n    Ignore any spaces in the input string.\n    >>> separate_paren_groups(\'( ) (( )) (( )( ))\')\n    [\'()\', \'(())\', \'(()())\']\n    """\n', 'solution_code': 'from typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n    """ Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n    separate those group into separate strings and return the list of those.\n    Separate groups are balanced (each open brace is properly closed) and not nested within each other\n    Ignore any spaces in the input string.\n    >>> separate_paren_groups(\'( ) (( )) (( )( ))\')\n    [\'()\', \'(())\', \'(()())\']\n    """\n    result = []\n    current_string = []\n    current_depth = 0\n\n    for c in paren_string:\n        if c == \'(\':\n            current_depth += 1\n            current_string.append(c)\n        elif c == \')\':\n            current_depth -= 1\n            current_string.append(c)\n\n            if current_depth == 0:\n                result.append(\'\'.join(current_string))\n                current_string.clear()\n\n    return result\n'}, {'problem': 'Given a positive floating point number, it can be decomposed into and integer part (largest integer smaller than given number) and decimals (leftover part always smaller than 1). Return the decimal part of the number. >>> truncate_number(3.5) 0.5', 'code_header': 'def truncate_number(number: float) -> float:', 'test_case': '\n\ndef truncate_number(number: float) -> float:\n    """ Given a positive floating point number, it can be decomposed into\n    and integer part (largest integer smaller than given number) and decimals\n    (leftover part always smaller than 1).\n\n    Return the decimal part of the number.\n    >>> truncate_number(3.5)\n    0.5\n    """\n', 'solution_code': '\n\ndef truncate_number(number: float) -> float:\n    """ Given a positive floating point number, it can be decomposed into\n    and integer part (largest integer smaller than given number) and decimals\n    (leftover part always smaller than 1).\n\n    Return the decimal part of the number.\n    >>> truncate_number(3.5)\n    0.5\n    """\n    return number % 1.0\n'}, {'problem': "You're given a list of deposit and withdrawal operations on a bank account that starts with zero balance. Your task is to detect if at any point the balance of account fallls below zero, and at that point function should return True. Otherwise it should return False. >>> below_zero([1, 2, 3]) False >>> below_zero([1, 2, -4, 5]) True", 'code_header': 'from typing import List def below_zero(operations: List[int]) -> bool:', 'test_case': 'from typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n    """ You\'re given a list of deposit and withdrawal operations on a bank account that starts with\n    zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n    at that point function should return True. Otherwise it should return False.\n    >>> below_zero([1, 2, 3])\n    False\n    >>> below_zero([1, 2, -4, 5])\n    True\n    """\n', 'solution_code': 'from typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n    """ You\'re given a list of deposit and withdrawal operations on a bank account that starts with\n    zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n    at that point function should return True. Otherwise it should return False.\n    >>> below_zero([1, 2, 3])\n    False\n    >>> below_zero([1, 2, -4, 5])\n    True\n    """\n    balance = 0\n\n    for op in operations:\n        balance += op\n        if balance < 0:\n            return True\n\n    return False\n'}, ..., {'problem': 'Complete the function that takes two integers and returns the product of their unit digits. Assume the input is always valid. Examples: multiply(148, 412) should return 16. multiply(19, 28) should return 72. multiply(2020, 1851) should return 0. multiply(14,-15) should return 20.', 'code_header': 'def multiply(a, b):', 'test_case': '\ndef multiply(a, b):\n    """Complete the function that takes two integers and returns \n    the product of their unit digits.\n    Assume the input is always valid.\n    Examples:\n    multiply(148, 412) should return 16.\n    multiply(19, 28) should return 72.\n    multiply(2020, 1851) should return 0.\n    multiply(14,-15) should return 20.\n    """\n', 'solution_code': '\ndef multiply(a, b):\n    """Complete the function that takes two integers and returns \n    the product of their unit digits.\n    Assume the input is always valid.\n    Examples:\n    multiply(148, 412) should return 16.\n    multiply(19, 28) should return 72.\n    multiply(2020, 1851) should return 0.\n    multiply(14,-15) should return 20.\n    """\n    return abs(a % 10) * abs(b % 10)\n'}, {'problem': "Given a string s, count the number of uppercase vowels in even indices. For example: count_upper('aBCdEf') returns 1 count_upper('abcdefg') returns 0 count_upper('dBBE') returns 0", 'code_header': 'def count_upper(s):', 'test_case': '\ndef count_upper(s):\n    """\n    Given a string s, count the number of uppercase vowels in even indices.\n    \n    For example:\n    count_upper(\'aBCdEf\') returns 1\n    count_upper(\'abcdefg\') returns 0\n    count_upper(\'dBBE\') returns 0\n    """\n', 'solution_code': '\ndef count_upper(s):\n    """\n    Given a string s, count the number of uppercase vowels in even indices.\n    \n    For example:\n    count_upper(\'aBCdEf\') returns 1\n    count_upper(\'abcdefg\') returns 0\n    count_upper(\'dBBE\') returns 0\n    """\n    count = 0\n    for i in range(0,len(s),2):\n        if s[i] in "AEIOU":\n            count += 1\n    return count\n'}, {'problem': 'Create a function that takes a value (string) representing a number and returns the closest integer to it. If the number is equidistant from two integers, round it away from zero. Examples >>> closest_integer("10") 10 >>> closest_integer("15.3") 15 Note: Rounding away from zero means that if the given number is equidistant from two integers, the one you should return is the one that is the farthest from zero. For example closest_integer("14.5") should return 15 and closest_integer("-14.5") should return -15.', 'code_header': 'def closest_integer(value):', 'test_case': '\ndef closest_integer(value):\n    \'\'\'\n    Create a function that takes a value (string) representing a number\n    and returns the closest integer to it. If the number is equidistant\n    from two integers, round it away from zero.\n\n    Examples\n    >>> closest_integer("10")\n    10\n    >>> closest_integer("15.3")\n    15\n\n    Note:\n    Rounding away from zero means that if the given number is equidistant\n    from two integers, the one you should return is the one that is the\n    farthest from zero. For example closest_integer("14.5") should\n    return 15 and closest_integer("-14.5") should return -15.\n    \'\'\'\n', 'solution_code': '\ndef closest_integer(value):\n    \'\'\'\n    Create a function that takes a value (string) representing a number\n    and returns the closest integer to it. If the number is equidistant\n    from two integers, round it away from zero.\n\n    Examples\n    >>> closest_integer("10")\n    10\n    >>> closest_integer("15.3")\n    15\n\n    Note:\n    Rounding away from zero means that if the given number is equidistant\n    from two integers, the one you should return is the one that is the\n    farthest from zero. For example closest_integer("14.5") should\n    return 15 and closest_integer("-14.5") should return -15.\n    \'\'\'\n    from math import floor, ceil\n\n    if value.count(\'.\') == 1:\n        # remove trailing zeros\n        while (value[-1] == \'0\'):\n            value = value[:-1]\n\n    num = float(value)\n    if value[-2:] == \'.5\':\n        if num > 0:\n            res = ceil(num)\n        else:\n            res = floor(num)\n    elif len(value) > 0:\n        res = int(round(num))\n    else:\n        res = 0\n\n    return res\n\n'}]

In [None]:
# Generate CoT (Chain of Thought) dataset
cot_examples = generate_dataset(
    problem_dataset=mbpp_problems,
    task_prompt=COT_PROMPT_TEMPLATE,
    solution_field="solution_cot",
    output_marker="Step-by-step solution:",
    teacher_model=teacher_model,
    teacher_tokenizer=teacher_tokenizer,
    num_examples=50,
    output_file="datasets/cot_dataset.json"
)

# Generate code dataset from CoT
code_examples = generate_dataset(
    problem_dataset=cot_examples,  # Use the output from CoT as input
    task_prompt=DEVELOPER_PROMPT_TEMPLATE,
    solution_field="code",
    output_marker="Python code:",
    teacher_model=teacher_model,
    teacher_tokenizer=teacher_tokenizer,
    num_examples=50,
    output_file="datasets/code_dataset.json"
)

# Generate debugged code dataset
debugged_examples = generate_dataset(
    problem_dataset=code_examples,  # Use the code examples as input
    task_prompt=DEBUGGER_PROMPT_TEMPLATE,
    solution_field="debugged_code",
    output_marker="Debugged Python code:",
    teacher_model=teacher_model,
    teacher_tokenizer=teacher_tokenizer,
    num_examples=50,
    output_file="datasets/debugged_code_dataset.json"
)

# Generate code explanations
explanation_examples = generate_dataset(
    problem_dataset=code_examples,  # Use code examples that also have CoT
    task_prompt=EXPLAINER_PROMPT_TEMPLATE,
    solution_field="explanation",
    output_marker="Explanation of the code:",
    teacher_model=teacher_model,
    teacher_tokenizer=teacher_tokenizer,
    num_examples=50,
    output_file="datasets/explanation_dataset.json"
)

In [None]:
for i, example in enumerate(mbpp_problems):
  print(f"Problem number: {i}")
  print(f"Problem: {example['problem']}")
  print("Test cases:")
  print(example['test_case'])
  print("Code Solution:")
  print(example['solution'])