# Phase 6.1: AWQ Quantization for Deployment

Quantize the model to 4-bit AWQ format for efficient inference.

## Contents
1. Setup
2. Load Model
3. Apply AWQ Quantization
4. Verify Quantized Model
5. Save Quantized Model

In [None]:
# Setup
import sys
import os
sys.path.append("..")

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig
from awq import AutoAWQForCausalLM
import json

# GPU setup
from config.gpu_utils import setup_gpu, print_memory_usage, clear_memory
device = setup_gpu()

print_memory_usage()

In [None]:
# Directories
# Primary: Use instruction-tuned model (based on expanded model)
MODEL_DIR = "../models/instruction_tuned"

# Alternative: Use expanded model directly (no instruction tuning)
# MODEL_DIR = "../models/final/korean_medgemma_expanded"

# Legacy (non-expanded):
# MODEL_DIR = "../models/final/korean_medgemma"

OUTPUT_DIR = "../models/korean_medgemma_awq"

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Input model: {MODEL_DIR}")
print(f"Output dir: {OUTPUT_DIR}")

---
## 1. AWQ Quantization Configuration

In [None]:
# AWQ configuration
AWQ_CONFIG = {
    "w_bit": 4,  # 4-bit quantization
    "q_group_size": 128,  # Group size for quantization
    "zero_point": True,  # Use zero point
    "version": "GEMM",  # Optimized for inference
}

print("AWQ Configuration:")
for key, value in AWQ_CONFIG.items():
    print(f"  {key}: {value}")

---
## 2. Load Model for Quantization

In [None]:
# Load model using AutoAWQ
print("Loading model for AWQ quantization...")

model = AutoAWQForCausalLM.from_pretrained(
    MODEL_DIR,
    trust_remote_code=True,
    safetensors=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

print("Model loaded!")
print_memory_usage()

---
## 3. Prepare Calibration Data

In [None]:
# Prepare calibration data for AWQ
# AWQ needs sample data to calibrate quantization

calibration_texts = [
    # Korean medical texts
    "고혈압은 혈압이 정상보다 높은 상태를 말합니다. 수축기 혈압이 140mmHg 이상이거나 이완기 혈압이 90mmHg 이상인 경우를 고혈압으로 정의합니다.",
    "당뇨병은 인슐린 분비나 작용에 문제가 생겨 혈당이 높아지는 대사 질환입니다. 제1형 당뇨병과 제2형 당뇨병으로 분류됩니다.",
    "폐렴은 폐에 염증이 생기는 질환으로, 세균, 바이러스, 곰팡이 등이 원인이 될 수 있습니다. 기침, 발열, 호흡곤란이 주요 증상입니다.",
    "심근경색은 심장 근육에 혈액 공급이 차단되어 발생하는 응급 상황입니다. 가슴 통증, 호흡곤란, 식은땀이 주요 증상입니다.",
    "뇌졸중은 뇌혈관이 막히거나 터져서 발생하는 질환입니다. 갑작스러운 마비, 언어 장애, 두통이 나타날 수 있습니다.",
    
    # English medical texts
    "Hypertension is defined as blood pressure consistently above 140/90 mmHg. It is a major risk factor for heart disease, stroke, and kidney disease.",
    "Diabetes mellitus is a metabolic disorder characterized by elevated blood glucose levels. Type 2 diabetes is the most common form.",
    "Pneumonia is an infection that inflames the air sacs in one or both lungs. Symptoms include cough, fever, and difficulty breathing.",
    "Myocardial infarction occurs when blood flow to the heart muscle is blocked. Prompt treatment is essential to minimize heart damage.",
    "Stroke is a medical emergency caused by disrupted blood supply to the brain. Symptoms include sudden weakness and speech difficulties.",
]

print(f"Prepared {len(calibration_texts)} calibration texts")

---
## 4. Apply AWQ Quantization

In [None]:
# Quantize model
print("\nApplying AWQ quantization...")
print("This may take 10-30 minutes depending on model size.")

quant_config = {
    "zero_point": AWQ_CONFIG["zero_point"],
    "q_group_size": AWQ_CONFIG["q_group_size"],
    "w_bit": AWQ_CONFIG["w_bit"],
    "version": AWQ_CONFIG["version"],
}

model.quantize(
    tokenizer,
    quant_config=quant_config,
    calib_data=calibration_texts,
)

print("\nQuantization complete!")
print_memory_usage()

---
## 5. Verify Quantized Model

In [None]:
# Test quantized model
print("\nTesting quantized model...")

test_prompts = [
    "고혈압의 증상과 치료법은 무엇인가요?",
    "What are the symptoms of diabetes?",
]

for prompt in test_prompts:
    formatted_prompt = f"""<|im_start|>system
당신은 의료 AI 어시스턴트입니다.
<|im_end|>
<|im_start|>user
{prompt}
<|im_end|>
<|im_start|>assistant
"""
    
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            temperature=0.7,
            do_sample=True,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    print(f"\nQ: {prompt}")
    print(f"A: {response[len(formatted_prompt):][:200]}...")

---
## 6. Save Quantized Model

In [None]:
# Save quantized model
print(f"\nSaving quantized model to {OUTPUT_DIR}...")

model.save_quantized(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print("Quantized model saved!")

In [None]:
# Save quantization info
quant_info = {
    "source_model": MODEL_DIR,
    "quantization_method": "AWQ",
    "config": AWQ_CONFIG,
    "calibration_samples": len(calibration_texts),
}

with open(f"{OUTPUT_DIR}/quantization_info.json", "w") as f:
    json.dump(quant_info, f, indent=2)

print("Quantization info saved")

In [None]:
# Check model size
import os

def get_folder_size(folder):
    total = 0
    for path, dirs, files in os.walk(folder):
        for f in files:
            fp = os.path.join(path, f)
            total += os.path.getsize(fp)
    return total

original_size = get_folder_size(MODEL_DIR) / (1024**3)
quantized_size = get_folder_size(OUTPUT_DIR) / (1024**3)

print(f"\nModel size comparison:")
print(f"  Original: {original_size:.2f} GB")
print(f"  Quantized: {quantized_size:.2f} GB")
print(f"  Compression: {original_size / quantized_size:.1f}x")

In [None]:
print("\n" + "=" * 60)
print("AWQ Quantization Complete!")
print("=" * 60)
print(f"\nQuantized model saved to: {OUTPUT_DIR}")
print(f"Compression ratio: {original_size / quantized_size:.1f}x")
print("\nNext steps:")
print("  Run 02_deploy_vllm.ipynb to deploy with vLLM")