# Fine-tuning HaiJava-Surgeon with Dynamic Prompt Templates

This notebook implements fine-tuning for the `Haiintel/HaiJava-Surgeon-Qwen2.5-Coder-7B-SFT-v1` model using the `code_x_glue_cc_code_refinement` dataset from Hugging Face.

## Key Features:
- **Dataset**: CodeXGLUE Code Refinement (small subset)
- **Dynamic Prompts**: 5 different prompt templates randomly selected during preprocessing
- **Training**: 3 epochs with LoRA fine-tuning
- **Format**: Chat-based format with system/user/assistant roles

## 1. Setup and Imports

In [1]:
# Install required packages (uncomment if needed)
# !pip install transformers datasets peft accelerate bitsandbytes torch
!pip install -U datasets
!pip install -U \
  transformers \
  datasets \
  accelerate \
  peft \
  bitsandbytes \
  sentencepiece \
  safetensors
!pip install -U hf_transfer




In [2]:
import os
import random
import torch
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from typing import Dict, List

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.8.0+cu128
CUDA available: False


## 2. Configuration

In [3]:
# Model and dataset configuration
MODEL_NAME = "Haiintel/HaiJava-Surgeon-Qwen2.5-Coder-7B-SFT-v1"
DATASET_NAME = "code_x_glue_cc_code_refinement"
DATASET_SUBSET = "small"  # Using SMALL subset (‚â§50 tokens)
DATASET_SPLIT = "train"   # Using TRAIN split only

# Training configuration
NUM_EPOCHS = 1
BATCH_SIZE = 1
LEARNING_RATE = 2e-4
MAX_LENGTH = 750
OUTPUT_DIR = "./models/haijava_dynamic_prompts"

# LoRA configuration
LORA_R = 64
LORA_ALPHA = 16
LORA_DROPOUT = 0.1
LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

print("Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Dataset: {DATASET_NAME} ({DATASET_SUBSET} subset, {DATASET_SPLIT} split)")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Output Directory: {OUTPUT_DIR}")

Configuration:
  Model: Haiintel/HaiJava-Surgeon-Qwen2.5-Coder-7B-SFT-v1
  Dataset: code_x_glue_cc_code_refinement (small subset, train split)
  Epochs: 1
  Batch Size: 1
  Learning Rate: 0.0002
  Output Directory: ./models/haijava_dynamic_prompts


## 3. Define 5 Dynamic Prompt Templates

These templates will be randomly selected during preprocessing for each sample.

In [4]:
def get_prompt_templates() -> List[Dict]:
    """
    Returns 5 different prompt templates for dynamic selection.
    Each template follows the chat format with system/user/assistant roles.
    """
    templates = [
        # Template 1: Detailed expert instruction
        {
            "name": "detailed_expert",
            "system": "You are a Java code fixing assistant.",
            "user_template": "You are a Java expert. Fix the following buggy Java code. Correct syntax errors, logic bugs, runtime errors, and code quality issues. Output ONLY the corrected code.\n\n{buggy_code}"
        },
        # Template 2: Simple instruction
        {
            "name": "simple",
            "system": "You are a Java code fixing assistant.",
            "user_template": "Fix the following Java code.\n\n{buggy_code}"
        },
        # Template 3: Minimal (code only)
        {
            "name": "minimal",
            "system": "You are a Java code fixing assistant.",
            "user_template": "{buggy_code}"
        },
        # Template 4: Structured format
        {
            "name": "structured",
            "system": "You are a Java code fixing assistant.",
            "user_template": "### Buggy Code\n{buggy_code}\n\n### Corrected Code"
        },
        # Template 5: Line-by-line format
        {
            "name": "line_by_line",
            "system": "You are a Java code fixing assistant.",
            "user_template": "Analyze and fix the following buggy Java code line by line:\n\n{buggy_code}\n\nProvide the corrected version:"
        }
    ]
    return templates

# Display templates
templates = get_prompt_templates()
print(f"\nDefined {len(templates)} prompt templates:")
for i, template in enumerate(templates, 1):
    print(f"\n{i}. {template['name'].upper()}")
    print(f"   System: {template['system']}")
    print(f"   User: {template['user_template'][:100]}...")


Defined 5 prompt templates:

1. DETAILED_EXPERT
   System: You are a Java code fixing assistant.
   User: You are a Java expert. Fix the following buggy Java code. Correct syntax errors, logic bugs, runtime...

2. SIMPLE
   System: You are a Java code fixing assistant.
   User: Fix the following Java code.

{buggy_code}...

3. MINIMAL
   System: You are a Java code fixing assistant.
   User: {buggy_code}...

4. STRUCTURED
   System: You are a Java code fixing assistant.
   User: ### Buggy Code
{buggy_code}

### Corrected Code...

5. LINE_BY_LINE
   System: You are a Java code fixing assistant.
   User: Analyze and fix the following buggy Java code line by line:

{buggy_code}

Provide the corrected ver...


## 4. Load Dataset

Loading the CodeXGLUE Code Refinement dataset (small subset, train split only).

In [5]:
print(f"Loading dataset: {DATASET_NAME} ({DATASET_SUBSET} subset)...")
print(f"Using split: {DATASET_SPLIT}\n")

# Load dataset
dataset = load_dataset(
    DATASET_NAME,
    DATASET_SUBSET,
    split=DATASET_SPLIT
)

print(f"‚úÖ Dataset loaded successfully!")
print(f"   Total samples: {len(dataset)}")
print(f"   Features: {dataset.features}")

# Show a sample
print("\nüìã Sample from dataset:")
sample = dataset[0]
print(f"\nBuggy code (first 200 chars):\n{sample['buggy'][:200]}...")
print(f"\nFixed code (first 200 chars):\n{sample['fixed'][:200]}...")

Loading dataset: code_x_glue_cc_code_refinement (small subset)...
Using split: train

‚úÖ Dataset loaded successfully!
   Total samples: 46680
   Features: {'id': Value('int32'), 'buggy': Value('string'), 'fixed': Value('string')}

üìã Sample from dataset:

Buggy code (first 200 chars):
public java.lang.String METHOD_1 ( ) { return new TYPE_1 ( STRING_1 ) . format ( VAR_1 [ ( ( VAR_1 . length ) - 1 ) ] . getTime ( ) ) ; } 
...

Fixed code (first 200 chars):
public java.lang.String METHOD_1 ( ) { return new TYPE_1 ( STRING_1 ) . format ( VAR_1 [ ( ( type ) - 1 ) ] . getTime ( ) ) ; } 
...


## 5. Load Model and Tokenizer

In [6]:
print(f"Loading tokenizer: {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True
)

# Set padding token if not set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print(f"Set pad_token to eos_token: {tokenizer.eos_token}")

print(f"‚úÖ Tokenizer loaded successfully!")
print(f"   Vocab size: {len(tokenizer)}")
print(f"   EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")
print(f"   PAD token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")

Loading tokenizer: Haiintel/HaiJava-Surgeon-Qwen2.5-Coder-7B-SFT-v1...
‚úÖ Tokenizer loaded successfully!
   Vocab size: 151665
   EOS token: <|im_end|> (ID: 151645)
   PAD token: <|endoftext|> (ID: 151643)


In [7]:
print(f"\nLoading model: {MODEL_NAME}...")
print("This may take a few minutes...\n")

# Load model in 8-bit for memory efficiency
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True
)

model.gradient_checkpointing_enable()

model.enable_input_require_grads()


print(f"‚úÖ Model loaded successfully!")
print(f"   Model type: {model.config.model_type}")
print(f"   Hidden size: {model.config.hidden_size}")
print(f"   Num layers: {model.config.num_hidden_layers}")
print(f"   Num attention heads: {model.config.num_attention_heads}")

# Calculate model size
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n   Total parameters: {total_params:,}")
print(f"   Trainable parameters (before LoRA): {trainable_params:,}")


Loading model: Haiintel/HaiJava-Surgeon-Qwen2.5-Coder-7B-SFT-v1...
This may take a few minutes...



`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

‚úÖ Model loaded successfully!
   Model type: qwen2
   Hidden size: 3584
   Num layers: 28
   Num attention heads: 28

   Total parameters: 7,615,616,512
   Trainable parameters (before LoRA): 7,615,616,512


## 6. Apply LoRA Configuration

In [8]:
print("Applying LoRA configuration...\n")

# Configure LoRA
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=LORA_TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# Apply LoRA to model
model = get_peft_model(model, lora_config)

print("‚úÖ LoRA applied successfully!\n")
model.print_trainable_parameters()

# Calculate trainable percentage
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
trainable_percent = 100 * trainable_params / total_params
print(f"\nTrainable parameters: {trainable_params:,} ({trainable_percent:.2f}% of total)")

Applying LoRA configuration...





‚úÖ LoRA applied successfully!

trainable params: 161,480,704 || all params: 7,777,097,216 || trainable%: 2.0764

Trainable parameters: 161,480,704 (2.08% of total)


## 7. Dynamic Preprocessing Function

This function randomly selects one of the 5 prompt templates for each sample during preprocessing.

In [9]:
def preprocess_function(examples: Dict) -> Dict:
    """
    Preprocess function that dynamically applies one of 5 prompt templates.
    
    Workflow:
    1. For each sample, randomly select a template
    2. Format the buggy code using the selected template
    3. Create chat messages with system/user/assistant roles
    4. Apply tokenizer's chat template
    5. Tokenize the result
    
    Args:
        examples: Batch of examples from dataset
    
    Returns:
        Tokenized inputs ready for training
    """
    templates = get_prompt_templates()
    
    formatted_texts = []
    prompt_lengths = []
    
    # Process each example in  =the batch
    # for buggy_code, fixed_code in zip(examples['buggy'], examples['fixed']):
    #     # Randomly select a template
    #     template = random.choice(templates)
        
    #     # Format the user message with the selected template
    #     user_content = template['user_template'].format(buggy_code=buggy_code)
        
    #     # Create chat messages
    #     messages = [
    #         {"role": "system", "content": template['system']},
    #         {"role": "user", "content": user_content},
    #         {"role": "assistant", "content": fixed_code}
    #     ]
        
    #     # Apply chat template
    #     formatted_text = tokenizer.apply_chat_template(
    #         messages,
    #         tokenize=False,
    #         add_generation_prompt=False
    #     )
        
    #     formatted_texts.append(formatted_text)


    for buggy_code, fixed_code in zip(examples["buggy"], examples["fixed"]):
        template = random.choice(templates)
    
        # Build prompt (NO chat roles)
        prompt = template["user_template"].format(buggy_code=buggy_code)
    
        # Full causal text
        full_text = prompt + "\n" + fixed_code
    
        formatted_texts.append(full_text)
    
        # Track prompt length for label masking
        prompt_ids = tokenizer(
            prompt,
            add_special_tokens=False
        )["input_ids"]
        prompt_lengths.append(len(prompt_ids))

    
    # Tokenize all formatted texts
    tokenized = tokenizer(
        formatted_texts,
        truncation=True,
        max_length=MAX_LENGTH,
        padding=False,  # Will be handled by data collator
        return_tensors=None
    )
    
    # For causal LM, labels are the same as input_ids
    # tokenized["labels"] = tokenized["input_ids"].copy()
    labels = []

    for input_ids, prompt_len in zip(tokenized["input_ids"], prompt_lengths):
        lbl = input_ids.copy()
        lbl[:min(prompt_len, len(lbl))] = [-100] * min(prompt_len, len(lbl))
        labels.append(lbl)
    
    tokenized["labels"] = labels
    
    return tokenized

print("‚úÖ Preprocessing function defined!")
print("\nThis function will:")
print("  1. Randomly select one of 5 templates for each sample")
print("  2. Format the prompt using the selected template")
print("  3. Create chat messages (system/user/assistant)")
print("  4. Apply tokenizer's chat template")
print("  5. Tokenize and prepare for training")

‚úÖ Preprocessing function defined!

This function will:
  1. Randomly select one of 5 templates for each sample
  2. Format the prompt using the selected template
  3. Create chat messages (system/user/assistant)
  4. Apply tokenizer's chat template
  5. Tokenize and prepare for training


## 8. Test Preprocessing with Examples

Let's test the preprocessing function to see how different templates are applied.

In [10]:
print("Testing preprocessing with 3 examples...\n")

# Get a small sample
test_sample = dataset.select(range(3))

# Process each sample individually to see different templates
for i in range(3):
    sample = test_sample[i]
    
    # Create a single-item batch
    batch = {
        'buggy': [sample['buggy']],
        'fixed': [sample['fixed']]
    }
    
    # Process
    processed = preprocess_function(batch)
    
    # Decode to see the formatted text
    
    formatted_text = tokenizer.decode(processed['input_ids'][0], skip_special_tokens=False)
    
    print(f"{'='*80}")
    print(f"Example {i+1}")
    print(f"{'='*80}")
    print(f"\nFormatted text (first 500 chars):\n{formatted_text[:500]}...")
    print(f"\nToken count: {len(processed['input_ids'][0])}")
    print()

Testing preprocessing with 3 examples...

Example 1

Formatted text (first 500 chars):
You are a Java expert. Fix the following buggy Java code. Correct syntax errors, logic bugs, runtime errors, and code quality issues. Output ONLY the corrected code.

public java.lang.String METHOD_1 ( ) { return new TYPE_1 ( STRING_1 ) . format ( VAR_1 [ ( ( VAR_1 . length ) - 1 ) ] . getTime ( ) ) ; } 

public java.lang.String METHOD_1 ( ) { return new TYPE_1 ( STRING_1 ) . format ( VAR_1 [ ( ( type ) - 1 ) ] . getTime ( ) ) ; } 
...

Token count: 126

Example 2

Formatted text (first 500 chars):
You are a Java expert. Fix the following buggy Java code. Correct syntax errors, logic bugs, runtime errors, and code quality issues. Output ONLY the corrected code.

public boolean METHOD_1 ( java.lang.String name ) { TYPE_1 VAR_1 = TYPE_1 . METHOD_2 ( VAR_2 ) ; return ( ! ( METHOD_3 ( name ) ) ) && ( VAR_1 . contains ( name ) ) ; } 

public boolean METHOD_1 ( java.lang.String name ) { return ( ! ( METHOD

## 9. Apply Preprocessing to Dataset

Now we'll apply the preprocessing function to the entire dataset. Each sample will get a randomly selected template.

In [11]:
print("Applying preprocessing to dataset...")
print("This will randomly assign templates to each sample.\n")

# Apply preprocessing
tokenized_dataset = dataset.map(
    preprocess_function,
    batched=True,
    batch_size=100,
    remove_columns=dataset.column_names,
    desc="Tokenizing dataset"
)

print(f"\n‚úÖ Preprocessing complete!")
print(f"   Total samples: {len(tokenized_dataset)}")
print(f"   Features: {tokenized_dataset.features}")

# Show token length statistics
token_lengths = [len(x) for x in tokenized_dataset['input_ids']]
print(f"\nToken length statistics:")
print(f"   Min: {min(token_lengths)}")
print(f"   Max: {max(token_lengths)}")
print(f"   Mean: {sum(token_lengths) / len(token_lengths):.1f}")
print(f"   Median: {sorted(token_lengths)[len(token_lengths)//2]}")

Applying preprocessing to dataset...
This will randomly assign templates to each sample.


‚úÖ Preprocessing complete!
   Total samples: 46680
   Features: {'input_ids': List(Value('int32')), 'attention_mask': List(Value('int8')), 'labels': List(Value('int64'))}

Token length statistics:
   Min: 18
   Max: 219
   Mean: 108.4
   Median: 108


## 10. Create Train/Validation Split

In [12]:
print("Creating train/validation split (90/10)...\n")

# Split dataset
split_dataset = tokenized_dataset.train_test_split(
    test_size=0.1,
    seed=SEED
)

train_dataset = split_dataset['train']
eval_dataset = split_dataset['test']

print(f"‚úÖ Split complete!")
print(f"   Training samples: {len(train_dataset)}")
print(f"   Validation samples: {len(eval_dataset)}")
print(f"   Split ratio: {len(train_dataset)/len(tokenized_dataset)*100:.1f}% train, {len(eval_dataset)/len(tokenized_dataset)*100:.1f}% validation")

Creating train/validation split (90/10)...

‚úÖ Split complete!
   Training samples: 42012
   Validation samples: 4668
   Split ratio: 90.0% train, 10.0% validation


## 11. Setup Training Configuration

In [13]:
print("Setting up training configuration...\n")

# Calculate training steps
num_train_samples = len(train_dataset)
steps_per_epoch = num_train_samples // BATCH_SIZE
total_steps = steps_per_epoch * NUM_EPOCHS
eval_steps = steps_per_epoch // 4  # Evaluate 4 times per epoch
save_steps = steps_per_epoch // 2  # Save 2 times per epoch

print(f"Training configuration:")
print(f"   Total training samples: {num_train_samples}")
print(f"   Steps per epoch: {steps_per_epoch}")
print(f"   Total training steps: {total_steps}")
print(f"   Evaluation steps: {eval_steps}")
print(f"   Save steps: {save_steps}")

# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_steps=100,
    logging_steps=50,
    eval_steps=5251,
    save_steps=10502,
    eval_strategy="steps",
    save_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=3,
    fp16=torch.cuda.is_available(),
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="adamw_torch",
    lr_scheduler_type="cosine",
    report_to="none",  # Change to "tensorboard" or "wandb" if desired
    seed=SEED,
)

print(f"\n‚úÖ Training arguments configured!")

Setting up training configuration...

Training configuration:
   Total training samples: 42012
   Steps per epoch: 42012
   Total training steps: 42012
   Evaluation steps: 10503
   Save steps: 21006

‚úÖ Training arguments configured!


In [14]:
from dataclasses import dataclass
from typing import List, Dict, Any
import torch

@dataclass
class DataCollatorForCausalLMWithPadding:
    tokenizer: Any
    label_pad_token_id: int = -100

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # Extract labels
        labels = [f.pop("labels") for f in features]

        # Pad inputs
        batch = self.tokenizer.pad(
            features,
            padding=True,
            return_tensors="pt"
        )

        # Pad labels manually
        max_len = batch["input_ids"].shape[1]
        padded_labels = [
            lbl + [self.label_pad_token_id] * (max_len - len(lbl))
            for lbl in labels
        ]

        batch["labels"] = torch.tensor(padded_labels, dtype=torch.long)
        return batch


## 12. Create Data Collator and Trainer

In [15]:
print("Creating data collator and trainer...\n")
from transformers import DataCollatorWithPadding
# Data collator for language modeling
# data_collator = DataCollatorForLanguageModeling(
#     tokenizer=tokenizer,
#     mlm=False  # Causal LM, not masked LM
# )
data_collator = DataCollatorForCausalLMWithPadding(
    tokenizer=tokenizer
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

print("‚úÖ Trainer created successfully!")
print("\nTraining workflow:")
print("  1. Trainer pulls a batch from train_dataset")
print("  2. Data collator pads the batch")
print("  3. Model processes the batch (forward pass)")
print("  4. Loss is calculated")
print("  5. Gradients are computed (backward pass)")
print("  6. Optimizer updates LoRA parameters")
print("  7. Repeat for all batches and epochs")

The model is already on multiple devices. Skipping the move to device specified in `args`.


Creating data collator and trainer...

‚úÖ Trainer created successfully!

Training workflow:
  1. Trainer pulls a batch from train_dataset
  2. Data collator pads the batch
  3. Model processes the batch (forward pass)
  4. Loss is calculated
  5. Gradients are computed (backward pass)
  6. Optimizer updates LoRA parameters
  7. Repeat for all batches and epochs


## 13. Start Training

**Note**: This will take several hours depending on your hardware. The training will run for 3 epochs.

In [None]:
print("="*80)
print("STARTING TRAINING")
print("="*80)
print(f"\nModel: {MODEL_NAME}")
print(f"Dataset: {DATASET_NAME} ({DATASET_SUBSET} subset, {DATASET_SPLIT} split)")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(eval_dataset)}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"\nDynamic prompts: 5 templates randomly selected per sample")
print(f"\nEstimated time: Several hours (depends on hardware)")
print(f"\n{'='*80}\n")

# Start training
train_result = trainer.train()

print("\n" + "="*80)
print("TRAINING COMPLETE!")
print("="*80)
print(f"\nTraining metrics:")
print(f"   Final train loss: {train_result.training_loss:.4f}")
print(f"   Total steps: {train_result.global_step}")
print(f"   Training time: {train_result.metrics['train_runtime']:.2f} seconds")
print(f"   Samples per second: {train_result.metrics['train_samples_per_second']:.2f}")

STARTING TRAINING

Model: Haiintel/HaiJava-Surgeon-Qwen2.5-Coder-7B-SFT-v1
Dataset: code_x_glue_cc_code_refinement (small subset, train split)
Training samples: 42012
Validation samples: 4668
Epochs: 1
Batch size: 1
Learning rate: 0.0002

Dynamic prompts: 5 templates randomly selected per sample

Estimated time: Several hours (depends on hardware)




You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


## 14. Evaluate on Validation Set

In [None]:
print("Evaluating on validation set...\n")

eval_results = trainer.evaluate()

print("‚úÖ Evaluation complete!")
print(f"\nValidation metrics:")
for key, value in eval_results.items():
    print(f"   {key}: {value:.4f}")

## 15. Save the Fine-tuned Model

In [None]:
print("Saving fine-tuned model...\n")

# Save the final model
final_model_path = f"{OUTPUT_DIR}/final"
trainer.save_model(final_model_path)
tokenizer.save_pretrained(final_model_path)

print(f"‚úÖ Model saved to: {final_model_path}")
print(f"\nSaved files:")
import os
for file in os.listdir(final_model_path):
    file_path = os.path.join(final_model_path, file)
    if os.path.isfile(file_path):
        size_mb = os.path.getsize(file_path) / (1024 * 1024)
        print(f"   {file}: {size_mb:.2f} MB")

## 16. Test the Fine-tuned Model

Let's test the model with a sample buggy code to see how it performs.

In [None]:
print("Testing the fine-tuned model...\n")

# Get a test sample from the dataset
test_sample = dataset[100]  # Use a different sample than training
buggy_code = test_sample['buggy']
expected_fixed = test_sample['fixed']

print("="*80)
print("TEST SAMPLE")
print("="*80)
print(f"\nüêõ Buggy Code:\n{buggy_code}")
print(f"\n‚úÖ Expected Fixed Code:\n{expected_fixed}")

# Create a prompt using Template 1 (detailed expert)
messages = [
    {"role": "system", "content": "You are a Java code fixing assistant."},
    {"role": "user", "content": f"You are a Java expert. Fix the following buggy Java code. Correct syntax errors, logic bugs, runtime errors, and code quality issues. Output ONLY the corrected code.\n\n{buggy_code}"}
]

# Format with chat template
prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

# Tokenize
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Generate
print("\nü§ñ Generating fix...")
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id
    )

# Decode the output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Extract just the assistant's response
if "assistant" in generated_text:
    model_output = generated_text.split("assistant")[-1].strip()
else:
    model_output = generated_text

print(f"\nüîß Model's Fixed Code:\n{model_output}")
print("\n" + "="*80)

## 17. Test with Multiple Templates

Let's test the same buggy code with different prompt templates to see how the model responds.

In [None]:
print("Testing with all 5 prompt templates...\n")

templates = get_prompt_templates()
test_sample = dataset[150]
buggy_code = test_sample['buggy']

print("="*80)
print(f"Buggy Code:\n{buggy_code}")
print("="*80)

for i, template in enumerate(templates, 1):
    print(f"\n{'='*80}")
    print(f"Template {i}: {template['name'].upper()}")
    print(f"{'='*80}")
    
    # Create messages
    user_content = template['user_template'].format(buggy_code=buggy_code)
    messages = [
        {"role": "system", "content": template['system']},
        {"role": "user", "content": user_content}
    ]
    
    # Format and tokenize
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    # Decode
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "assistant" in generated_text:
        model_output = generated_text.split("assistant")[-1].strip()
    else:
        model_output = generated_text
    
    print(f"\nModel Output (first 300 chars):\n{model_output[:300]}...")

## 18. Summary and Next Steps

In [None]:
print("="*80)
print("TRAINING SUMMARY")
print("="*80)
print(f"\n‚úÖ Model: {MODEL_NAME}")
print(f"‚úÖ Dataset: {DATASET_NAME} ({DATASET_SUBSET} subset, {DATASET_SPLIT} split)")
print(f"‚úÖ Training samples: {len(train_dataset)}")
print(f"‚úÖ Validation samples: {len(eval_dataset)}")
print(f"‚úÖ Epochs completed: {NUM_EPOCHS}")
print(f"‚úÖ Dynamic prompt templates: 5 templates randomly selected")
print(f"‚úÖ Model saved to: {final_model_path}")

print(f"\n{'='*80}")
print("NEXT STEPS")
print("="*80)
print("""
1. **Evaluate on Test Set**: Load the test split and evaluate model performance

2. **Merge LoRA Adapters**: Merge the LoRA adapters with the base model for deployment
   ```python
   from peft import PeftModel
   base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
   model = PeftModel.from_pretrained(base_model, final_model_path)
   merged_model = model.merge_and_unload()
   merged_model.save_pretrained("./models/haijava_merged")
   ```

3. **Upload to Hugging Face Hub**: Share your fine-tuned model
   ```python
   merged_model.push_to_hub("your-username/haijava-surgeon-finetuned")
   tokenizer.push_to_hub("your-username/haijava-surgeon-finetuned")
   ```

4. **Run Comprehensive Evaluation**: Test on various Java bug-fixing benchmarks

5. **Analyze Template Performance**: Check which templates led to better results
""")

print("="*80)
print("TRAINING COMPLETE! üéâ")
print("="*80)