# Phase 3.5: Stage 5 - Harmonization (New Input + All Output)

EEVE Stage 5: Train new input embeddings + all output embeddings.

## Purpose
- Full vocabulary harmonization
- Align new input tokens with the updated output vocabulary

## Contents
1. Setup and Configuration
2. Load Model from Stage 4
3. Configure Parameter Freezing
4. Training
5. Save Stage 5 Checkpoint

In [None]:
# Setup
import sys
import os
sys.path.append("..")

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    TrainerCallback,
)
from datasets import load_from_disk
import json


class MetricsCallback(TrainerCallback):
    """Callback to print training metrics at each logging step."""

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            step = state.global_step
            metrics = []
            if "loss" in logs:
                metrics.append(f"loss={logs['loss']:.4f}")
            if "eval_loss" in logs:
                metrics.append(f"eval_loss={logs['eval_loss']:.4f}")
            if "learning_rate" in logs:
                metrics.append(f"lr={logs['learning_rate']:.2e}")
            if metrics:
                print(f"[Step {step}] {', '.join(metrics)}")


# GPU setup
from config.gpu_utils import setup_gpu, print_memory_usage, clear_memory
device = setup_gpu()

print_memory_usage()

In [None]:
# Directories
STAGE4_MODEL_DIR = "../models/staged_training/stage4_all_output_embeds"
DATA_DIR = "../data/processed"
OUTPUT_DIR = "../models/staged_training/stage5_harmonization"

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Input model: {STAGE4_MODEL_DIR}")
print(f"Output dir: {OUTPUT_DIR}")

---
## 1. Stage Configuration

In [None]:
# Stage 5 configuration
STAGE_CONFIG = {
    "name": "stage5_harmonization",
    "description": "Train new input embeddings + ALL output embeddings",
    "train_input_embeddings": True,
    "train_output_embeddings": True,
    "train_lora_layers": False,
    "freeze_old_input_embeddings": True,  # Freeze old input
    "freeze_old_output_embeddings": False,  # Train all output
    "learning_rate": 2e-5,
    "num_epochs": 1,
    "warmup_ratio": 0.03,
    "batch_size": 1,
    "gradient_accumulation_steps": 16,
}

print("Stage 5 Configuration:")
for key, value in STAGE_CONFIG.items():
    print(f"  {key}: {value}")

---
## 2. Load Model from Stage 4

In [None]:
# Load token mapping
mapping_path = f"{STAGE4_MODEL_DIR}/token_mapping.json"
with open(mapping_path, "r", encoding="utf-8") as f:
    token_mapping = json.load(f)

original_vocab_size = token_mapping["original_vocab_size"]
new_vocab_size = token_mapping["new_vocab_size"]

print(f"Original vocab: {original_vocab_size}")
print(f"New vocab: {new_vocab_size}")
print(f"New tokens: {new_vocab_size - original_vocab_size}")

In [None]:
# Load model from Stage 4
print("\nLoading model from Stage 4...")

model = AutoModelForCausalLM.from_pretrained(
    STAGE4_MODEL_DIR,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(STAGE4_MODEL_DIR)

# Ensure padding token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded!")
print_memory_usage()

---
## 3. Configure Parameter Freezing

In [None]:
# First, freeze ALL parameters
for param in model.parameters():
    param.requires_grad = False

print("Froze all parameters")

In [None]:
# Enable input and output embeddings
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()

input_embeddings.weight.requires_grad = True
output_embeddings.weight.requires_grad = True

print(f"Enabled input embeddings training: {input_embeddings.weight.shape}")
print(f"Enabled output embeddings training: {output_embeddings.weight.shape}")

In [None]:
# Create hook to freeze OLD input embeddings only
class FreezeOldInputEmbeddingsHook:
    """Hook to zero out gradients for original INPUT token embeddings only"""
    
    def __init__(self, original_vocab_size):
        self.original_vocab_size = original_vocab_size
    
    def __call__(self, grad):
        # Zero out gradients for original tokens
        grad[:self.original_vocab_size] = 0
        return grad

# Register hook only for INPUT embeddings
freeze_hook_input = FreezeOldInputEmbeddingsHook(original_vocab_size)
hook_handle_input = input_embeddings.weight.register_hook(freeze_hook_input)

print(f"Registered gradient hook for INPUT embeddings (freeze old {original_vocab_size})")
print(f"OUTPUT embeddings: training ALL {new_vocab_size} tokens (no hook)")

In [None]:
# Verify trainable parameters
trainable_params = 0
total_params = 0

for name, param in model.named_parameters():
    total_params += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()
        print(f"Trainable: {name} - {param.shape}")

print(f"\nTrainable parameters: {trainable_params:,} / {total_params:,}")
print(f"Percentage: {100 * trainable_params / total_params:.4f}%")

---
## 4. Load Training Data

In [None]:
# Load language modeling data
lm_data_path = f"{DATA_DIR}/korean_medical_lm"

if os.path.exists(lm_data_path):
    dataset = load_from_disk(lm_data_path)
    print(f"Loaded dataset: {dataset}")
else:
    print(f"Dataset not found at {lm_data_path}")

In [None]:
# Tokenize dataset
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=1024,
        padding="max_length",
    )

print("Tokenizing dataset...")
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
    num_proc=4,
)

print(f"Tokenized dataset: {tokenized_dataset}")

In [None]:
# Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

---
## 5. Training

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=STAGE_CONFIG["num_epochs"],
    per_device_train_batch_size=STAGE_CONFIG["batch_size"],
    per_device_eval_batch_size=STAGE_CONFIG["batch_size"],
    gradient_accumulation_steps=STAGE_CONFIG["gradient_accumulation_steps"],
    learning_rate=STAGE_CONFIG["learning_rate"],
    warmup_ratio=STAGE_CONFIG["warmup_ratio"],
    lr_scheduler_type="cosine",
    bf16=True,
    logging_steps=10,
    save_strategy="epoch",
    save_total_limit=2,
    optim="adamw_torch",
    max_grad_norm=1.0,
    report_to="tensorboard",
    dataloader_num_workers=4,
    eval_strategy="steps",
    eval_steps=500,
)

print("Training arguments configured")

In [None]:
# Create trainer with metrics callback
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"] if "validation" in tokenized_dataset else None,
    data_collator=data_collator,
    callbacks=[MetricsCallback()],
)

print("Trainer created")

In [None]:
# Train!
print("\n" + "=" * 60)
print("Starting Stage 5 Training: Harmonization")
print("=" * 60)
print_memory_usage()

trainer.train()

print("\nTraining complete!")
print_memory_usage()

---
## 6. Save Checkpoint

In [None]:
# Remove hook before saving
hook_handle_input.remove()
print("Removed gradient hook")

# Save model
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print(f"\nModel saved to {OUTPUT_DIR}")

In [None]:
# Save stage info
stage_info = {
    "stage": 5,
    "name": STAGE_CONFIG["name"],
    "description": STAGE_CONFIG["description"],
    "config": STAGE_CONFIG,
    "trainable_params": trainable_params,
    "total_params": total_params,
    "original_vocab_size": original_vocab_size,
    "new_vocab_size": new_vocab_size,
    "previous_stage": STAGE4_MODEL_DIR,
}

info_path = f"{OUTPUT_DIR}/stage_info.json"
with open(info_path, "w", encoding="utf-8") as f:
    json.dump(stage_info, f, indent=2)

print(f"Stage info saved to {info_path}")

In [None]:
# Copy token mapping
import shutil
shutil.copy(
    f"{STAGE4_MODEL_DIR}/token_mapping.json",
    f"{OUTPUT_DIR}/token_mapping.json"
)
print("Copied token mapping")

In [None]:
print("\n" + "=" * 60)
print("Stage 5 Complete: Harmonization Training Done!")
print("=" * 60)
print(f"\nCheckpoint saved to: {OUTPUT_DIR}")
print("\nEmbedding-only stages (1-5) complete!")
print("\nNext steps:")
print("  Run 06_stage6_qlora_full.ipynb for QLoRA training")