# Phase 1.3: Merge Korean Tokens with MedGemma Tokenizer

Merge filtered Korean tokens into the MedGemma tokenizer.

## Contents
1. Load MedGemma Tokenizer
2. Load Filtered Korean Tokens
3. Filter Duplicates
4. Add New Tokens
5. Verify Merged Tokenizer
6. Save Merged Tokenizer

In [None]:
# Setup
import sys
import os
sys.path.append("..")

from transformers import AutoTokenizer
import json
from tqdm import tqdm

# Directories
MODEL_DIR = "../models/tokenizer"
MERGED_DIR = "../models/merged_tokenizer"

os.makedirs(MERGED_DIR, exist_ok=True)

print(f"Merged tokenizer directory: {MERGED_DIR}")

---
## 1. Load MedGemma/Gemma Tokenizer

In [None]:
# Load base tokenizer
# Note: Use actual MedGemma when available, or Gemma base for testing
BASE_MODEL = "google/gemma-2b"  # Change to "google/medgemma-4b-it" when available

print(f"Loading base tokenizer: {BASE_MODEL}")

try:
    base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    print(f"Loaded successfully!")
    print(f"Original vocabulary size: {len(base_tokenizer)}")
except Exception as e:
    print(f"Error loading tokenizer: {e}")
    print("Make sure you have access to the model (may require login)")

In [None]:
# Analyze base tokenizer
print("\nBase tokenizer info:")
print(f"  Vocab size: {len(base_tokenizer)}")
print(f"  Model max length: {base_tokenizer.model_max_length}")
print(f"  Padding side: {base_tokenizer.padding_side}")

print("\nSpecial tokens:")
for name, token in base_tokenizer.special_tokens_map.items():
    print(f"  {name}: {token}")

In [None]:
# Test Korean tokenization with base tokenizer
test_korean = "안녕하세요, 저는 의료 AI 어시스턴트입니다."

base_tokens = base_tokenizer.tokenize(test_korean)
base_ids = base_tokenizer.encode(test_korean)

print(f"Korean text: {test_korean}")
print(f"Base tokenizer tokens ({len(base_tokens)}): {base_tokens}")
print(f"\nNote: Each Korean character may be split into multiple byte tokens.")

---
## 2. Load Filtered Korean Tokens

In [None]:
# Load filtered Korean tokens
filtered_tokens_path = f"{MODEL_DIR}/filtered_korean_tokens.txt"

with open(filtered_tokens_path, "r", encoding="utf-8") as f:
    korean_tokens = [line.strip() for line in f if line.strip()]

print(f"Loaded {len(korean_tokens)} filtered Korean tokens")
print(f"\nSample tokens: {korean_tokens[:20]}")

---
## 3. Filter Duplicates

In [None]:
# Get existing vocabulary
base_vocab = set(base_tokenizer.get_vocab().keys())
print(f"Base vocabulary size: {len(base_vocab)}")

# Filter out tokens already in base vocabulary
new_tokens = []
duplicate_tokens = []

for token in korean_tokens:
    # Clean SentencePiece prefix for comparison
    # Note: Different tokenizers may use different prefixes
    clean_token = token.replace("▁", "")
    
    # Check various forms
    if token in base_vocab or clean_token in base_vocab:
        duplicate_tokens.append(token)
    else:
        new_tokens.append(clean_token)  # Use clean version for HF tokenizer

print(f"\nDuplicate tokens (already in base): {len(duplicate_tokens)}")
print(f"New tokens to add: {len(new_tokens)}")

if duplicate_tokens:
    print(f"\nSample duplicates: {duplicate_tokens[:10]}")

In [None]:
# Remove empty and whitespace-only tokens
new_tokens = [t for t in new_tokens if t and not t.isspace()]

# Remove duplicates while preserving order
seen = set()
unique_new_tokens = []
for token in new_tokens:
    if token not in seen:
        seen.add(token)
        unique_new_tokens.append(token)

new_tokens = unique_new_tokens
print(f"Unique new tokens: {len(new_tokens)}")

---
## 4. Add New Tokens to Tokenizer

In [None]:
# Record original vocab size
original_vocab_size = len(base_tokenizer)
print(f"Original vocabulary size: {original_vocab_size}")

# Add new tokens
print(f"\nAdding {len(new_tokens)} new tokens...")
num_added = base_tokenizer.add_tokens(new_tokens)

print(f"Tokens added: {num_added}")
print(f"New vocabulary size: {len(base_tokenizer)}")

In [None]:
# Verify some new tokens were added
new_vocab_size = len(base_tokenizer)

print(f"\nVocabulary expansion:")
print(f"  Before: {original_vocab_size}")
print(f"  After: {new_vocab_size}")
print(f"  Added: {new_vocab_size - original_vocab_size}")

# Check specific tokens
test_tokens = ["의사", "환자", "병원", "치료", "진단"]
print(f"\nTest token IDs:")
for token in test_tokens:
    token_id = base_tokenizer.convert_tokens_to_ids(token)
    if token_id != base_tokenizer.unk_token_id:
        print(f"  {token}: {token_id} (added)")
    else:
        print(f"  {token}: UNK (not found as single token)")

---
## 5. Verify Merged Tokenizer

In [None]:
# Test Korean tokenization with merged tokenizer
test_sentences = [
    "안녕하세요, 저는 의료 AI 어시스턴트입니다.",
    "환자가 발열과 기침 증상을 호소합니다.",
    "당뇨병은 혈당 조절에 문제가 생기는 대사 질환입니다.",
    "MRI 검사 결과 뇌에 이상 소견이 발견되었습니다.",
]

print("Tokenization comparison (before vs after):")
print("=" * 80)

# Reload original tokenizer for comparison
original_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

for sentence in test_sentences:
    original_tokens = original_tokenizer.tokenize(sentence)
    merged_tokens = base_tokenizer.tokenize(sentence)
    
    improvement = len(original_tokens) / len(merged_tokens) if merged_tokens else 0
    
    print(f"\nSentence: {sentence}")
    print(f"Original ({len(original_tokens)} tokens): {original_tokens[:15]}...")
    print(f"Merged ({len(merged_tokens)} tokens): {merged_tokens[:15]}...")
    print(f"Improvement: {improvement:.2f}x")

In [None]:
# Verify English still works correctly
english_test = "The patient presents with symptoms of diabetes mellitus."

original_en_tokens = original_tokenizer.tokenize(english_test)
merged_en_tokens = base_tokenizer.tokenize(english_test)

print("English tokenization check:")
print(f"\nSentence: {english_test}")
print(f"Original ({len(original_en_tokens)} tokens): {original_en_tokens}")
print(f"Merged ({len(merged_en_tokens)} tokens): {merged_en_tokens}")
print(f"\nEnglish tokenization preserved: {original_en_tokens == merged_en_tokens}")

In [None]:
# Test encode/decode roundtrip
test_text = "환자의 혈압이 140/90 mmHg로 고혈압 소견입니다."

encoded = base_tokenizer.encode(test_text)
decoded = base_tokenizer.decode(encoded)

print("Encode/Decode roundtrip test:")
print(f"Original: {test_text}")
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")
print(f"Match: {test_text == decoded}")

---
## 6. Save Merged Tokenizer

In [None]:
# Save merged tokenizer
base_tokenizer.save_pretrained(MERGED_DIR)
print(f"Saved merged tokenizer to {MERGED_DIR}")

# List saved files
print("\nSaved files:")
for f in os.listdir(MERGED_DIR):
    size = os.path.getsize(os.path.join(MERGED_DIR, f))
    print(f"  {f}: {size/1024:.1f} KB")

In [None]:
# Save token mapping for embedding initialization
token_mapping = {
    "base_model": BASE_MODEL,
    "original_vocab_size": original_vocab_size,
    "new_vocab_size": new_vocab_size,
    "new_tokens_count": new_vocab_size - original_vocab_size,
    "new_tokens": new_tokens,
    "new_token_ids": {
        token: base_tokenizer.convert_tokens_to_ids(token)
        for token in new_tokens
        if base_tokenizer.convert_tokens_to_ids(token) != base_tokenizer.unk_token_id
    },
}

mapping_path = f"{MERGED_DIR}/token_mapping.json"
with open(mapping_path, "w", encoding="utf-8") as f:
    json.dump(token_mapping, f, ensure_ascii=False, indent=2)

print(f"Saved token mapping to {mapping_path}")

In [None]:
# Summary
print("\n" + "=" * 60)
print("Tokenizer Merge Summary")
print("=" * 60)
print(f"\nBase model: {BASE_MODEL}")
print(f"Original vocab size: {original_vocab_size}")
print(f"New vocab size: {new_vocab_size}")
print(f"Korean tokens added: {new_vocab_size - original_vocab_size}")
print(f"\nMerged tokenizer saved to: {MERGED_DIR}")

In [None]:
print("\n" + "=" * 60)
print("Phase 1: Tokenizer Preparation Complete!")
print("=" * 60)
print("\nNext steps:")
print("  1. Move to Phase 2: Embedding Initialization")
print("  2. Run phase2_embedding/01_resize_embeddings.ipynb")