# Quick Start Guide: Federated Learning for Java Error Classification

This notebook provides a quick introduction to training and evaluating models for Java code error classification using both centralized and federated learning approaches.

## 1. Environment Setup

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"Current GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Data Exploration

In [None]:
import json
import pandas as pd

# Load training data
with open('../data/train_data.json', 'r') as f:
    train_data = [json.loads(line) for line in f]

print(f"Training samples: {len(train_data)}")
print(f"\nSample structure:")
print(json.dumps(train_data[0], indent=2)[:500] + "...")

In [None]:
# Analyze error types
import re
from collections import Counter

error_types = []
for sample in train_data:
    feedback = sample['feedback']
    # Extract error types (Syntax/Runtime/Logical)
    types = re.findall(r'(Syntax Error|Runtime Error|Logical Error)', feedback)
    error_types.extend(types)

type_counts = Counter(error_types)
print("Error type distribution:")
for error_type, count in type_counts.items():
    print(f"  {error_type}: {count} ({count/len(error_types)*100:.1f}%)")

## 3. Data Preparation for Federated Learning

In [None]:
# Prepare data for federated clients
!python ../src/utils/data_preparation.py

## 4. Model Training

### Option A: Centralized Training (Baseline)

In [None]:
# Train with Unsloth (faster and more efficient)
!cd .. && python src/training/centralized_unsloth.py

### Option B: Federated Learning

In [None]:
# Train with FedAdam
!cd .. && bash scripts/run_federated_fedadam.sh --mode simulation

## 5. Model Evaluation

In [None]:
# Set your OpenAI API key
import os
os.environ['OPENAI_API_KEY'] = 'your-api-key-here'

In [None]:
# Evaluate trained models
!cd .. && python src/evaluation/evaluate_with_gpt.py

## 6. Results Analysis

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Example results (replace with your actual results)
results = {
    'Few-shot': 8.47,
    'Centralized (HF)': 8.89,
    'Centralized (Unsloth)': 8.95,
    'FedAvg': 8.76,
    'FedProx': 8.81,
    'FedAdam': 8.92
}

# Plot comparison
plt.figure(figsize=(12, 6))
methods = list(results.keys())
scores = list(results.values())
colors = ['gray' if 'Few' in m else 'skyblue' if 'Central' in m else 'lightcoral' for m in methods]

bars = plt.bar(methods, scores, color=colors, alpha=0.8)
plt.ylabel('Overall Score (out of 10)', fontsize=12)
plt.title('Model Performance Comparison', fontsize=14, fontweight='bold')
plt.ylim(8.0, 9.2)
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3)

# Add value labels on bars
for bar, score in zip(bars, scores):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
             f'{score:.2f}', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.savefig('../results_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## 7. Inference Example

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# Load trained model
model_path = "../java_error_federated_results/fedadam/final_model"
base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-4B-Base",
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
model = PeftModel.from_pretrained(base_model, model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

print("✓ Model loaded successfully")

In [None]:
# Test on example
test_input = """Java code requirement: Return the largest prime factor of n.
student code: 
import java.util.*;
class Solution {
    public int largestPrimeFactor(int n) {
        int largest = 1;
        for (int j = 2; j <= n; j++) {
            if (n % j = 0) {  // Missing second '='
                largest = j;
            }
        }
        return largest;
    }
}
"""

messages = [
    {"role": "system", "content": "Analyze the student's Java code and identify all errors."},
    {"role": "user", "content": test_input}
]

# Generate prediction
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)

outputs = model.generate(
    **inputs,
    max_new_tokens=512,
    temperature=0.7,
    do_sample=True
)

prediction = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
print("\nModel Prediction:")
print(prediction)

## 8. Summary

In this notebook, we:
1. ✅ Explored the Java error classification dataset
2. ✅ Prepared data for federated learning
3. ✅ Trained models using centralized and federated approaches
4. ✅ Evaluated model performance with GPT-based scoring
5. ✅ Visualized and compared results
6. ✅ Tested inference on new examples

### Key Takeaways
- Federated learning achieves comparable performance to centralized training
- FedAdam shows the best performance among federated algorithms
- All fine-tuned models significantly outperform few-shot baseline
- Privacy can be preserved with minimal performance loss (<2%)

### Next Steps
- Experiment with different hyperparameters
- Try larger models (Qwen3-8B)
- Test on your own code examples
- Deploy the best model for production use