# AWQ (Activation-aware Weight Quantization) with MLX-LM

This notebook demonstrates how to use AWQ (Activation-aware Weight Quantization) with MLX-LM to scale and clip weights before quantization.

## What is AWQ?
AWQ is a quantization method that scales and clips weights before quantization to preserve model quality. It uses calibration samples to determine optimal scaling factors for different weights.

## Requirements
- macOS with Apple Silicon (M1/M2/M3/M4)
- Python 3.9+
- MLX framework
- Sufficient disk space for model storage

## Step 1: Environment Setup and Dependencies

In [None]:
import os
import sys
import subprocess
from pathlib import Path

# Environment setup
print("Setting up environment for AWQ quantization...")

# Create project directories
project_dir = Path.cwd()
models_dir = project_dir / "models"
models_dir.mkdir(exist_ok=True)

print(f"Project directory: {project_dir}")
print(f"Models directory: {models_dir}")

## Step 2: Install MLX and Dependencies

In [None]:
# Install required packages
print("Installing MLX and dependencies...")

packages = [
    "mlx-lm",
    "transformers",
    "torch", 
    "huggingface_hub",
    "datasets",
    "accelerate",
    "sentencepiece",
    "protobuf"
]

for package in packages:
    try:
        print(f"Installing {package}...")
        subprocess.run([sys.executable, "-m", "pip", "install", package], 
                      check=True, capture_output=True, text=True)
        print(f"‚úÖ {package} installed successfully")
    except subprocess.CalledProcessError as e:
        print(f"‚ö†Ô∏è Warning installing {package}: {e}")

print("\nüì¶ All packages installation completed!")

## Step 3: Test MLX Imports

In [None]:
# Test imports
print("Testing MLX imports...")

try:
    import mlx.core as mx
    from mlx_lm import load, generate
    from huggingface_hub import login, snapshot_download
    print("‚úÖ All imports successful!")
    
    # Test MLX functionality
    test_array = mx.array([1, 2, 3])
    print(f"‚úÖ MLX test array: {test_array}")
    
except ImportError as e:
    print(f"‚ùå Import failed: {e}")
    print("Please restart kernel and try again.")

## Step 4: Configuration

In [None]:
# AWQ Configuration
print("=== AWQ Configuration ===\n")

# Model to quantize (you can change this)
MODEL_NAME = "Qwen/Qwen2.5-0.5B"  # Small model for demonstration

# AWQ Parameters
AWQ_CONFIG = {
    "bits": 4,                    # Quantization precision (typically 4 bits)
    "num_samples": 32,            # Calibration samples (default: 32)
    "n_grid": 10,                 # Search granularity (default: 10)
    "group_size": 128,            # Group size for quantization
}

print(f"Model: {MODEL_NAME}")
print(f"Target bits: {AWQ_CONFIG['bits']}")
print(f"Calibration samples: {AWQ_CONFIG['num_samples']}")
print(f"Search grid: {AWQ_CONFIG['n_grid']}")
print(f"Group size: {AWQ_CONFIG['group_size']}")

# Set up directories
original_model_dir = models_dir / MODEL_NAME.replace("/", "_")
awq_model_dir = models_dir / f"{MODEL_NAME.replace('/', '_')}_AWQ_{AWQ_CONFIG['bits']}bit"

print(f"\nOriginal model dir: {original_model_dir}")
print(f"AWQ model dir: {awq_model_dir}")

## Step 5: Download Original Model

In [None]:
from datetime import datetime

print(f"Downloading {MODEL_NAME}...")
print("This may take a while depending on model size and internet connection.")

# Create directories
original_model_dir.mkdir(parents=True, exist_ok=True)

# Check if model already exists
if list(original_model_dir.glob("*")):
    print(f"Model files found in {original_model_dir}")
    use_existing = input("Use existing model files? (y/n): ").strip().lower()
    if use_existing != 'y':
        import shutil
        shutil.rmtree(original_model_dir)
        original_model_dir.mkdir(parents=True, exist_ok=True)

if not list(original_model_dir.glob("*")):
    try:
        start_time = datetime.now()
        
        downloaded_path = snapshot_download(
            repo_id=MODEL_NAME,
            local_dir=str(original_model_dir),
            local_dir_use_symlinks=False
        )
        
        end_time = datetime.now()
        duration = end_time - start_time
        
        print(f"‚úÖ Model downloaded successfully in {duration}")
        
    except Exception as e:
        print(f"‚ùå Download failed: {e}")
        print("Please check the model name and internet connection.")

# List downloaded files
print("\nModel files:")
total_size = 0
for file in original_model_dir.glob("*"):
    if file.is_file():
        size_mb = file.stat().st_size / 1024 / 1024
        total_size += size_mb
        print(f"  {file.name} ({size_mb:.2f} MB)")

print(f"\nTotal model size: {total_size:.2f} MB")

## Step 6: AWQ Quantization

In [None]:
import subprocess
import shutil
from datetime import datetime

print("Starting AWQ quantization...")
print(f"Source: {original_model_dir}")
print(f"Target: {awq_model_dir}")
print(f"Configuration: {AWQ_CONFIG}")

# Clean up existing AWQ directory
if awq_model_dir.exists():
    print(f"Removing existing AWQ directory: {awq_model_dir}")
    shutil.rmtree(awq_model_dir)

awq_model_dir.mkdir(parents=True, exist_ok=True)

# Build AWQ command
awq_cmd = [
    "python", "-m", "mlx_lm.awq",
    "--model", str(original_model_dir),
    "--mlx-path", str(awq_model_dir),
    "--bits", str(AWQ_CONFIG["bits"]),
    "--num-samples", str(AWQ_CONFIG["num_samples"]),
    "--n-grid", str(AWQ_CONFIG["n_grid"])
]

print(f"\nRunning command: {' '.join(awq_cmd)}")

try:
    start_time = datetime.now()
    
    # Run AWQ quantization
    result = subprocess.run(
        awq_cmd,
        capture_output=True,
        text=True,
        cwd=str(project_dir)
    )
    
    end_time = datetime.now()
    duration = end_time - start_time
    
    if result.returncode == 0:
        print(f"\n‚úÖ AWQ quantization completed successfully in {duration}!")
        print("STDOUT:", result.stdout)
    else:
        print(f"\n‚ùå AWQ quantization failed!")
        print("STDERR:", result.stderr)
        print("STDOUT:", result.stdout)
        
except Exception as e:
    print(f"‚ùå Error running AWQ: {e}")

# Check results
if awq_model_dir.exists() and list(awq_model_dir.glob("*")):
    print("\nAWQ quantized files:")
    total_awq_size = 0
    for file in awq_model_dir.glob("*"):
        if file.is_file():
            size_mb = file.stat().st_size / 1024 / 1024
            total_awq_size += size_mb
            print(f"  {file.name} ({size_mb:.2f} MB)")
    
    print(f"\nTotal AWQ model size: {total_awq_size:.2f} MB")
    if total_size > 0:
        print(f"Size reduction: {((total_size - total_awq_size) / total_size * 100):.1f}%")

## Step 7: Test AWQ Model

In [None]:
# Test the AWQ quantized model
if awq_model_dir.exists() and list(awq_model_dir.glob("*")):
    print("Testing AWQ quantized model...")
    
    try:
        # Load the AWQ model
        model, tokenizer = load(str(awq_model_dir))
        print("‚úÖ AWQ model loaded successfully!")
        
        # Test generation
        test_prompts = [
            "Hello, how are you?",
            "The weather today is",
            "Artificial intelligence is",
            "Machine learning can be used for"
        ]
        
        print("\n=== AWQ Model Test Results ===")
        for prompt in test_prompts:
            print(f"\nPrompt: '{prompt}'")
            
            response = generate(
                model, 
                tokenizer, 
                prompt=prompt, 
                max_tokens=50,
                temp=0.7
            )
            
            print(f"Response: {response}")
            
        print("\n‚úÖ AWQ model is working correctly!")
        
    except Exception as e:
        print(f"‚ùå Error testing AWQ model: {e}")
else:
    print("‚ùå AWQ model not found. Quantization may have failed.")

## Step 8: Compare Original vs AWQ Performance

In [None]:
# Optional: Compare original vs AWQ model performance
import time

compare_models = input("Do you want to compare original vs AWQ model performance? (y/n): ").strip().lower()

if compare_models == 'y':
    print("\n=== Model Performance Comparison ===")
    
    test_prompt = "The future of artificial intelligence"
    max_tokens = 100
    
    try:
        # Test original model
        print("\nüîÑ Testing original model...")
        original_model, original_tokenizer = load(str(original_model_dir))
        
        start_time = time.time()
        original_response = generate(
            original_model, 
            original_tokenizer, 
            prompt=test_prompt, 
            max_tokens=max_tokens,
            temp=0.7
        )
        original_time = time.time() - start_time
        
        print(f"Original response: {original_response}")
        print(f"Original generation time: {original_time:.2f}s")
        
    except Exception as e:
        print(f"‚ùå Error testing original model: {e}")
        original_response = None
        original_time = None
    
    try:
        # Test AWQ model (already loaded above)
        print("\nüîÑ Testing AWQ model...")
        
        start_time = time.time()
        awq_response = generate(
            model, 
            tokenizer, 
            prompt=test_prompt, 
            max_tokens=max_tokens,
            temp=0.7
        )
        awq_time = time.time() - start_time
        
        print(f"AWQ response: {awq_response}")
        print(f"AWQ generation time: {awq_time:.2f}s")
        
        # Compare performance
        if original_time and awq_time:
            speedup = original_time / awq_time
            print(f"\nüìä Performance comparison:")
            print(f"   Speedup: {speedup:.2f}x")
            print(f"   Time saved: {original_time - awq_time:.2f}s")
        
    except Exception as e:
        print(f"‚ùå Error testing AWQ model: {e}")
else:
    print("Skipping performance comparison.")

## Step 9: Evaluate Model Quality

In [None]:
# Optional: Evaluate the quantized model
print("=== Model Quality Evaluation ===\n")

evaluate_model = input("Do you want to evaluate model quality? (y/n): ").strip().lower()

if evaluate_model == 'y':
    # You can use mlx_lm.evaluate for this
    eval_cmd = [
        "python", "-m", "mlx_lm.evaluate",
        "--model", str(awq_model_dir),
        "--dataset", "wikitext",  # or your preferred dataset
        "--few-shot", "5"
    ]
    
    print(f"Running evaluation: {' '.join(eval_cmd)}")
    
    try:
        result = subprocess.run(eval_cmd, capture_output=True, text=True)
        
        if result.returncode == 0:
            print("\n‚úÖ Evaluation completed!")
            print(result.stdout)
        else:
            print("\n‚ùå Evaluation failed!")
            print(result.stderr)
            
    except Exception as e:
        print(f"‚ùå Error running evaluation: {e}")
else:
    print("Skipping evaluation.")

## Step 10: Upload to Hugging Face (Optional)

In [None]:
from huggingface_hub import HfApi, upload_folder
import getpass

upload_to_hf = input("Do you want to upload the AWQ model to Hugging Face? (y/n): ").strip().lower()

if upload_to_hf == 'y':
    # Get Hugging Face credentials
    print("Please enter your Hugging Face token:")
    hf_token = getpass.getpass("HF Token: ")
    
    try:
        login(token=hf_token)
        print("‚úÖ Successfully logged in to Hugging Face!")
        
        # Get repository name
        repo_name = input("Enter repository name (e.g., 'username/model-name-awq'): ").strip()
        
        # Create repository
        api = HfApi()
        api.create_repo(repo_id=repo_name, repo_type="model", exist_ok=True)
        print(f"‚úÖ Repository {repo_name} created!")
        
        # Create model card
        model_card = f"""---
license: apache-2.0
base_model: {MODEL_NAME}
tags:
- mlx
- awq
- quantized
- {AWQ_CONFIG['bits']}-bit
---

# {MODEL_NAME.split('/')[-1]} - AWQ {AWQ_CONFIG['bits']}-bit

This is an AWQ (Activation-aware Weight Quantization) {AWQ_CONFIG['bits']}-bit quantized version of [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}).

## Quantization Details
- Method: AWQ (Activation-aware Weight Quantization)
- Precision: {AWQ_CONFIG['bits']}-bit
- Calibration samples: {AWQ_CONFIG['num_samples']}
- Search grid: {AWQ_CONFIG['n_grid']}
- Group size: {AWQ_CONFIG['group_size']}

## Features
- Scales and clips weights before quantization
- Optimized for Apple Silicon devices
- Maintains model quality through activation-aware scaling

## Usage
```python
from mlx_lm import load, generate

model, tokenizer = load("{repo_name}")
response = generate(model, tokenizer, prompt="Hello", max_tokens=100)
```
"""
        
        # Save model card
        with open(awq_model_dir / "README.md", "w") as f:
            f.write(model_card)
        
        # Upload
        print(f"Uploading to {repo_name}...")
        upload_folder(
            folder_path=str(awq_model_dir),
            repo_id=repo_name,
            repo_type="model",
            commit_message=f"Add AWQ {AWQ_CONFIG['bits']}-bit quantized model"
        )
        
        print(f"‚úÖ Model uploaded successfully!")
        print(f"üîó https://huggingface.co/{repo_name}")
        
    except Exception as e:
        print(f"‚ùå Upload failed: {e}")
else:
    print("Skipping upload.")

## Step 11: Summary

In [None]:
# Final summary
print("\n" + "="*60)
print("üéâ AWQ QUANTIZATION SUMMARY")
print("="*60)

print(f"\nüìã Configuration:")
print(f"   Base Model: {MODEL_NAME}")
print(f"   Target Bits: {AWQ_CONFIG['bits']}")
print(f"   Calibration Samples: {AWQ_CONFIG['num_samples']}")
print(f"   Search Grid: {AWQ_CONFIG['n_grid']}")
print(f"   Group Size: {AWQ_CONFIG['group_size']}")

print(f"\nüìÅ Directories:")
print(f"   Original: {original_model_dir}")
print(f"   AWQ Model: {awq_model_dir}")

# Check if quantization was successful
if awq_model_dir.exists() and list(awq_model_dir.glob("*")):
    print(f"\n‚úÖ Status: AWQ quantization completed successfully!")
    
    # Calculate size reduction if possible
    if 'total_size' in locals() and 'total_awq_size' in locals():
        print(f"   Original size: {total_size:.2f} MB")
        print(f"   AWQ size: {total_awq_size:.2f} MB")
        print(f"   Size reduction: {((total_size - total_awq_size) / total_size * 100):.1f}%")
else:
    print(f"\n‚ùå Status: AWQ quantization failed or incomplete")

print(f"\nüí° AWQ Advantages:")
print(f"   ‚Ä¢ Scales and clips weights before quantization")
print(f"   ‚Ä¢ Preserves model quality through activation awareness")
print(f"   ‚Ä¢ Efficient search for optimal scaling factors")
print(f"   ‚Ä¢ Good balance between size and performance")

print(f"\nüîß Tuning Tips:")
print(f"   ‚Ä¢ Increase num_samples for better quality")
print(f"   ‚Ä¢ Increase n_grid for more thorough search")
print(f"   ‚Ä¢ Adjust group_size based on model architecture")

print("\n" + "="*60)
print("Thank you for using AWQ quantization!")
print("="*60)