# Phase 3.7: Stage 7 - Cooldown (LoRA Only)

EEVE Stage 7: Stabilization with LoRA-only training at lower learning rate.

## Purpose
- Stabilize learned representations
- Fine-tune internal layers while keeping embeddings fixed
- Prevent overfitting with lower learning rate

## Contents
1. Setup and Configuration
2. Load Model from Stage 6
3. Configure Training (LoRA only)
4. Training
5. Merge LoRA and Save Final Model

In [None]:
# Setup
import sys
import os
sys.path.append("..")

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig,
    TrainerCallback,
)
from peft import PeftModel, LoraConfig, get_peft_model, TaskType
from datasets import load_from_disk
import json


class MetricsCallback(TrainerCallback):
    """Callback to print training metrics at each logging step."""

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            step = state.global_step
            metrics = []
            if "loss" in logs:
                metrics.append(f"loss={logs['loss']:.4f}")
            if "eval_loss" in logs:
                metrics.append(f"eval_loss={logs['eval_loss']:.4f}")
            if "learning_rate" in logs:
                metrics.append(f"lr={logs['learning_rate']:.2e}")
            if metrics:
                print(f"[Step {step}] {', '.join(metrics)}")


# GPU setup
from config.gpu_utils import setup_gpu, print_memory_usage, clear_memory
device = setup_gpu()

print_memory_usage()

In [None]:
# Directories
STAGE6_MODEL_DIR = "../models/staged_training/stage6_qlora_full"
STAGE5_MODEL_DIR = "../models/staged_training/stage5_harmonization"  # Base model
DATA_DIR = "../data/processed"
OUTPUT_DIR = "../models/staged_training/stage7_cooldown"
FINAL_MODEL_DIR = "../models/final/korean_medgemma"

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(FINAL_MODEL_DIR, exist_ok=True)

print(f"Input model: {STAGE6_MODEL_DIR}")
print(f"Output dir: {OUTPUT_DIR}")
print(f"Final model dir: {FINAL_MODEL_DIR}")

---
## 1. Stage Configuration

In [None]:
# Stage 7 configuration (cooldown)
STAGE_CONFIG = {
    "name": "stage7_cooldown",
    "description": "Stabilization with LoRA-only training",
    "train_input_embeddings": False,
    "train_output_embeddings": False,
    "train_lora_layers": True,
    "learning_rate": 5e-5,  # Lower LR for cooldown
    "num_epochs": 1,
    "warmup_ratio": 0.1,  # Higher warmup ratio for stability
    "batch_size": 1,
    "gradient_accumulation_steps": 16,
    # Smaller LoRA for cooldown
    "lora_r": 32,
    "lora_alpha": 64,
    "lora_dropout": 0.05,
    "lora_target_modules": [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
}

print("Stage 7 Configuration (Cooldown):")
for key, value in STAGE_CONFIG.items():
    print(f"  {key}: {value}")

---
## 2. Load Model from Stage 6

In [None]:
# Load token mapping
mapping_path = f"{STAGE6_MODEL_DIR}/token_mapping.json"
with open(mapping_path, "r", encoding="utf-8") as f:
    token_mapping = json.load(f)

original_vocab_size = token_mapping["original_vocab_size"]
new_vocab_size = token_mapping["new_vocab_size"]

print(f"Original vocab: {original_vocab_size}")
print(f"New vocab: {new_vocab_size}")
print(f"New tokens: {new_vocab_size - original_vocab_size}")

In [None]:
# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

print("BitsAndBytes config created")

In [None]:
# Load base model
print("\nLoading base model...")

base_model = AutoModelForCausalLM.from_pretrained(
    STAGE5_MODEL_DIR,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

print(f"Base model loaded!")
print_memory_usage()

In [None]:
# Load LoRA adapters from Stage 6
print("\nLoading LoRA adapters from Stage 6...")

model = PeftModel.from_pretrained(
    base_model,
    STAGE6_MODEL_DIR,
    is_trainable=True,
)

tokenizer = AutoTokenizer.from_pretrained(STAGE6_MODEL_DIR)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"PEFT model loaded!")
print_memory_usage()

---
## 3. Configure Training (LoRA Only)

In [None]:
# Freeze embeddings for cooldown
for name, param in model.named_parameters():
    # Freeze all modules_to_save (embeddings)
    if "modules_to_save" in name or "embed_tokens" in name or "lm_head" in name:
        param.requires_grad = False
        print(f"Frozen: {name}")

print("\nEmbeddings frozen for cooldown stage")

In [None]:
# Verify trainable parameters
trainable_params = 0
total_params = 0

for name, param in model.named_parameters():
    total_params += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()

print(f"\nTrainable parameters: {trainable_params:,} / {total_params:,}")
print(f"Percentage: {100 * trainable_params / total_params:.4f}%")
print("\nOnly LoRA parameters are trainable (no embeddings)")

---
## 4. Load Training Data

In [None]:
# Load language modeling data
lm_data_path = f"{DATA_DIR}/korean_medical_lm"

if os.path.exists(lm_data_path):
    dataset = load_from_disk(lm_data_path)
    print(f"Loaded dataset: {dataset}")
else:
    print(f"Dataset not found at {lm_data_path}")

In [None]:
# Tokenize dataset
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=1024,
        padding="max_length",
    )

print("Tokenizing dataset...")
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
    num_proc=4,
)

print(f"Tokenized dataset: {tokenized_dataset}")

In [None]:
# Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

---
## 5. Training

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=STAGE_CONFIG["num_epochs"],
    per_device_train_batch_size=STAGE_CONFIG["batch_size"],
    per_device_eval_batch_size=STAGE_CONFIG["batch_size"],
    gradient_accumulation_steps=STAGE_CONFIG["gradient_accumulation_steps"],
    learning_rate=STAGE_CONFIG["learning_rate"],
    warmup_ratio=STAGE_CONFIG["warmup_ratio"],
    lr_scheduler_type="cosine",
    bf16=True,
    logging_steps=10,
    save_strategy="epoch",
    save_total_limit=2,
    optim="paged_adamw_8bit",
    max_grad_norm=1.0,
    report_to="tensorboard",
    gradient_checkpointing=True,
    dataloader_num_workers=4,
    eval_strategy="steps",
    eval_steps=500,
)

print("Training arguments configured")

In [None]:
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
print("Gradient checkpointing enabled")

In [None]:
# Create trainer with metrics callback
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"] if "validation" in tokenized_dataset else None,
    data_collator=data_collator,
    callbacks=[MetricsCallback()],
)

print("Trainer created")

In [None]:
# Train!
print("\n" + "=" * 60)
print("Starting Stage 7 Training: Cooldown")
print("=" * 60)
print_memory_usage()

trainer.train()

print("\nTraining complete!")
print_memory_usage()

---
## 6. Save Checkpoint

In [None]:
# Save cooldown checkpoint
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print(f"\nCooldown checkpoint saved to {OUTPUT_DIR}")

In [None]:
# Save stage info
stage_info = {
    "stage": 7,
    "name": STAGE_CONFIG["name"],
    "description": STAGE_CONFIG["description"],
    "config": STAGE_CONFIG,
    "trainable_params": trainable_params,
    "total_params": total_params,
    "original_vocab_size": original_vocab_size,
    "new_vocab_size": new_vocab_size,
    "previous_stage": STAGE6_MODEL_DIR,
}

info_path = f"{OUTPUT_DIR}/stage_info.json"
with open(info_path, "w", encoding="utf-8") as f:
    json.dump(stage_info, f, indent=2)

print(f"Stage info saved to {info_path}")

---
## 7. Merge LoRA and Save Final Model

In [None]:
# Optional: Merge LoRA weights and save full model
print("\nMerging LoRA weights into base model...")

# Clear GPU memory
del trainer
clear_memory()

# Reload base model in bfloat16 (not quantized) for merging
print("Loading base model for merging...")
base_model_merge = AutoModelForCausalLM.from_pretrained(
    STAGE5_MODEL_DIR,
    torch_dtype=torch.bfloat16,
    device_map="cpu",  # Load on CPU for merging
    trust_remote_code=True,
)

print("Base model loaded for merging")

In [None]:
# Load LoRA adapters
print("Loading LoRA adapters...")
merged_model = PeftModel.from_pretrained(base_model_merge, OUTPUT_DIR)

# Merge and unload
print("Merging LoRA weights...")
merged_model = merged_model.merge_and_unload()

print("LoRA weights merged successfully!")

In [None]:
# Save merged model
print(f"\nSaving merged model to {FINAL_MODEL_DIR}...")

merged_model.save_pretrained(FINAL_MODEL_DIR)
tokenizer.save_pretrained(FINAL_MODEL_DIR)

print("Merged model saved!")

In [None]:
# Copy token mapping to final model
import shutil
shutil.copy(
    f"{OUTPUT_DIR}/token_mapping.json",
    f"{FINAL_MODEL_DIR}/token_mapping.json"
)

# Save final model info
final_info = {
    "model_name": "korean_medgemma",
    "base_model": token_mapping.get("base_model", "unknown"),
    "original_vocab_size": original_vocab_size,
    "new_vocab_size": new_vocab_size,
    "korean_tokens_added": new_vocab_size - original_vocab_size,
    "training_stages": 7,
    "lora_merged": True,
}

with open(f"{FINAL_MODEL_DIR}/model_info.json", "w") as f:
    json.dump(final_info, f, indent=2)

print(f"Final model info saved")

In [None]:
print("\n" + "=" * 60)
print("Phase 3 Complete: All Staged Training Done!")
print("=" * 60)
print(f"\nCooldown checkpoint: {OUTPUT_DIR}")
print(f"Final merged model: {FINAL_MODEL_DIR}")
print("\nNext steps:")
print("  1. Run phase4_instruction_tuning/01_instruction_tuning.ipynb")
print("  2. Or proceed to phase5_evaluation for testing")