# Phase 3.4: Stage 4 - Train All Output Embeddings

EEVE Stage 4: Train ALL output embeddings (including original tokens).

## Purpose
- Integrate new Korean tokens with existing vocabulary in output space
- Allow model to adjust generation probabilities for all tokens

## Contents
1. Setup and Configuration
2. Load Model from Stage 3
3. Configure Parameter Freezing
4. Training
5. Save Stage 4 Checkpoint

In [None]:
# Setup
import sys
import os
sys.path.append("..")

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    TrainerCallback,
)
from datasets import load_from_disk
import json


class MetricsCallback(TrainerCallback):
    """Callback to print training metrics at each logging step."""

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            step = state.global_step
            metrics = []
            if "loss" in logs:
                metrics.append(f"loss={logs['loss']:.4f}")
            if "eval_loss" in logs:
                metrics.append(f"eval_loss={logs['eval_loss']:.4f}")
            if "learning_rate" in logs:
                metrics.append(f"lr={logs['learning_rate']:.2e}")
            if metrics:
                print(f"[Step {step}] {', '.join(metrics)}")


# GPU setup
from config.gpu_utils import setup_gpu, print_memory_usage, clear_memory
device = setup_gpu()

print_memory_usage()

In [None]:
# Directories
STAGE3_MODEL_DIR = "../models/staged_training/stage3_both_new_embeds"
DATA_DIR = "../data/processed"
OUTPUT_DIR = "../models/staged_training/stage4_all_output_embeds"

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Input model: {STAGE3_MODEL_DIR}")
print(f"Output dir: {OUTPUT_DIR}")

---
## 1. Stage Configuration

In [None]:
# Stage 4 configuration
STAGE_CONFIG = {
    "name": "stage4_all_output_embeds",
    "description": "Train ALL output embeddings (including original)",
    "train_input_embeddings": False,
    "train_output_embeddings": True,
    "train_lora_layers": False,
    "freeze_old_embeddings": False,  # Key change: train ALL output embeddings
    "learning_rate": 2e-5,  # Lower LR for full vocabulary
    "num_epochs": 1,
    "warmup_ratio": 0.03,
    "batch_size": 1,
    "gradient_accumulation_steps": 16,
}

print("Stage 4 Configuration:")
for key, value in STAGE_CONFIG.items():
    print(f"  {key}: {value}")

---
## 2. Load Model from Stage 3

In [None]:
# Load token mapping
mapping_path = f"{STAGE3_MODEL_DIR}/token_mapping.json"
with open(mapping_path, "r", encoding="utf-8") as f:
    token_mapping = json.load(f)

original_vocab_size = token_mapping["original_vocab_size"]
new_vocab_size = token_mapping["new_vocab_size"]

print(f"Original vocab: {original_vocab_size}")
print(f"New vocab: {new_vocab_size}")
print(f"New tokens: {new_vocab_size - original_vocab_size}")

In [None]:
# Load model from Stage 3
print("\nLoading model from Stage 3...")

model = AutoModelForCausalLM.from_pretrained(
    STAGE3_MODEL_DIR,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(STAGE3_MODEL_DIR)

# Ensure padding token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded!")
print_memory_usage()

---
## 3. Configure Parameter Freezing

In [None]:
# First, freeze ALL parameters
for param in model.parameters():
    param.requires_grad = False

print("Froze all parameters")

In [None]:
# Enable ALL output embeddings (lm_head) - no freezing of old tokens
output_embeddings = model.get_output_embeddings()
output_embeddings.weight.requires_grad = True

print(f"Enabled ALL output embeddings training")
print(f"Output embedding shape: {output_embeddings.weight.shape}")
print(f"\nNote: No gradient hook - training ALL {new_vocab_size} output embeddings")

In [None]:
# Verify trainable parameters
trainable_params = 0
total_params = 0

for name, param in model.named_parameters():
    total_params += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()
        print(f"Trainable: {name} - {param.shape}")

print(f"\nTrainable parameters: {trainable_params:,} / {total_params:,}")
print(f"Percentage: {100 * trainable_params / total_params:.4f}%")

---
## 4. Load Training Data

In [None]:
# Load language modeling data
lm_data_path = f"{DATA_DIR}/korean_medical_lm"

if os.path.exists(lm_data_path):
    dataset = load_from_disk(lm_data_path)
    print(f"Loaded dataset: {dataset}")
else:
    print(f"Dataset not found at {lm_data_path}")

In [None]:
# Tokenize dataset
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=1024,
        padding="max_length",
    )

print("Tokenizing dataset...")
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
    num_proc=4,
)

print(f"Tokenized dataset: {tokenized_dataset}")

In [None]:
# Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

---
## 5. Training

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=STAGE_CONFIG["num_epochs"],
    per_device_train_batch_size=STAGE_CONFIG["batch_size"],
    per_device_eval_batch_size=STAGE_CONFIG["batch_size"],
    gradient_accumulation_steps=STAGE_CONFIG["gradient_accumulation_steps"],
    learning_rate=STAGE_CONFIG["learning_rate"],
    warmup_ratio=STAGE_CONFIG["warmup_ratio"],
    lr_scheduler_type="cosine",
    bf16=True,
    logging_steps=10,
    save_strategy="epoch",
    save_total_limit=2,
    optim="adamw_torch",
    max_grad_norm=1.0,
    report_to="tensorboard",
    dataloader_num_workers=4,
    eval_strategy="steps",
    eval_steps=500,
)

print("Training arguments configured")

In [None]:
# Create trainer with metrics callback
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"] if "validation" in tokenized_dataset else None,
    data_collator=data_collator,
    callbacks=[MetricsCallback()],
)

print("Trainer created")

In [None]:
# Train!
print("\n" + "=" * 60)
print("Starting Stage 4 Training: All Output Embeddings")
print("=" * 60)
print_memory_usage()

trainer.train()

print("\nTraining complete!")
print_memory_usage()

---
## 6. Save Checkpoint

In [None]:
# Save model
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print(f"\nModel saved to {OUTPUT_DIR}")

In [None]:
# Save stage info
stage_info = {
    "stage": 4,
    "name": STAGE_CONFIG["name"],
    "description": STAGE_CONFIG["description"],
    "config": STAGE_CONFIG,
    "trainable_params": trainable_params,
    "total_params": total_params,
    "original_vocab_size": original_vocab_size,
    "new_vocab_size": new_vocab_size,
    "previous_stage": STAGE3_MODEL_DIR,
}

info_path = f"{OUTPUT_DIR}/stage_info.json"
with open(info_path, "w", encoding="utf-8") as f:
    json.dump(stage_info, f, indent=2)

print(f"Stage info saved to {info_path}")

In [None]:
# Copy token mapping
import shutil
shutil.copy(
    f"{STAGE3_MODEL_DIR}/token_mapping.json",
    f"{OUTPUT_DIR}/token_mapping.json"
)
print("Copied token mapping")

In [None]:
print("\n" + "=" * 60)
print("Stage 4 Complete: All Output Embeddings Trained!")
print("=" * 60)
print(f"\nCheckpoint saved to: {OUTPUT_DIR}")
print("\nNext steps:")
print("  Run 05_stage5_harmonization.ipynb")