# Phase 4: Korean Medical Instruction Tuning

Fine-tune the Korean-adapted model on medical instruction data.

## Purpose
- Train on Korean medical QA format
- Enable instruction-following capabilities
- Use KorMedMCQA and other instruction data

## Contents
1. Setup and Configuration
2. Load Model from Phase 3
3. Load Instruction Data
4. Training
5. Test Generation
6. Save Instruction-Tuned Model

In [None]:
# Setup
import sys
import os
sys.path.append("..")

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import load_from_disk
import json

# GPU setup
from config.gpu_utils import setup_gpu, print_memory_usage, clear_memory
device = setup_gpu()

print_memory_usage()

In [None]:
# Directories
# Use expanded model from Phase 3 Stage 7 (hybrid: identity layers + QLoRA)
BASE_MODEL_DIR = "../models/final/korean_medgemma_expanded"

# Alternative: Use Stage 7 cooldown checkpoint directly
# BASE_MODEL_DIR = "../models/staged_training/stage7_cooldown"

# Legacy (non-expanded model):
# BASE_MODEL_DIR = "../models/final/korean_medgemma"

DATA_DIR = "../data/processed"
OUTPUT_DIR = "../models/instruction_tuned"

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Base model: {BASE_MODEL_DIR}")
print(f"Output dir: {OUTPUT_DIR}")
print("\nNote: Using hybrid expanded model with +2 identity layers")

---
## 1. Configuration

In [None]:
# Instruction tuning configuration
CONFIG = {
    "learning_rate": 2e-5,
    "num_epochs": 3,
    "batch_size": 1,
    "gradient_accumulation_steps": 8,
    "max_seq_length": 2048,
    "warmup_ratio": 0.03,
    # LoRA config for instruction tuning
    "lora_r": 32,
    "lora_alpha": 64,
    "lora_dropout": 0.05,
}

print("Instruction Tuning Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

---
## 2. Load Model

In [None]:
# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

print("BitsAndBytes config created")

In [None]:
# Load model
print("\nLoading Korean MedGemma model...")

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_DIR,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_DIR)

# Ensure padding token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded!")
print(f"Vocab size: {len(tokenizer)}")
print_memory_usage()

In [None]:
# Prepare for k-bit training
model = prepare_model_for_kbit_training(model)
print("Model prepared for k-bit training")

In [None]:
# Apply LoRA for instruction tuning
lora_config = LoraConfig(
    r=CONFIG["lora_r"],
    lora_alpha=CONFIG["lora_alpha"],
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    lora_dropout=CONFIG["lora_dropout"],
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_config)

print("\nLoRA applied for instruction tuning")
model.print_trainable_parameters()

---
## 3. Load Instruction Data

In [None]:
# Load instruction dataset
instruction_data_path = f"{DATA_DIR}/korean_medical_instruction"

if os.path.exists(instruction_data_path):
    dataset = load_from_disk(instruction_data_path)
    print(f"Loaded instruction dataset: {dataset}")
else:
    print(f"Dataset not found at {instruction_data_path}")
    print("Run Phase 0 notebooks to prepare data.")

In [None]:
# Preview instruction data
print("Sample instruction:")
print(dataset["train"][0]["text"][:1000])

---
## 4. Training

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=CONFIG["num_epochs"],
    per_device_train_batch_size=CONFIG["batch_size"],
    per_device_eval_batch_size=CONFIG["batch_size"],
    gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
    learning_rate=CONFIG["learning_rate"],
    warmup_ratio=CONFIG["warmup_ratio"],
    lr_scheduler_type="cosine",
    bf16=True,
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    optim="paged_adamw_8bit",
    max_grad_norm=0.3,
    report_to="tensorboard",
    gradient_checkpointing=True,
)

print("Training arguments configured")

In [None]:
# Create SFT Trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"] if "validation" in dataset else None,
    tokenizer=tokenizer,
    max_seq_length=CONFIG["max_seq_length"],
    dataset_text_field="text",
)

print("SFT Trainer created")

In [None]:
# Train!
print("\n" + "=" * 60)
print("Starting Instruction Tuning")
print("=" * 60)
print_memory_usage()

trainer.train()

print("\nTraining complete!")
print_memory_usage()

---
## 5. Test Generation

In [None]:
# Test the instruction-tuned model
test_prompts = [
    """<|im_start|>system
당신은 한국어 의료 전문 AI 어시스턴트입니다.
<|im_end|>
<|im_start|>user
고혈압의 주요 증상과 치료법에 대해 설명해주세요.
<|im_end|>
<|im_start|>assistant
""",
    """<|im_start|>system
당신은 한국어 의료 전문 AI 어시스턴트입니다.
<|im_end|>
<|im_start|>user
당뇨병 환자가 주의해야 할 식이요법은 무엇인가요?
<|im_end|>
<|im_start|>assistant
""",
]

print("Testing instruction-tuned model:")
print("=" * 60)

for i, prompt in enumerate(test_prompts):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            temperature=0.7,
            do_sample=True,
            top_p=0.9,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    
    print(f"\n--- Test {i+1} ---")
    print(response)
    print()

---
## 6. Save Model

In [None]:
# Save instruction-tuned model
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print(f"\nInstruction-tuned model saved to {OUTPUT_DIR}")

In [None]:
# Save training info
training_info = {
    "phase": "instruction_tuning",
    "base_model": BASE_MODEL_DIR,
    "config": CONFIG,
    "train_samples": len(dataset["train"]),
    "eval_samples": len(dataset["validation"]) if "validation" in dataset else 0,
}

with open(f"{OUTPUT_DIR}/training_info.json", "w") as f:
    json.dump(training_info, f, indent=2)

print("Training info saved")

In [None]:
print("\n" + "=" * 60)
print("Phase 4 Complete: Instruction Tuning Done!")
print("=" * 60)
print(f"\nModel saved to: {OUTPUT_DIR}")
print("\nNext steps:")
print("  1. Run phase5_evaluation/01_evaluate_korean.ipynb")
print("  2. Run phase5_evaluation/02_evaluate_english.ipynb")