# Phase 2.1: Resize Model Embeddings

Resize MedGemma embedding layers to accommodate new Korean tokens.

## Contents
1. Setup and Load Model
2. Load Merged Tokenizer
3. Resize Embeddings
4. Verify Resized Model
5. Save Resized Model

In [None]:
# Setup
import sys
import os
sys.path.append("..")

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json

# GPU setup
from config.gpu_utils import setup_gpu, print_memory_usage
device = setup_gpu()

# Directories
MERGED_TOKENIZER_DIR = "../models/merged_tokenizer"
RESIZED_MODEL_DIR = "../models/resized_model"

os.makedirs(RESIZED_MODEL_DIR, exist_ok=True)

print(f"Output directory: {RESIZED_MODEL_DIR}")

---
## 1. Load Base Model

In [None]:
# Load token mapping to get base model info
mapping_path = f"{MERGED_TOKENIZER_DIR}/token_mapping.json"

with open(mapping_path, "r", encoding="utf-8") as f:
    token_mapping = json.load(f)

BASE_MODEL = token_mapping["base_model"]
original_vocab_size = token_mapping["original_vocab_size"]
new_vocab_size = token_mapping["new_vocab_size"]

print(f"Base model: {BASE_MODEL}")
print(f"Original vocab size: {original_vocab_size}")
print(f"New vocab size: {new_vocab_size}")
print(f"New tokens: {new_vocab_size - original_vocab_size}")

In [None]:
# Load base model
print(f"\nLoading base model: {BASE_MODEL}")
print("This may take a few minutes...")

print_memory_usage()

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.bfloat16,
    device_map="cpu",  # Load on CPU for embedding manipulation
    trust_remote_code=True,
)

print(f"\nModel loaded successfully!")
print_memory_usage()

In [None]:
# Model architecture info
print("\nModel architecture:")
print(f"  Model type: {model.config.model_type}")
print(f"  Hidden size: {model.config.hidden_size}")
print(f"  Num layers: {model.config.num_hidden_layers}")
print(f"  Num heads: {model.config.num_attention_heads}")
print(f"  Vocab size (config): {model.config.vocab_size}")

In [None]:
# Get embedding layer info
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()  # lm_head

print("\nEmbedding layers:")
print(f"  Input embeddings shape: {input_embeddings.weight.shape}")
print(f"  Output embeddings shape: {output_embeddings.weight.shape}")
print(f"  Embedding dim: {input_embeddings.weight.shape[1]}")

---
## 2. Load Merged Tokenizer

In [None]:
# Load merged tokenizer
tokenizer = AutoTokenizer.from_pretrained(MERGED_TOKENIZER_DIR)

print(f"Loaded merged tokenizer")
print(f"Tokenizer vocab size: {len(tokenizer)}")

# Verify sizes match
assert len(tokenizer) == new_vocab_size, f"Vocab size mismatch: {len(tokenizer)} vs {new_vocab_size}"
print(f"\nVocab size verified: {len(tokenizer)}")

---
## 3. Resize Embeddings

In [None]:
# Record original embedding weights (for later initialization)
original_input_embeds = input_embeddings.weight.data.clone()
original_output_embeds = output_embeddings.weight.data.clone()

print(f"Saved original embeddings")
print(f"  Input shape: {original_input_embeds.shape}")
print(f"  Output shape: {original_output_embeds.shape}")

In [None]:
# Resize token embeddings
print(f"\nResizing embeddings: {original_vocab_size} -> {new_vocab_size}")

model.resize_token_embeddings(new_vocab_size)

# Verify resize
new_input_embeddings = model.get_input_embeddings()
new_output_embeddings = model.get_output_embeddings()

print(f"\nResized embeddings:")
print(f"  Input embeddings shape: {new_input_embeddings.weight.shape}")
print(f"  Output embeddings shape: {new_output_embeddings.weight.shape}")

In [None]:
# Verify original embeddings are preserved
preserved_input = torch.allclose(
    new_input_embeddings.weight.data[:original_vocab_size],
    original_input_embeds,
    atol=1e-6
)
preserved_output = torch.allclose(
    new_output_embeddings.weight.data[:original_vocab_size],
    original_output_embeds,
    atol=1e-6
)

print(f"\nOriginal embeddings preserved:")
print(f"  Input embeddings: {preserved_input}")
print(f"  Output embeddings: {preserved_output}")

In [None]:
# Check new token embeddings (should be randomly initialized or zero)
new_token_input_embeds = new_input_embeddings.weight.data[original_vocab_size:]
new_token_output_embeds = new_output_embeddings.weight.data[original_vocab_size:]

print(f"\nNew token embeddings (before initialization):")
print(f"  Input - Mean: {new_token_input_embeds.mean():.6f}, Std: {new_token_input_embeds.std():.6f}")
print(f"  Output - Mean: {new_token_output_embeds.mean():.6f}, Std: {new_token_output_embeds.std():.6f}")
print(f"\nNote: These will be properly initialized in the next notebook.")

---
## 4. Update Model Config

In [None]:
# Update model config
model.config.vocab_size = new_vocab_size

print(f"Updated model config:")
print(f"  vocab_size: {model.config.vocab_size}")

---
## 5. Verify Resized Model

In [None]:
# Test forward pass with original tokens
print("Testing forward pass with original tokens...")

test_input = tokenizer("Hello, how are you?", return_tensors="pt")
print(f"Input IDs: {test_input['input_ids']}")
print(f"Max ID: {test_input['input_ids'].max().item()}")

with torch.no_grad():
    outputs = model(**test_input)

print(f"\nOutput logits shape: {outputs.logits.shape}")
print(f"Forward pass successful!")

In [None]:
# Test forward pass with new Korean tokens
print("\nTesting forward pass with Korean tokens...")

korean_test = "안녕하세요, 의료 질문이 있습니다."
korean_input = tokenizer(korean_test, return_tensors="pt")

print(f"Korean text: {korean_test}")
print(f"Input IDs: {korean_input['input_ids']}")
print(f"Max ID: {korean_input['input_ids'].max().item()}")
print(f"Vocab size: {new_vocab_size}")

# Check if any IDs exceed vocab size
max_id = korean_input['input_ids'].max().item()
if max_id >= new_vocab_size:
    print(f"WARNING: Token ID {max_id} exceeds vocab size {new_vocab_size}!")
else:
    with torch.no_grad():
        outputs = model(**korean_input)
    print(f"\nOutput logits shape: {outputs.logits.shape}")
    print(f"Forward pass successful!")

---
## 6. Save Resized Model

In [None]:
# Save resized model
print(f"\nSaving resized model to {RESIZED_MODEL_DIR}...")

model.save_pretrained(RESIZED_MODEL_DIR)
tokenizer.save_pretrained(RESIZED_MODEL_DIR)

print("Model and tokenizer saved!")

# List saved files
print("\nSaved files:")
for f in os.listdir(RESIZED_MODEL_DIR):
    size = os.path.getsize(os.path.join(RESIZED_MODEL_DIR, f)) / (1024**2)
    print(f"  {f}: {size:.1f} MB")

In [None]:
# Save resize info
resize_info = {
    "base_model": BASE_MODEL,
    "original_vocab_size": original_vocab_size,
    "new_vocab_size": new_vocab_size,
    "new_tokens_added": new_vocab_size - original_vocab_size,
    "embedding_dim": model.config.hidden_size,
    "embedding_initialization": "random (needs proper initialization)",
}

info_path = f"{RESIZED_MODEL_DIR}/resize_info.json"
with open(info_path, "w", encoding="utf-8") as f:
    json.dump(resize_info, f, indent=2)

print(f"\nResize info saved to {info_path}")

In [None]:
# Copy token mapping for embedding initialization
import shutil

src_mapping = f"{MERGED_TOKENIZER_DIR}/token_mapping.json"
dst_mapping = f"{RESIZED_MODEL_DIR}/token_mapping.json"
shutil.copy(src_mapping, dst_mapping)

print(f"Copied token mapping to {dst_mapping}")

In [None]:
print("\n" + "=" * 60)
print("Model Embedding Resize Complete!")
print("=" * 60)
print(f"\nResized model saved to: {RESIZED_MODEL_DIR}")
print(f"Vocabulary: {original_vocab_size} -> {new_vocab_size} (+{new_vocab_size - original_vocab_size})")
print("\nIMPORTANT: New token embeddings are randomly initialized.")
print("Run 02_initialize_embeddings.ipynb to properly initialize them.")
print("\nNext steps:")
print("  1. Run 02_initialize_embeddings.ipynb for EEVE/WECHSEL initialization")