# Sarashina2-7B AWQ 4-bit Quantization

This notebook demonstrates how to quantize the `sbintuitions/sarashina2-7b` model using AWQ (Activation-aware Weight Quantization) to 4-bit precision.

## Install Required Packages

In [None]:
!pip install autoawq transformers accelerate datasets huggingface_hub tqdm

## Import Libraries

In [None]:
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
from datasets import load_dataset
from huggingface_hub import HfApi, create_repo
from tqdm.auto import tqdm
import torch
import time
import re
import unicodedata
import os

## Configuration

In [None]:
model_path = "sbintuitions/sarashina2-7b"
quant_path = "sarashina2-7b-4bit-awq"
hf_model_id = "ronantakizawa/sarashina2-7b-4bit-awq"  # HuggingFace repo name
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"}

# Device configuration
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cpu":
    print("⚠️ Warning: Running on CPU. Quantization will be slow. GPU is recommended.")

## Load Model and Tokenizer

In [None]:
# Load model with proper device placement
print("Loading model...")
model = AutoAWQForCausalLM.from_pretrained(
    model_path,
    device_map="cuda:0" if torch.cuda.is_available() else "cpu",
    safetensors=False,  # This model uses PyTorch .bin files
    **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Set pad token if not already set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded successfully on device: {next(model.parameters()).device}")

## Prepare Calibration Data

We'll use the izumi-lab/llm-japanese-dataset which contains over 9 million diverse Japanese text samples including chat conversations, Wikipedia summaries, and various sources. This dataset is specifically designed for training Japanese LLMs and provides excellent diversity for calibration.

In [None]:
# Load the Japanese dataset for calibration
print("Loading calibration dataset...")
print("📚 Dataset: range3/wikipedia-ja-20230101 (Japanese Wikipedia)")
print("🎯 Target: 512 calibration + 100 test samples\n")

# Wikipedia articles are much better for AWQ - longer, more natural text
print("⏳ Loading Japanese Wikipedia dataset...\n")

try:
    # Load Japanese Wikipedia dataset
    dataset = load_dataset(
        "range3/wikipedia-ja-20230101",
        split="train[:10000]"  # Load first 10k articles
    )

    print(f"✅ Dataset loaded! Got {len(dataset):,} Wikipedia articles\n")

    # Debug: Check dataset structure and length distribution
    print("🔍 Examining dataset structure...")
    sample = dataset[0]
    print(f"  Keys: {list(sample.keys())}")

    # Sample 10 random examples to see length distribution
    import random
    print("\n  Checking length distribution (10 random articles):")
    lengths = []
    for _ in range(10):
        idx = random.randint(0, min(1000, len(dataset)-1))
        article = dataset[idx]
        # Wikipedia typically has 'text' or 'content' field
        text = article.get('text', article.get('content', ''))
        length = len(text)
        lengths.append(length)

    print(f"    Lengths: {sorted(lengths)}")
    print(f"    Average: {sum(lengths) // len(lengths):,} chars")
    print(f"    Min: {min(lengths):,}, Max: {max(lengths):,}")

    # Show a sample article preview
    sample_text = dataset[0].get('text', dataset[0].get('content', ''))
    print(f"\n  Sample article preview ({len(sample_text):,} chars):")
    print(f"    {sample_text[:200]}...")

    print("\n" + "="*60 + "\n")

except Exception as e:
    print(f"❌ Error loading dataset: {type(e).__name__}: {e}")
    print("\n💡 Creating high-quality calibration data as fallback...")

    # High-quality fallback samples with longer text
    calibration_data = [
        "人工知能の発展により、私たちの生活は大きく変化しています。機械学習やディープラーニングの技術が、様々な分野で応用されています。この技術革新は、医療、教育、ビジネスなど、あらゆる領域に影響を与えています。特に自然言語処理の分野では、大規模言語モデルの登場により、人間のような文章生成が可能になりました。これらのモデルは、膨大なテキストデータから学習し、文脈を理解して適切な応答を生成することができます。今後も技術の進歩により、より高度なAIシステムが開発されることが期待されています。深層学習の仕組みは、人間の脳の神経回路網を模倣したニューラルネットワークに基づいています。",
        "日本の四季は美しく、春には桜が咲き、夏には祭りが開催され、秋には紅葉が見られ、冬には雪景色が楽しめます。それぞれの季節には独特の魅力があり、日本文化の重要な要素となっています。季節の移り変わりは、日本人の感性や美意識に深く影響を与えてきました。春の花見や夏の花火大会、秋の月見、冬の雪まつりなど、季節ごとの行事も数多く存在し、日本の伝統を今に伝えています。これらの行事は地域ごとに特色があり、その土地の歴史や文化を反映しています。",
        "最新の研究によれば、深層学習モデルは自然言語処理タスクにおいて人間レベルの性能を達成しつつあります。特にトランスフォーマーアーキテクチャの登場により、言語理解の精度が大幅に向上しました。これらのモデルは、翻訳、要約、質問応答など、多様なタスクで活用されています。BERTやGPTなどの事前学習モデルは、大規模なコーパスから言語の統計的パターンを学習し、少量のタスク固有データで高い性能を発揮することができます。注意機構と呼ばれる技術により、モデルは文脈の重要な部分に焦点を当てることができます。",
        "東京は日本の首都であり、世界でも最も人口密度の高い都市の一つです。伝統と最新技術が融合した独特の雰囲気を持ち、多くの観光客を魅了しています。高層ビルと歴史的建造物が共存する景観は、東京ならではの魅力です。浅草の雷門や明治神宮などの伝統的な観光地から、渋谷や新宿などの近代的な繁華街まで、多様な顔を持つ都市として知られています。また、東京は世界有数の経済センターでもあり、多くの国際企業が本社を構えています。",
        "気候変動は現代社会が直面する最も重要な課題の一つです。温室効果ガスの削減や再生可能エネルギーの活用など、持続可能な社会の実現に向けた取り組みが世界中で進められています。環境保護と経済発展のバランスを取ることが、今後の重要な課題となっています。太陽光発電や風力発電などのクリーンエネルギー技術の普及、森林保全活動、プラスチック削減など、様々な取り組みが行われています。国際的な協力も不可欠であり、パリ協定などの枠組みが整備されています。",
    ] * 103  # 515 samples
    test_data = calibration_data[:100]
    calibration_data = calibration_data[100:612]
    print(f"   ✅ Created {len(calibration_data)} calibration + {len(test_data)} test samples")
    print("   ⚠️  Note: Using repeated samples - quantization quality may be lower\n")
    dataset = None

# Extract text from Wikipedia articles
def extract_text_from_sample(sample):
    """Extract text from Wikipedia article."""
    # Try different possible field names
    text = sample.get('text', sample.get('content', sample.get('article', '')))
    return text.strip() if text else ''

# Only run data collection if dataset loaded successfully
if dataset is not None and len(dataset) > 0:
    calibration_data = []
    test_data = []
    num_calibration_samples = 512
    num_test_samples = 100

    # For Wikipedia, we want longer passages - aim for 200-1000 chars (roughly 50-250 tokens)
    min_length = 200  # Longer minimum for Wikipedia articles
    max_length = 2000  # Cap at 2000 chars (roughly 500 tokens)
    min_japanese_ratio = 0.50  # Wikipedia should be mostly Japanese

    stats = {
        'processed': 0,
        'too_short': 0,
        'too_long': 0,
        'insufficient_japanese': 0,
        'too_many_special_chars': 0,
        'too_much_whitespace': 0,
        'duplicates': 0,
        'accepted': 0
    }

    def is_valid_japanese_text(text):
        """Validate with stricter criteria for Wikipedia."""
        japanese_chars = len(re.findall(r'[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF]', text))
        total_chars = len(text)

        # Require at least 50% Japanese characters (Wikipedia should be mostly Japanese)
        japanese_ratio = japanese_chars / max(total_chars, 1)
        if japanese_ratio < min_japanese_ratio:
            stats['insufficient_japanese'] += 1
            return False

        # Check for excessive special characters or markup
        special_char_ratio = len(re.findall(r'[<>{}[\]\\|=]', text)) / max(total_chars, 1)
        if special_char_ratio >= 0.05:  # Stricter for Wikipedia
            stats['too_many_special_chars'] += 1
            return False

        # Check for excessive whitespace
        whitespace_ratio = len(re.findall(r'\s', text)) / max(total_chars, 1)
        if whitespace_ratio >= 0.50:
            stats['too_much_whitespace'] += 1
            return False

        return True

    def normalize_text(text):
        """Normalize unicode and clean text."""
        text = unicodedata.normalize('NFKC', text)
        # Remove excessive whitespace
        text = re.sub(r'\s+', ' ', text)
        # Remove common Wikipedia markup artifacts
        text = re.sub(r'\[\d+\]', '', text)  # Remove citation markers like [1], [2]
        return text.strip()

    def extract_passage(text, target_length=500):
        """Extract a passage of roughly target_length from longer text."""
        if len(text) <= max_length:
            return text

        # Try to extract a coherent passage (between sentences)
        # Find a good starting point (not too close to beginning)
        start = random.randint(0, max(0, len(text) - target_length - 100))

        # Find sentence boundary near start
        for i in range(start, min(start + 100, len(text))):
            if text[i] in '。！？\n':
                start = i + 1
                break

        # Extract passage
        end = min(start + target_length, len(text))

        # Find sentence boundary near end
        for i in range(end, min(end + 100, len(text))):
            if text[i] in '。！？\n':
                end = i + 1
                break

        return text[start:end].strip()

    print(f"🔍 Filtering criteria for Wikipedia articles:")
    print(f"  • Length: {min_length}-{max_length} characters (good for AWQ)")
    print(f"  • Japanese characters: ≥{min_japanese_ratio*100:.0f}%")
    print(f"  • Extracting coherent passages from longer articles")
    print(f"  • Target: {num_calibration_samples} diverse samples\n")

    seen_texts = set()
    start_time = time.time()

    print("🚀 Starting data collection...\n")

    for i, sample in enumerate(dataset):
        stats['processed'] += 1

        if len(calibration_data) >= num_calibration_samples and len(test_data) >= num_test_samples:
            break

        # Extract full article text
        full_text = extract_text_from_sample(sample)

        if not full_text or len(full_text) < min_length:
            stats['too_short'] += 1
            continue

        # Extract a passage of appropriate length
        text = extract_passage(full_text, target_length=600)
        text = normalize_text(text)

        # Length validation
        if len(text) < min_length:
            stats['too_short'] += 1
            continue
        if len(text) > max_length:
            text = text[:max_length]  # Truncate if still too long

        # Duplicate check
        if text in seen_texts:
            stats['duplicates'] += 1
            continue

        # Quality validation
        if not is_valid_japanese_text(text):
            continue

        # Accept sample
        seen_texts.add(text)
        stats['accepted'] += 1

        # Split into calibration and test sets
        if len(calibration_data) < num_calibration_samples:
            calibration_data.append(text)
        elif len(test_data) < num_test_samples:
            test_data.append(text)

        # Progress updates
        if stats['accepted'] % 50 == 0 or (i + 1) % 1000 == 0:
            elapsed = time.time() - start_time
            acceptance_rate = (stats['accepted'] / stats['processed']) * 100 if stats['processed'] > 0 else 0
            print(f"⏳ [{elapsed:.0f}s] Processed {stats['processed']:,} | Accepted: {stats['accepted']} ({acceptance_rate:.1f}%) | Collected: {len(calibration_data)}/{num_calibration_samples} cal + {len(test_data)}/{num_test_samples} test")

    elapsed_total = time.time() - start_time

    print(f"\n{'='*60}")
    print(f"✅ DATA COLLECTION COMPLETE!")
    print(f"{'='*60}")
    print(f"⏱️  Total time: {elapsed_total:.1f}s")
    print(f"📊 Processed: {stats['processed']:,} articles")
    print(f"✅ Accepted: {stats['accepted']} ({(stats['accepted']/stats['processed'])*100:.1f}%)")
    print(f"\n📦 Final datasets:")
    print(f"  • Calibration set: {len(calibration_data)} samples")
    print(f"  • Test set (held-out): {len(test_data)} samples")

    print(f"\n❌ Rejection breakdown:")
    print(f"  • Too short (<{min_length} chars): {stats['too_short']}")
    print(f"  • Too long (>{max_length} chars): {stats['too_long']}")
    print(f"  • Insufficient Japanese (<{min_japanese_ratio*100:.0f}%): {stats['insufficient_japanese']}")
    print(f"  • Too many special chars: {stats['too_many_special_chars']}")
    print(f"  • Too much whitespace: {stats['too_much_whitespace']}")
    print(f"  • Duplicates: {stats['duplicates']}")

if len(calibration_data) > 0:
    print(f"\n📈 Calibration set statistics:")
    cal_lengths = [len(s) for s in calibration_data]
    print(f"  • Length (chars): min={min(cal_lengths)}, max={max(cal_lengths)}, avg={sum(cal_lengths) // len(cal_lengths)}, median={sorted(cal_lengths)[len(cal_lengths)//2]}")

    print(f"\n🔢 Tokenization statistics (sample of 50):")
    sample_tokens = [len(tokenizer.encode(s)) for s in calibration_data[:50]]
    print(f"  • Token count: min={min(sample_tokens)}, max={max(sample_tokens)}, avg={sum(sample_tokens) // len(sample_tokens)}")

    first_sample = calibration_data[0]
    jp_chars = len(re.findall(r'[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF]', first_sample))
    jp_ratio = (jp_chars / len(first_sample)) * 100
    print(f"\n📝 First calibration sample ({len(first_sample)} chars, {jp_ratio:.1f}% Japanese):")
    print(f"{first_sample[:400]}...")

if len(test_data) > 0:
    test_sample = test_data[0]
    jp_chars_test = len(re.findall(r'[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF]', test_sample))
    jp_ratio_test = (jp_chars_test / len(test_sample)) * 100
    print(f"\n📝 First test sample ({len(test_sample)} chars, {jp_ratio_test:.1f}% Japanese):")
    print(f"{test_sample[:300]}...")

## Compute Baseline Perplexity (Optional but Recommended)

Before quantization, compute perplexity on the original FP16 model for comparison. This helps measure quantization quality loss.

In [None]:
# Compute baseline perplexity on original model (before quantization)
# This provides a comparison point for evaluating quantization quality

print("Computing baseline perplexity on original FP16 model...")
print("This may take a few minutes...\n")

def compute_perplexity_baseline(model, tokenizer, texts, max_samples=50):
    """Compute perplexity on a set of texts."""
    model.eval()
    total_loss = 0
    total_tokens = 0

    # Get device from model parameters
    device = next(model.parameters()).device

    with torch.no_grad():
        for i, text in enumerate(texts[:max_samples]):
            if (i + 1) % 10 == 0:
                print(f"  Processed {i + 1}/{max_samples} samples...")

            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            # Get model outputs
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss

            total_loss += loss.item() * inputs["input_ids"].size(1)
            total_tokens += inputs["input_ids"].size(1)

    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    return perplexity

# Use the held-out test set (NOT calibration data)
original_perplexity = compute_perplexity_baseline(model, tokenizer, test_data, max_samples=50)
print(f"\n✅ Original FP16 model perplexity: {original_perplexity:.2f}")
print(f"   This will be used to compare against quantized model perplexity later.\n")

## Quantize the Model

In [None]:
# Quantize the model with the Japanese dataset
model.quantize(tokenizer, quant_config=quant_config, calib_data=calibration_data)
print("Quantization complete!")

## Save Quantized Model

In [None]:
# Save quantized model BEFORE deleting original model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f"Quantized model saved to {quant_path}")

# Clean up memory - delete original model and free GPU memory
print("\nCleaning up memory...")
del model
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"✅ GPU memory cleared. Available memory: {torch.cuda.mem_get_info()[0] / 1024**3:.2f} GB")

## Load and Test Quantized Model

In [None]:
# Load quantized model with proper device handling
quantized_model = AutoAWQForCausalLM.from_quantized(
    quant_path,
    fuse_layers=True,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(quant_path)

# Get device from model parameters
device = next(quantized_model.parameters()).device
print(f"Quantized model loaded on device: {device}")

## Quick Test Generation

In [None]:
prompt = "おはようございます、今日の天気は"

# Tokenize and ensure proper device placement
# Get device from model parameters
device = next(quantized_model.parameters()).device

inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

generation_output = quantized_model.generate(
    **inputs,
    max_new_tokens=50,
    do_sample=True,
    temperature=0.7,
    top_p=0.95,
    pad_token_id=tokenizer.pad_token_id,
)

generated_text = tokenizer.decode(generation_output[0], skip_special_tokens=True)
print("Quick test generation:")
print(f"  Prompt: {prompt}")
print(f"  Output: {generated_text}")
print(f"\n✅ Model loaded and functional!")

## Quantization Quality Evaluation

In [None]:
# Prepare test prompts (diverse Japanese text for evaluation)
test_prompts = [
    "人工知能の発展により",
    "おはようございます、今日は",
    "日本の伝統文化について",
    "最新の科学技術では",
    "私が一番好きな季節は",
]

print("=" * 60)
print("🔍 QUANTIZATION QUALITY EVALUATION")
print("=" * 60)

In [None]:
# 1. Perplexity Comparison: Original vs Quantized
print("\n1️⃣ Computing Perplexity on Held-Out Test Set...")
print(f"   Using {len(test_data)} samples NOT seen during calibration\n")

def compute_perplexity(model, tokenizer, texts, max_samples=None):
    """Compute perplexity on a set of texts."""
    model.eval()
    total_loss = 0
    total_tokens = 0

    # Get device from model parameters
    device = next(model.parameters()).device

    texts_to_use = texts if max_samples is None else texts[:max_samples]

    with torch.no_grad():
        for text in tqdm(texts_to_use, desc="Computing perplexity"):
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            # Get model outputs
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss

            total_loss += loss.item() * inputs["input_ids"].size(1)
            total_tokens += inputs["input_ids"].size(1)

    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    return perplexity

# Compute perplexity for quantized model
print("   Computing perplexity for quantized model on test set...")
quantized_perplexity = compute_perplexity(quantized_model, tokenizer, test_data, max_samples=50)

# Compare with original model
print(f"\n   📊 Perplexity Comparison:")
print(f"      • Original FP16 model:  {original_perplexity:.2f}")
print(f"      • Quantized 4-bit model: {quantized_perplexity:.2f}")

# Calculate degradation
perplexity_increase = ((quantized_perplexity - original_perplexity) / original_perplexity) * 100
print(f"      • Perplexity increase: {perplexity_increase:+.1f}%")

print(f"\n   📊 Interpretation:")
print(f"      • Lower perplexity = better language modeling")
print(f"      • Typical 4-bit AWQ increases perplexity by 5-15%")
if perplexity_increase < 5:
    print(f"      ✅ Excellent: Minimal quality loss (<5%)")
elif perplexity_increase < 15:
    print(f"      ✅ Good: Within expected range (5-15%)")
elif perplexity_increase < 25:
    print(f"      ⚠️  Acceptable but higher than typical (15-25%)")
else:
    print(f"      ❌ Poor: Significant quality degradation (>{perplexity_increase:.0f}%)")
    print(f"         Consider: More calibration samples or different quantization settings")

In [None]:
# 2. Generation Quality Assessment
print("\n2️⃣ Generation Quality Assessment...")
print("\nGenerating responses for test prompts:\n")

generation_results = []

# Get device from model parameters
device = next(quantized_model.parameters()).device

for prompt in test_prompts:
    print(f"📝 Prompt: '{prompt}'")

    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Generate with quantized model
    start_time = time.time()
    output = quantized_model.generate(
        **inputs,
        max_new_tokens=40,
        do_sample=True,
        temperature=0.7,
        top_p=0.95,
        pad_token_id=tokenizer.pad_token_id,
    )
    generation_time = time.time() - start_time

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    print(f"   ✅ Output: {generated_text}")
    print(f"   ⏱️  Time: {generation_time:.2f}s\n")

    generation_results.append({
        'prompt': prompt,
        'output': generated_text,
        'time': generation_time
    })

avg_time = sum(r['time'] for r in generation_results) / len(generation_results)
print(f"Average generation time: {avg_time:.2f}s per prompt")

# Check for common quality issues
print("\n🔍 Quality Checks:")
quality_issues = []
for result in generation_results:
    output = result['output']
    # Check for repetition
    words = output.split()
    if len(words) != len(set(words)) and len(words) > 10:
        repetition_rate = 1 - len(set(words)) / len(words)
        if repetition_rate > 0.3:
            quality_issues.append(f"High repetition in: '{result['prompt'][:30]}...'")

    # Check for excessive special characters
    special_chars = len([c for c in output if not c.isalnum() and not c.isspace()])
    if special_chars / max(len(output), 1) > 0.2:
        quality_issues.append(f"Excessive special chars in: '{result['prompt'][:30]}...'")

if quality_issues:
    print("   ⚠️  Issues found:")
    for issue in quality_issues:
        print(f"      - {issue}")
else:
    print("   ✅ No major quality issues detected")

In [None]:
# 3. Memory and Size Analysis
print("\n3️⃣ Memory and Size Analysis...")

def get_directory_size(path):
    """Calculate total size of a directory in GB."""
    total_size = 0
    for dirpath, dirnames, filenames in os.walk(path):
        for filename in filenames:
            filepath = os.path.join(dirpath, filename)
            if os.path.exists(filepath):
                total_size += os.path.getsize(filepath)
    return total_size / (1024 ** 3)  # Convert to GB

if os.path.exists(quant_path):
    quantized_size = get_directory_size(quant_path)
    original_size = 14.6  # GB

    print(f"   📦 Original model size: {original_size:.2f} GB")
    print(f"   📦 Quantized model size: {quantized_size:.2f} GB")
    print(f"   💾 Size reduction: {((original_size - quantized_size) / original_size * 100):.1f}%")
    print(f"   💾 Compression ratio: {original_size / quantized_size:.2f}x")

    # Check if size reduction is as expected
    expected_reduction = 0.70  # 70% reduction expected
    actual_reduction = (original_size - quantized_size) / original_size

    if actual_reduction < expected_reduction - 0.1:
        print(f"\n   ⚠️  Size reduction ({actual_reduction*100:.1f}%) is lower than expected (~70%)")
        print(f"      This might indicate issues with quantization configuration")
    else:
        print(f"\n   ✅ Size reduction meets expectations!")
else:
    print(f"   ⚠️ Quantized model directory not found: {quant_path}")
    quantized_size = 0
    original_size = 14.6
    actual_reduction = 0

# GPU Memory usage
if torch.cuda.is_available():
    print(f"\n   🎮 GPU Memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"   🎮 GPU Memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    print(f"   🎮 GPU Memory usage reduction vs FP16: ~4x less memory")
else:
    print(f"\n   💻 Running on CPU (no GPU memory stats available)")

In [None]:
# 4. Summary Report
print("\n" + "=" * 60)
print("📊 QUANTIZATION SUMMARY REPORT")
print("=" * 60)

print(f"""
✅ Quantization Configuration:
   - Method: AWQ (Activation-aware Weight Quantization)
   - Precision: 4-bit
   - Group Size: 128
   - Calibration Samples: {len(calibration_data)}
   - Test Samples (held-out): {len(test_data)}

📈 Quality Metrics:
   - Original FP16 Perplexity: {original_perplexity:.2f}
   - Quantized 4-bit Perplexity: {quantized_perplexity:.2f}
   - Perplexity Degradation: {perplexity_increase:+.1f}%
   - Average Generation Time: {avg_time:.2f}s
   - Quality Issues Detected: {len(quality_issues)}

💾 Storage Efficiency:
   - Original Size: {original_size:.2f} GB
   - Quantized Size: {quantized_size:.2f} GB
   - Size Reduction: {((original_size - quantized_size) / original_size * 100):.1f}%
   - Compression Ratio: {original_size / max(quantized_size, 0.01):.2f}x

🎯 Overall Assessment:
""")

# Provide comprehensive assessment
assessment_score = 0

# Perplexity assessment
if perplexity_increase < 5:
    print("   ✅ Perplexity: Excellent - Minimal quality loss")
    assessment_score += 3
elif perplexity_increase < 15:
    print("   ✅ Perplexity: Good - Within expected range")
    assessment_score += 2
elif perplexity_increase < 25:
    print("   ⚠️  Perplexity: Acceptable but higher than typical")
    assessment_score += 1
else:
    print("   ❌ Perplexity: Poor - Significant quality degradation")
    print("      → Consider: More calibration samples or different quantization settings")

# Size reduction assessment
if actual_reduction >= 0.65:
    print("   ✅ Compression: Excellent - Meets expectations")
    assessment_score += 2
elif actual_reduction >= 0.5:
    print("   ⚠️  Compression: Acceptable but lower than expected")
    assessment_score += 1
else:
    print("   ❌ Compression: Poor - Verify quantization completed successfully")

# Generation quality assessment
if len(quality_issues) == 0:
    print("   ✅ Generation Quality: No issues detected")
    assessment_score += 1
else:
    print(f"   ⚠️  Generation Quality: {len(quality_issues)} issues detected")

print(f"\n🎯 Final Score: {assessment_score}/6")
if assessment_score >= 5:
    print("   ✅ EXCELLENT quantization - Ready for production use!")
elif assessment_score >= 3:
    print("   ✅ GOOD quantization - Suitable for most use cases")
else:
    print("   ⚠️  NEEDS IMPROVEMENT - Consider adjusting parameters")

print("\n" + "=" * 60)

In [None]:
# Login to Hugging Face (you'll need to provide your token)
from huggingface_hub import notebook_login
notebook_login()

In [None]:
quant_path = "sarashina2-7b-awq"
hf_model_id = "ronantakizawa/sarashina2-7b-4bit-awq"  # HuggingFace repo name
# Upload quantized model to Hugging Face
print(f"Uploading quantized model to {hf_model_id}...")

try:
    # Create repo (will skip if already exists)
    create_repo(hf_model_id, repo_type="model", exist_ok=True)
    print(f"Repository {hf_model_id} is ready")

    # Upload model files
    api = HfApi()
    api.upload_folder(
        folder_path=quant_path,
        repo_id=hf_model_id,
        repo_type="model",
        commit_message="Upload AWQ 4-bit quantized sarashina2-7b model"
    )

    print(f"✅ Model successfully uploaded to https://huggingface.co/{hf_model_id}")
except Exception as e:
    print(f"❌ Error uploading model: {e}")
    print("\nMake sure you:")
    print("1. Have run notebook_login() and provided your token")
    print("2. Have write access to the repository")
    print("3. Have sufficient disk space and internet connection")

## Create Model Card

In [None]:
quant_path = "sarashina2-7b-awq"
hf_model_id = "ronantakizawa/sarashina2-7b-4bit-awq"  # HuggingFace repo name
# Create a model card for the repository
model_card_content = f"""---
language:
- ja
- en
license: mit
tags:
- awq
- quantized
- 4-bit
- japanese
- llm
base_model: sbintuitions/sarashina2-7b
---

# Sarashina2-7B AWQ 4-bit Quantized

This is a 4-bit AWQ quantized version of [sbintuitions/sarashina2-7b](https://huggingface.co/sbintuitions/sarashina2-7b).

## Model Description

- **Base Model:** sarashina2-7b (7B parameters)
- **Quantization Method:** AWQ (Activation-aware Weight Quantization)
- **Quantization Precision:** 4-bit
- **Group Size:** 128
- **Calibration Dataset:** [izumi-lab/llm-japanese-dataset](https://huggingface.co/datasets/izumi-lab/llm-japanese-dataset) (512 samples)
- **Original Size:** ~14.6 GB
- **Quantized Size:** ~3-4 GB
- **Size Reduction:** ~70-75%

## Usage

```python
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = "{hf_model_id}"

# Load quantized model
model = AutoAWQForCausalLM.from_quantized(
    model_path,
    fuse_layers=True,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Generate text
prompt = "おはようございます、今日の天気は"
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {{k: v.to(model.device) for k, v in inputs.items()}}

outputs = model.generate(
    **inputs,
    max_new_tokens=50,
    do_sample=True,
    temperature=0.7,
    top_p=0.95,
)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```

## Installation

```bash
pip install autoawq transformers accelerate
```

## Performance

- **Memory Usage:** Reduced by ~70-75% compared to the original model
- **Inference Speed:** Faster inference due to smaller model size
- **Quality:** Minimal accuracy loss due to activation-aware quantization

## Limitations

- Requires GPU for optimal performance
- May have slight quality degradation compared to the full precision model
- Quantization is optimized for the calibration dataset distribution

## License

MIT License (inherited from base model)

## Citation

```bibtex
@misc{{sarashina2-7b-awq,
  author = {{Ronan Takizawa}},
  title = {{Sarashina2-7B AWQ 4-bit Quantized}},
  year = {{2025}},
  publisher = {{Hugging Face}},
  howpublished = {{\\url{{https://huggingface.co/{hf_model_id}}}}}
}}
```

## Base Model Citation

Please refer to the [original model card](https://huggingface.co/sbintuitions/sarashina2-7b) for the base model citation.
"""

# Save model card
import os
readme_path = os.path.join(quant_path, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
    f.write(model_card_content)

print(f"Model card created at {readme_path}")
print("\nYou can now upload this README.md with your model!")