# Dynamic Quantization with MLX-LM

This notebook demonstrates how to use Dynamic Quantization with MLX-LM to estimate the sensitivity for each quantizable layer and apply different precision levels.

## What is Dynamic Quantization?
Dynamic quantization estimates the sensitivity for each quantizable layer and uses higher precision (e.g., 5 bits) for sensitive layers while using lower precision for less sensitive layers. This approach optimizes the balance between model size and quality.

## Requirements
- macOS with Apple Silicon (M1/M2/M3/M4)
- Python 3.9+
- MLX framework
- Sufficient disk space for model storage

## Step 1: Environment Setup and Dependencies

In [None]:
import os
import sys
import subprocess
from pathlib import Path
import json

# Environment setup
print("Setting up environment for Dynamic quantization...")

# Create project directories
project_dir = Path.cwd()
models_dir = project_dir / "models"
sensitivity_dir = project_dir / "sensitivities"  # For storing sensitivity files
models_dir.mkdir(exist_ok=True)
sensitivity_dir.mkdir(exist_ok=True)

print(f"Project directory: {project_dir}")
print(f"Models directory: {models_dir}")
print(f"Sensitivity directory: {sensitivity_dir}")

## Step 2: Install MLX and Dependencies

In [None]:
# Install required packages
print("Installing MLX and dependencies...")

packages = [
    "mlx-lm",
    "transformers",
    "torch", 
    "huggingface_hub",
    "datasets",
    "accelerate",
    "sentencepiece",
    "protobuf"
]

for package in packages:
    try:
        print(f"Installing {package}...")
        subprocess.run([sys.executable, "-m", "pip", "install", package], 
                      check=True, capture_output=True, text=True)
        print(f"✅ {package} installed successfully")
    except subprocess.CalledProcessError as e:
        print(f"⚠️ Warning installing {package}: {e}")

print("\n📦 All packages installation completed!")

## Step 3: Test MLX Imports

In [None]:
# Test imports
print("Testing MLX imports...")

try:
    import mlx.core as mx
    from mlx_lm import load, generate
    from huggingface_hub import login, snapshot_download
    print("✅ All imports successful!")
    
    # Test MLX functionality
    test_array = mx.array([1, 2, 3])
    print(f"✅ MLX test array: {test_array}")
    
except ImportError as e:
    print(f"❌ Import failed: {e}")
    print("Please restart kernel and try again.")

## Step 4: Configuration

In [None]:
# Dynamic Quantization Configuration
print("=== Dynamic Quantization Configuration ===\n")

# Model to quantize (you can change this)
MODEL_NAME = "Qwen/Qwen2.5-0.5B"  # Small model for demonstration

# Dynamic Quantization Parameters
DYNAMIC_CONFIG = {
    "target_bpw": 4.0,               # Target bits-per-weight
    "low_bits": 2,                   # Precision for less sensitive layers
    "high_bits": 8,                  # Precision for sensitive layers
    "num_samples": 512,              # Samples for sensitivity analysis
    "group_size": 128,               # Group size for quantization
}

print(f"Model: {MODEL_NAME}")
print(f"Target bits-per-weight: {DYNAMIC_CONFIG['target_bpw']}")
print(f"Low precision: {DYNAMIC_CONFIG['low_bits']} bits")
print(f"High precision: {DYNAMIC_CONFIG['high_bits']} bits")
print(f"Sensitivity samples: {DYNAMIC_CONFIG['num_samples']}")
print(f"Group size: {DYNAMIC_CONFIG['group_size']}")

# Set up directories
model_safe_name = MODEL_NAME.replace("/", "_")
original_model_dir = models_dir / model_safe_name
dynamic_model_dir = models_dir / f"{model_safe_name}_Dynamic_{DYNAMIC_CONFIG['target_bpw']}bpw"
sensitivity_file = sensitivity_dir / f"{model_safe_name}_sensitivity.json"

print(f"\nOriginal model dir: {original_model_dir}")
print(f"Dynamic model dir: {dynamic_model_dir}")
print(f"Sensitivity file: {sensitivity_file}")

## Step 5: Download Original Model

In [None]:
from datetime import datetime

print(f"Downloading {MODEL_NAME}...")
print("This may take a while depending on model size and internet connection.")

# Create directories
original_model_dir.mkdir(parents=True, exist_ok=True)

# Check if model already exists
if list(original_model_dir.glob("*")):
    print(f"Model files found in {original_model_dir}")
    use_existing = input("Use existing model files? (y/n): ").strip().lower()
    if use_existing != 'y':
        import shutil
        shutil.rmtree(original_model_dir)
        original_model_dir.mkdir(parents=True, exist_ok=True)

if not list(original_model_dir.glob("*")):
    try:
        start_time = datetime.now()
        
        downloaded_path = snapshot_download(
            repo_id=MODEL_NAME,
            local_dir=str(original_model_dir),
            local_dir_use_symlinks=False
        )
        
        end_time = datetime.now()
        duration = end_time - start_time
        
        print(f"✅ Model downloaded successfully in {duration}")
        
    except Exception as e:
        print(f"❌ Download failed: {e}")
        print("Please check the model name and internet connection.")

# List downloaded files
print("\nModel files:")
total_size = 0
for file in original_model_dir.glob("*"):
    if file.is_file():
        size_mb = file.stat().st_size / 1024 / 1024
        total_size += size_mb
        print(f"  {file.name} ({size_mb:.2f} MB)")

print(f"\nTotal model size: {total_size:.2f} MB")

## Step 6: Sensitivity Analysis (Optional)

In [None]:
import subprocess
from datetime import datetime

# Check if we already have a sensitivity file
print("=== Layer Sensitivity Analysis ===\n")

if sensitivity_file.exists():
    print(f"Sensitivity file found: {sensitivity_file}")
    use_existing = input("Use existing sensitivity analysis? (y/n): ").strip().lower()
    
    if use_existing == 'y':
        # Load existing sensitivity data
        with open(sensitivity_file, 'r') as f:
            sensitivity_data = json.load(f)
        print(f"✅ Loaded existing sensitivity data with {len(sensitivity_data)} layers")
    else:
        # Run new sensitivity analysis
        print("Running new sensitivity analysis...")
        run_sensitivity = True
else:
    print("No existing sensitivity file found.")
    run_analysis = input("Run sensitivity analysis first? (recommended, y/n): ").strip().lower()
    run_sensitivity = run_analysis == 'y'

if run_sensitivity:
    print("\n🔄 Running sensitivity analysis...")
    print("This will analyze which layers are most sensitive to quantization.")
    print("This may take some time...")
    
    # Note: MLX-LM might not have a separate sensitivity analysis command
    # This is a conceptual step - the actual dynamic quantization will
    # perform this analysis internally
    
    try:
        # Create a dummy sensitivity analysis (as MLX-LM handles this internally)
        print("Dynamic quantization will perform sensitivity analysis internally...")
        
        # Create a placeholder sensitivity file for demonstration
        sample_sensitivity = {
            "analysis_date": datetime.now().isoformat(),
            "model": MODEL_NAME,
            "samples_used": DYNAMIC_CONFIG['num_samples'],
            "note": "Sensitivity analysis performed internally by mlx_lm.dynamic_quant"
        }
        
        with open(sensitivity_file, 'w') as f:
            json.dump(sample_sensitivity, f, indent=2)
        
        print(f"✅ Sensitivity analysis placeholder created: {sensitivity_file}")
        
    except Exception as e:
        print(f"❌ Error in sensitivity analysis: {e}")
else:
    print("Skipping sensitivity analysis - will use dynamic quantization defaults.")

## Step 7: Dynamic Quantization

In [None]:
import subprocess
import shutil
from datetime import datetime

print("Starting Dynamic quantization...")
print(f"Source: {original_model_dir}")
print(f"Target: {dynamic_model_dir}")
print(f"Configuration: {DYNAMIC_CONFIG}")

# Clean up existing dynamic directory
if dynamic_model_dir.exists():
    print(f"Removing existing dynamic directory: {dynamic_model_dir}")
    shutil.rmtree(dynamic_model_dir)

dynamic_model_dir.mkdir(parents=True, exist_ok=True)

# Build Dynamic quantization command
dynamic_cmd = [
    "python", "-m", "mlx_lm.dynamic_quant",
    "--model", str(original_model_dir),
    "--mlx-path", str(dynamic_model_dir),
    "--target-bpw", str(DYNAMIC_CONFIG["target_bpw"]),
    "--low-bits", str(DYNAMIC_CONFIG["low_bits"]),
    "--high-bits", str(DYNAMIC_CONFIG["high_bits"])
]

# Add sensitivity file if it exists and contains actual data
if sensitivity_file.exists():
    # Check if it's a real sensitivity file (not our placeholder)
    with open(sensitivity_file, 'r') as f:
        sens_data = json.load(f)
    
    if "note" not in sens_data:  # Real sensitivity data
        dynamic_cmd.extend(["--sensitivities", str(sensitivity_file)])
        print(f"Using sensitivity file: {sensitivity_file}")

print(f"\nRunning command: {' '.join(dynamic_cmd)}")

try:
    start_time = datetime.now()
    
    # Run Dynamic quantization
    result = subprocess.run(
        dynamic_cmd,
        capture_output=True,
        text=True,
        cwd=str(project_dir)
    )
    
    end_time = datetime.now()
    duration = end_time - start_time
    
    if result.returncode == 0:
        print(f"\n✅ Dynamic quantization completed successfully in {duration}!")
        print("STDOUT:", result.stdout)
    else:
        print(f"\n❌ Dynamic quantization failed!")
        print("STDERR:", result.stderr)
        print("STDOUT:", result.stdout)
        
except Exception as e:
    print(f"❌ Error running Dynamic quantization: {e}")

# Check results
if dynamic_model_dir.exists() and list(dynamic_model_dir.glob("*")):
    print("\nDynamic quantized files:")
    total_dynamic_size = 0
    for file in dynamic_model_dir.glob("*"):
        if file.is_file():
            size_mb = file.stat().st_size / 1024 / 1024
            total_dynamic_size += size_mb
            print(f"  {file.name} ({size_mb:.2f} MB)")
    
    print(f"\nTotal dynamic model size: {total_dynamic_size:.2f} MB")
    if total_size > 0:
        print(f"Size reduction: {((total_size - total_dynamic_size) / total_size * 100):.1f}%")
        print(f"Actual bits-per-weight: {(total_dynamic_size / total_size * 16):.2f} bits")

## Step 8: Analyze Quantization Results

In [None]:
# Analyze the quantization results
if dynamic_model_dir.exists() and list(dynamic_model_dir.glob("*")):
    print("=== Dynamic Quantization Analysis ===")
    
    # Check for quantization info files
    quant_info_files = list(dynamic_model_dir.glob("*.json"))
    
    for info_file in quant_info_files:
        if "quant" in info_file.name.lower() or "config" in info_file.name.lower():
            print(f"\n📊 Quantization info from {info_file.name}:")
            try:
                with open(info_file, 'r') as f:
                    info_data = json.load(f)
                
                # Display relevant quantization information
                for key, value in info_data.items():
                    if any(word in key.lower() for word in ['quant', 'bit', 'precision', 'group']):
                        print(f"   {key}: {value}")
                        
            except Exception as e:
                print(f"   Could not read {info_file.name}: {e}")
    
    # Model size comparison
    print(f"\n📏 Size Comparison:")
    print(f"   Original: {total_size:.2f} MB")
    print(f"   Dynamic:  {total_dynamic_size:.2f} MB")
    print(f"   Reduction: {((total_size - total_dynamic_size) / total_size * 100):.1f}%")
    
    # Estimate compression ratio
    compression_ratio = total_size / total_dynamic_size if total_dynamic_size > 0 else 0
    print(f"   Compression ratio: {compression_ratio:.2f}x")
else:
    print("❌ Dynamic model not found. Quantization may have failed.")

## Step 9: Test Dynamic Model

In [None]:
# Test the Dynamic quantized model
if dynamic_model_dir.exists() and list(dynamic_model_dir.glob("*")):
    print("Testing Dynamic quantized model...")
    
    try:
        # Load the Dynamic model
        model, tokenizer = load(str(dynamic_model_dir))
        print("✅ Dynamic model loaded successfully!")
        
        # Test generation with various prompts
        test_prompts = [
            "Hello, how are you today?",
            "The weather forecast shows",
            "Artificial intelligence technology",
            "In the field of machine learning",
            "Dynamic quantization helps"
        ]
        
        print("\n=== Dynamic Model Test Results ===")
        for i, prompt in enumerate(test_prompts, 1):
            print(f"\n[Test {i}] Prompt: '{prompt}'")
            
            response = generate(
                model, 
                tokenizer, 
                prompt=prompt, 
                max_tokens=60,
                temp=0.7
            )
            
            print(f"Response: {response}")
            
        print("\n✅ Dynamic model is working correctly!")
        
    except Exception as e:
        print(f"❌ Error testing Dynamic model: {e}")
else:
    print("❌ Dynamic model not found. Quantization may have failed.")

## Step 10: Performance Comparison

In [None]:
# Optional: Compare original vs Dynamic model performance
import time

compare_models = input("Do you want to compare original vs Dynamic model performance? (y/n): ").strip().lower()

if compare_models == 'y':
    print("\n=== Performance Comparison ===")
    
    test_prompt = "Dynamic quantization is a technique that"
    max_tokens = 80
    num_runs = 3  # Multiple runs for average timing
    
    try:
        # Test original model
        print("\n🔄 Testing original model...")
        original_model, original_tokenizer = load(str(original_model_dir))
        
        original_times = []
        for run in range(num_runs):
            start_time = time.time()
            original_response = generate(
                original_model, 
                original_tokenizer, 
                prompt=test_prompt, 
                max_tokens=max_tokens,
                temp=0.7
            )
            original_times.append(time.time() - start_time)
        
        avg_original_time = sum(original_times) / len(original_times)
        print(f"Original response: {original_response}")
        print(f"Original avg time: {avg_original_time:.2f}s (over {num_runs} runs)")
        
    except Exception as e:
        print(f"❌ Error testing original model: {e}")
        avg_original_time = None
    
    try:
        # Test Dynamic model (already loaded above)
        print("\n🔄 Testing Dynamic model...")
        
        dynamic_times = []
        for run in range(num_runs):
            start_time = time.time()
            dynamic_response = generate(
                model, 
                tokenizer, 
                prompt=test_prompt, 
                max_tokens=max_tokens,
                temp=0.7
            )
            dynamic_times.append(time.time() - start_time)
        
        avg_dynamic_time = sum(dynamic_times) / len(dynamic_times)
        print(f"Dynamic response: {dynamic_response}")
        print(f"Dynamic avg time: {avg_dynamic_time:.2f}s (over {num_runs} runs)")
        
        # Compare performance
        if avg_original_time and avg_dynamic_time:
            speedup = avg_original_time / avg_dynamic_time
            print(f"\n📊 Performance Summary:")
            print(f"   Speedup: {speedup:.2f}x")
            print(f"   Time saved per generation: {avg_original_time - avg_dynamic_time:.2f}s")
            print(f"   Model size reduction: {((total_size - total_dynamic_size) / total_size * 100):.1f}%")
        
    except Exception as e:
        print(f"❌ Error testing Dynamic model: {e}")
else:
    print("Skipping performance comparison.")

## Step 11: Evaluate Model Quality

In [None]:
# Optional: Evaluate the quantized model
print("=== Model Quality Evaluation ===\n")

evaluate_model = input("Do you want to evaluate model quality? (y/n): ").strip().lower()

if evaluate_model == 'y':
    # You can use mlx_lm.evaluate for this
    eval_cmd = [
        "python", "-m", "mlx_lm.evaluate",
        "--model", str(dynamic_model_dir),
        "--dataset", "wikitext",  # or your preferred dataset
        "--few-shot", "5"
    ]
    
    print(f"Running evaluation: {' '.join(eval_cmd)}")
    
    try:
        result = subprocess.run(eval_cmd, capture_output=True, text=True)
        
        if result.returncode == 0:
            print("\n✅ Evaluation completed!")
            print(result.stdout)
        else:
            print("\n❌ Evaluation failed!")
            print(result.stderr)
            
    except Exception as e:
        print(f"❌ Error running evaluation: {e}")
else:
    print("Skipping evaluation.")

## Step 12: Upload to Hugging Face (Optional)

In [None]:
from huggingface_hub import HfApi, upload_folder
import getpass

upload_to_hf = input("Do you want to upload the Dynamic model to Hugging Face? (y/n): ").strip().lower()

if upload_to_hf == 'y':
    # Get Hugging Face credentials
    print("Please enter your Hugging Face token:")
    hf_token = getpass.getpass("HF Token: ")
    
    try:
        login(token=hf_token)
        print("✅ Successfully logged in to Hugging Face!")
        
        # Get repository name
        repo_name = input("Enter repository name (e.g., 'username/model-name-dynamic'): ").strip()
        
        # Create repository
        api = HfApi()
        api.create_repo(repo_id=repo_name, repo_type="model", exist_ok=True)
        print(f"✅ Repository {repo_name} created!")
        
        # Create model card
        model_card = f"""---
license: apache-2.0
base_model: {MODEL_NAME}
tags:
- mlx
- dynamic-quantization
- quantized
- mixed-precision
---

# {MODEL_NAME.split('/')[-1]} - Dynamic Quantization {DYNAMIC_CONFIG['target_bpw']}bpw

This is a Dynamic Quantization version of [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) with {DYNAMIC_CONFIG['target_bpw']} target bits-per-weight.

## Quantization Details
- Method: Dynamic Quantization (Mixed Precision)
- Target bits-per-weight: {DYNAMIC_CONFIG['target_bpw']}
- Low precision: {DYNAMIC_CONFIG['low_bits']} bits (less sensitive layers)
- High precision: {DYNAMIC_CONFIG['high_bits']} bits (sensitive layers)
- Group size: {DYNAMIC_CONFIG['group_size']}

## Features
- Automatically estimates layer sensitivity
- Uses different precision for different layers
- Optimized balance between size and quality
- Optimized for Apple Silicon devices

## How Dynamic Quantization Works
Dynamic quantization analyzes the sensitivity of each layer to quantization and applies:
- Higher precision ({DYNAMIC_CONFIG['high_bits']} bits) for sensitive layers
- Lower precision ({DYNAMIC_CONFIG['low_bits']} bits) for less sensitive layers

This approach maintains model quality while achieving significant size reduction.

## Usage
```python
from mlx_lm import load, generate

model, tokenizer = load("{repo_name}")
response = generate(model, tokenizer, prompt="Hello", max_tokens=100)
```
"""
        
        # Save model card
        with open(dynamic_model_dir / "README.md", "w") as f:
            f.write(model_card)
        
        # Upload
        print(f"Uploading to {repo_name}...")
        upload_folder(
            folder_path=str(dynamic_model_dir),
            repo_id=repo_name,
            repo_type="model",
            commit_message=f"Add Dynamic quantized model ({DYNAMIC_CONFIG['target_bpw']}bpw)"
        )
        
        print(f"✅ Model uploaded successfully!")
        print(f"🔗 https://huggingface.co/{repo_name}")
        
    except Exception as e:
        print(f"❌ Upload failed: {e}")
else:
    print("Skipping upload.")

## Step 13: Summary

In [None]:
# Final summary
print("\n" + "="*60)
print("🎉 DYNAMIC QUANTIZATION SUMMARY")
print("="*60)

print(f"\n📋 Configuration:")
print(f"   Base Model: {MODEL_NAME}")
print(f"   Target BPW: {DYNAMIC_CONFIG['target_bpw']}")
print(f"   Low Precision: {DYNAMIC_CONFIG['low_bits']} bits")
print(f"   High Precision: {DYNAMIC_CONFIG['high_bits']} bits")
print(f"   Group Size: {DYNAMIC_CONFIG['group_size']}")

print(f"\n📁 Directories:")
print(f"   Original: {original_model_dir}")
print(f"   Dynamic Model: {dynamic_model_dir}")
print(f"   Sensitivity File: {sensitivity_file}")

# Check if quantization was successful
if dynamic_model_dir.exists() and list(dynamic_model_dir.glob("*")):
    print(f"\n✅ Status: Dynamic quantization completed successfully!")
    
    # Calculate metrics if available
    if 'total_size' in locals() and 'total_dynamic_size' in locals():
        size_reduction = ((total_size - total_dynamic_size) / total_size * 100)
        compression_ratio = total_size / total_dynamic_size
        actual_bpw = (total_dynamic_size / total_size * 16)
        
        print(f"   Original size: {total_size:.2f} MB")
        print(f"   Dynamic size: {total_dynamic_size:.2f} MB")
        print(f"   Size reduction: {size_reduction:.1f}%")
        print(f"   Compression ratio: {compression_ratio:.2f}x")
        print(f"   Actual bits-per-weight: {actual_bpw:.2f}")
else:
    print(f"\n❌ Status: Dynamic quantization failed or incomplete")

print(f"\n💡 Dynamic Quantization Advantages:")
print(f"   • Adaptive precision based on layer sensitivity")
print(f"   • Better quality preservation than uniform quantization")
print(f"   • Automatic sensitivity analysis")
print(f"   • Optimal balance between size and performance")

print(f"\n🔧 Tuning Tips:")
print(f"   • Lower target-bpw = smaller model, potentially lower quality")
print(f"   • Adjust high-bits/low-bits spread for different trade-offs")
print(f"   • Use sensitivity analysis for fine-tuning")
print(f"   • Test with your specific use case to validate quality")

print("\n" + "="*60)
print("Thank you for using Dynamic quantization!")
print("="*60)