# Phase 2.2: Initialize New Token Embeddings

Initialize new Korean token embeddings using hybrid method:
- EEVE: Subword decomposition (primary)
- WECHSEL: Bilingual dictionary alignment (for medical terms)

## Contents
1. Setup and Load Model
2. Load Original Tokenizer (for subword decomposition)
3. Load Bilingual Dictionary
4. Implement Initialization Methods
5. Initialize All New Embeddings
6. Verify Initialization
7. Save Initialized Model

In [None]:
# Setup
import sys
import os
sys.path.append("..")

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
from tqdm import tqdm

# GPU setup
from config.gpu_utils import setup_gpu, print_memory_usage, clear_memory

# Directories
RESIZED_MODEL_DIR = "../models/resized_model"
INITIALIZED_MODEL_DIR = "../models/initialized_model"
BILINGUAL_DICT_DIR = "../data/bilingual_dict"

os.makedirs(INITIALIZED_MODEL_DIR, exist_ok=True)

print(f"Output directory: {INITIALIZED_MODEL_DIR}")

---
## 1. Load Resized Model and Tokenizers

In [None]:
# Load token mapping
mapping_path = f"{RESIZED_MODEL_DIR}/token_mapping.json"
with open(mapping_path, "r", encoding="utf-8") as f:
    token_mapping = json.load(f)

BASE_MODEL = token_mapping["base_model"]
original_vocab_size = token_mapping["original_vocab_size"]
new_vocab_size = token_mapping["new_vocab_size"]
new_tokens = token_mapping["new_tokens"]

print(f"Base model: {BASE_MODEL}")
print(f"Original vocab: {original_vocab_size}")
print(f"New vocab: {new_vocab_size}")
print(f"New tokens: {len(new_tokens)}")

In [None]:
# Load resized model
print("\nLoading resized model...")

model = AutoModelForCausalLM.from_pretrained(
    RESIZED_MODEL_DIR,
    torch_dtype=torch.bfloat16,
    device_map="cpu",  # CPU for embedding manipulation
    trust_remote_code=True,
)

print(f"Model loaded!")
print(f"Vocab size: {model.config.vocab_size}")

In [None]:
# Load new tokenizer (with Korean tokens)
new_tokenizer = AutoTokenizer.from_pretrained(RESIZED_MODEL_DIR)
print(f"New tokenizer vocab size: {len(new_tokenizer)}")

# Load original tokenizer (for subword decomposition)
old_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
print(f"Original tokenizer vocab size: {len(old_tokenizer)}")

---
## 2. Load Bilingual Dictionary

In [None]:
# Load bilingual medical dictionary
dict_path = f"{BILINGUAL_DICT_DIR}/bilingual_medical_dict.json"

if os.path.exists(dict_path):
    with open(dict_path, "r", encoding="utf-8") as f:
        bilingual_dict = json.load(f)
    print(f"Loaded bilingual dictionary: {len(bilingual_dict)} entries")
    
    # Create reverse mapping (Korean -> English)
    ko_to_en = {v: k for k, v in bilingual_dict.items()}
    print(f"Korean to English mapping: {len(ko_to_en)} entries")
else:
    print(f"Bilingual dictionary not found at {dict_path}")
    print("Will use subword decomposition only.")
    bilingual_dict = {}
    ko_to_en = {}

---
## 3. Get Embedding Layers

In [None]:
# Get embedding layers
input_embeds = model.get_input_embeddings().weight.data
output_embeds = model.get_output_embeddings().weight.data

print(f"Input embeddings shape: {input_embeds.shape}")
print(f"Output embeddings shape: {output_embeds.shape}")
print(f"Embedding dimension: {input_embeds.shape[1]}")

embedding_dim = input_embeds.shape[1]

In [None]:
# Check current state of new embeddings
new_input_embeds = input_embeds[original_vocab_size:]
new_output_embeds = output_embeds[original_vocab_size:]

print(f"\nNew embeddings before initialization:")
print(f"  Input - Mean: {new_input_embeds.mean():.6f}, Std: {new_input_embeds.std():.6f}")
print(f"  Output - Mean: {new_output_embeds.mean():.6f}, Std: {new_output_embeds.std():.6f}")

---
## 4. Implement Initialization Methods

In [None]:
class HybridEmbeddingInitializer:
    """
    Hybrid embedding initialization combining:
    - EEVE: Subword decomposition
    - WECHSEL: Bilingual dictionary alignment
    """
    
    def __init__(
        self,
        input_embeds,
        output_embeds,
        new_tokenizer,
        old_tokenizer,
        original_vocab_size,
        bilingual_dict=None,
    ):
        self.input_embeds = input_embeds
        self.output_embeds = output_embeds
        self.new_tokenizer = new_tokenizer
        self.old_tokenizer = old_tokenizer
        self.original_vocab_size = original_vocab_size
        
        # Create Korean -> English mapping
        self.ko_to_en = {}
        if bilingual_dict:
            self.ko_to_en = {v: k for k, v in bilingual_dict.items()}
        
        # Statistics
        self.stats = {
            "subword_average": 0,
            "bilingual_aligned": 0,
            "mean_fallback": 0,
            "total": 0,
        }
    
    def get_subword_ids(self, token):
        """Tokenize token with old tokenizer to get subword IDs"""
        ids = self.old_tokenizer.encode(token, add_special_tokens=False)
        # Filter out special token IDs if any
        ids = [i for i in ids if i < self.original_vocab_size]
        return ids
    
    def find_english_equivalent(self, korean_token):
        """Find English equivalent for Korean token"""
        # Direct lookup
        if korean_token in self.ko_to_en:
            return self.ko_to_en[korean_token]
        
        # Check if token is part of a longer Korean word
        for ko, en in self.ko_to_en.items():
            if korean_token in ko or ko in korean_token:
                return en
        
        return None
    
    def initialize_with_subword_average(self, token, token_idx):
        """
        EEVE method: Initialize using subword decomposition.
        - Input embeddings: Average of all subword embeddings
        - Output embeddings: First subword only
        """
        subword_ids = self.get_subword_ids(token)
        
        if len(subword_ids) == 0:
            return False
        
        # Input: Average of all subwords
        self.input_embeds[token_idx] = self.input_embeds[subword_ids].mean(dim=0)
        
        # Output: First subword only (EEVE finding)
        self.output_embeds[token_idx] = self.output_embeds[subword_ids[0]]
        
        self.stats["subword_average"] += 1
        return True
    
    def initialize_with_bilingual_alignment(self, token, token_idx):
        """
        WECHSEL method: Initialize using bilingual dictionary alignment.
        """
        english_equiv = self.find_english_equivalent(token)
        
        if english_equiv is None:
            return False
        
        # Get English token embeddings
        en_ids = self.get_subword_ids(english_equiv)
        
        if len(en_ids) == 0:
            return False
        
        # Use average of English subword embeddings
        self.input_embeds[token_idx] = self.input_embeds[en_ids].mean(dim=0)
        self.output_embeds[token_idx] = self.output_embeds[en_ids[0]]
        
        self.stats["bilingual_aligned"] += 1
        return True
    
    def initialize_with_mean_fallback(self, token_idx):
        """
        Fallback: Use mean of all original embeddings.
        """
        self.input_embeds[token_idx] = self.input_embeds[:self.original_vocab_size].mean(dim=0)
        self.output_embeds[token_idx] = self.output_embeds[:self.original_vocab_size].mean(dim=0)
        
        self.stats["mean_fallback"] += 1
        return True
    
    def initialize_token(self, token, prefer_bilingual=True):
        """
        Initialize a single token embedding using hybrid method.
        """
        # Get token ID
        token_idx = self.new_tokenizer.convert_tokens_to_ids(token)
        
        # Skip if it's an original token
        if token_idx < self.original_vocab_size:
            return "original"
        
        # Skip if it's UNK
        if token_idx == self.new_tokenizer.unk_token_id:
            return "unk"
        
        self.stats["total"] += 1
        
        # Try bilingual alignment first (for medical terms)
        if prefer_bilingual and self.ko_to_en:
            if self.initialize_with_bilingual_alignment(token, token_idx):
                return "bilingual_aligned"
        
        # Try subword decomposition
        if self.initialize_with_subword_average(token, token_idx):
            return "subword_average"
        
        # Fallback to mean
        self.initialize_with_mean_fallback(token_idx)
        return "mean_fallback"
    
    def initialize_all(self, tokens, prefer_bilingual=True):
        """
        Initialize all new tokens.
        """
        results = []
        
        for token in tqdm(tokens, desc="Initializing embeddings"):
            method = self.initialize_token(token, prefer_bilingual)
            results.append((token, method))
        
        return results
    
    def get_stats(self):
        return self.stats

---
## 5. Initialize All New Embeddings

In [None]:
# Create initializer
initializer = HybridEmbeddingInitializer(
    input_embeds=input_embeds,
    output_embeds=output_embeds,
    new_tokenizer=new_tokenizer,
    old_tokenizer=old_tokenizer,
    original_vocab_size=original_vocab_size,
    bilingual_dict=bilingual_dict if bilingual_dict else None,
)

print(f"Initializer created")
print(f"Bilingual dictionary entries: {len(initializer.ko_to_en)}")

In [None]:
# Initialize all new tokens
print(f"\nInitializing {len(new_tokens)} new token embeddings...")

results = initializer.initialize_all(new_tokens, prefer_bilingual=True)

# Print statistics
stats = initializer.get_stats()
print(f"\nInitialization Statistics:")
for method, count in stats.items():
    if count > 0:
        pct = count / stats["total"] * 100 if stats["total"] > 0 else 0
        print(f"  {method}: {count} ({pct:.1f}%)")

In [None]:
# Show sample results
print("\nSample initialization results:")
print(f"{'Token':<20} | {'Method':<20}")
print("-" * 45)

# Show some bilingual aligned
bilingual_samples = [r for r in results if r[1] == "bilingual_aligned"][:5]
for token, method in bilingual_samples:
    print(f"{token:<20} | {method:<20}")

# Show some subword average
subword_samples = [r for r in results if r[1] == "subword_average"][:5]
for token, method in subword_samples:
    print(f"{token:<20} | {method:<20}")

---
## 6. Verify Initialization

In [None]:
# Check embedding statistics after initialization
new_input_embeds_after = input_embeds[original_vocab_size:]
new_output_embeds_after = output_embeds[original_vocab_size:]
orig_input_embeds = input_embeds[:original_vocab_size]

print("Embedding statistics after initialization:")
print(f"\nOriginal embeddings:")
print(f"  Input - Mean: {orig_input_embeds.mean():.6f}, Std: {orig_input_embeds.std():.6f}")

print(f"\nNew embeddings (after initialization):")
print(f"  Input - Mean: {new_input_embeds_after.mean():.6f}, Std: {new_input_embeds_after.std():.6f}")
print(f"  Output - Mean: {new_output_embeds_after.mean():.6f}, Std: {new_output_embeds_after.std():.6f}")

In [None]:
# Verify specific medical terms are well-initialized
medical_test_terms = ["의사", "환자", "병원", "치료", "진단"]

print("\nMedical term embedding check:")
for term in medical_test_terms:
    token_id = new_tokenizer.convert_tokens_to_ids(term)
    
    if token_id >= original_vocab_size:
        embed = input_embeds[token_id]
        norm = torch.norm(embed).item()
        print(f"  {term} (id={token_id}): norm={norm:.4f}")
    else:
        print(f"  {term}: not in new tokens (may be subword)")

In [None]:
# Test model forward pass with initialized embeddings
print("\nTesting forward pass...")

test_text = "환자가 발열 증상을 호소합니다."
inputs = new_tokenizer(test_text, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

print(f"Input: {test_text}")
print(f"Token IDs: {inputs['input_ids'].tolist()}")
print(f"Output logits shape: {outputs.logits.shape}")
print(f"Forward pass successful!")

---
## 7. Save Initialized Model

In [None]:
# Save initialized model
print(f"\nSaving initialized model to {INITIALIZED_MODEL_DIR}...")

model.save_pretrained(INITIALIZED_MODEL_DIR)
new_tokenizer.save_pretrained(INITIALIZED_MODEL_DIR)

print("Model and tokenizer saved!")

# List saved files
print("\nSaved files:")
for f in os.listdir(INITIALIZED_MODEL_DIR):
    size = os.path.getsize(os.path.join(INITIALIZED_MODEL_DIR, f)) / (1024**2)
    print(f"  {f}: {size:.1f} MB")

In [None]:
# Save initialization info
init_info = {
    "base_model": BASE_MODEL,
    "original_vocab_size": original_vocab_size,
    "new_vocab_size": new_vocab_size,
    "new_tokens_count": len(new_tokens),
    "initialization_method": "hybrid (EEVE + WECHSEL)",
    "statistics": stats,
    "bilingual_dict_size": len(bilingual_dict) if bilingual_dict else 0,
}

info_path = f"{INITIALIZED_MODEL_DIR}/initialization_info.json"
with open(info_path, "w", encoding="utf-8") as f:
    json.dump(init_info, f, indent=2)

print(f"\nInitialization info saved to {info_path}")

In [None]:
# Copy token mapping
import shutil

src_mapping = f"{RESIZED_MODEL_DIR}/token_mapping.json"
dst_mapping = f"{INITIALIZED_MODEL_DIR}/token_mapping.json"
shutil.copy(src_mapping, dst_mapping)

print(f"Copied token mapping")

In [None]:
print("\n" + "=" * 60)
print("Embedding Initialization Complete!")
print("=" * 60)
print(f"\nInitialized model saved to: {INITIALIZED_MODEL_DIR}")
print(f"\nInitialization breakdown:")
for method, count in stats.items():
    if count > 0:
        print(f"  {method}: {count}")
print("\nPhase 2 Complete!")
print("\nNext steps:")
print("  1. Move to Phase 3: Staged Training")
print("  2. Run phase3_staged_training/01_stage1_new_input_embeds.ipynb")