# DWQ (Distilled Weight Quantization) with MLX-LM

This notebook demonstrates how to use DWQ (Distilled Weight Quantization) with MLX-LM to reduce model quality loss during quantization.

## What is DWQ?
DWQ is designed to minimize quality loss when quantizing models to lower bit precision. It works best for 2-4 bit models and uses calibration samples to maintain model performance.

## Requirements
- macOS with Apple Silicon (M1/M2/M3/M4)
- Python 3.9+
- MLX framework
- Sufficient disk space for model storage

## Step 1: Environment Setup and Dependencies

In [None]:
import os
import sys
import subprocess
from pathlib import Path

# Environment setup
print("Setting up environment for DWQ quantization...")

# Create project directories
project_dir = Path.cwd()
models_dir = project_dir / "models"
models_dir.mkdir(exist_ok=True)

print(f"Project directory: {project_dir}")
print(f"Models directory: {models_dir}")

## Step 2: Install MLX and Dependencies

In [None]:
# Install required packages
print("Installing MLX and dependencies...")

packages = [
    "mlx-lm",
    "transformers",
    "torch", 
    "huggingface_hub",
    "datasets",
    "accelerate",
    "sentencepiece",
    "protobuf"
]

for package in packages:
    try:
        print(f"Installing {package}...")
        subprocess.run([sys.executable, "-m", "pip", "install", package], 
                      check=True, capture_output=True, text=True)
        print(f"✅ {package} installed successfully")
    except subprocess.CalledProcessError as e:
        print(f"⚠️ Warning installing {package}: {e}")

print("\n📦 All packages installation completed!")

## Step 3: Test MLX Imports

In [None]:
# Test imports
print("Testing MLX imports...")

try:
    import mlx.core as mx
    from mlx_lm import load, generate
    from huggingface_hub import login, snapshot_download
    print("✅ All imports successful!")
    
    # Test MLX functionality
    test_array = mx.array([1, 2, 3])
    print(f"✅ MLX test array: {test_array}")
    
except ImportError as e:
    print(f"❌ Import failed: {e}")
    print("Please restart kernel and try again.")

## Step 4: Configuration

In [None]:
# DWQ Configuration
print("=== DWQ Configuration ===\n")

# Model to quantize (you can change this)
MODEL_NAME = "Qwen/Qwen2.5-0.5B"  # Small model for demonstration

# DWQ Parameters
DWQ_CONFIG = {
    "bits": 4,                    # Quantization precision (2-4 bits work best)
    "num_samples": 1024,          # Calibration samples (default: 1024)
    "batch_size": 8,              # Batch size to reduce memory footprint
    "group_size": 64,             # Group size (smaller can improve results)
    "learning_rate": 0.01,        # Learning rate (adjust based on precision)
}

print(f"Model: {MODEL_NAME}")
print(f"Target bits: {DWQ_CONFIG['bits']}")
print(f"Calibration samples: {DWQ_CONFIG['num_samples']}")
print(f"Batch size: {DWQ_CONFIG['batch_size']}")
print(f"Group size: {DWQ_CONFIG['group_size']}")

# Set up directories
original_model_dir = models_dir / MODEL_NAME.replace("/", "_")
dwq_model_dir = models_dir / f"{MODEL_NAME.replace('/', '_')}_DWQ_{DWQ_CONFIG['bits']}bit"

print(f"\nOriginal model dir: {original_model_dir}")
print(f"DWQ model dir: {dwq_model_dir}")

## Step 5: Download Original Model

In [None]:
from datetime import datetime

print(f"Downloading {MODEL_NAME}...")
print("This may take a while depending on model size and internet connection.")

# Create directories
original_model_dir.mkdir(parents=True, exist_ok=True)

# Check if model already exists
if list(original_model_dir.glob("*")):
    print(f"Model files found in {original_model_dir}")
    use_existing = input("Use existing model files? (y/n): ").strip().lower()
    if use_existing != 'y':
        import shutil
        shutil.rmtree(original_model_dir)
        original_model_dir.mkdir(parents=True, exist_ok=True)

if not list(original_model_dir.glob("*")):
    try:
        start_time = datetime.now()
        
        downloaded_path = snapshot_download(
            repo_id=MODEL_NAME,
            local_dir=str(original_model_dir),
            local_dir_use_symlinks=False
        )
        
        end_time = datetime.now()
        duration = end_time - start_time
        
        print(f"✅ Model downloaded successfully in {duration}")
        
    except Exception as e:
        print(f"❌ Download failed: {e}")
        print("Please check the model name and internet connection.")

# List downloaded files
print("\nModel files:")
total_size = 0
for file in original_model_dir.glob("*"):
    if file.is_file():
        size_mb = file.stat().st_size / 1024 / 1024
        total_size += size_mb
        print(f"  {file.name} ({size_mb:.2f} MB)")

print(f"\nTotal model size: {total_size:.2f} MB")

## Step 6: DWQ Quantization

In [None]:
import subprocess
import shutil
from datetime import datetime

print("Starting DWQ quantization...")
print(f"Source: {original_model_dir}")
print(f"Target: {dwq_model_dir}")
print(f"Configuration: {DWQ_CONFIG}")

# Clean up existing DWQ directory
if dwq_model_dir.exists():
    print(f"Removing existing DWQ directory: {dwq_model_dir}")
    shutil.rmtree(dwq_model_dir)

dwq_model_dir.mkdir(parents=True, exist_ok=True)

# Build DWQ command
dwq_cmd = [
    "python", "-m", "mlx_lm.dwq",
    "--model", str(original_model_dir),
    "--mlx-path", str(dwq_model_dir),
    "--bits", str(DWQ_CONFIG["bits"]),
    "--num-samples", str(DWQ_CONFIG["num_samples"]),
    "--batch-size", str(DWQ_CONFIG["batch_size"])
]

print(f"\nRunning command: {' '.join(dwq_cmd)}")

try:
    start_time = datetime.now()
    
    # Run DWQ quantization
    result = subprocess.run(
        dwq_cmd,
        capture_output=True,
        text=True,
        cwd=str(project_dir)
    )
    
    end_time = datetime.now()
    duration = end_time - start_time
    
    if result.returncode == 0:
        print(f"\n✅ DWQ quantization completed successfully in {duration}!")
        print("STDOUT:", result.stdout)
    else:
        print(f"\n❌ DWQ quantization failed!")
        print("STDERR:", result.stderr)
        print("STDOUT:", result.stdout)
        
except Exception as e:
    print(f"❌ Error running DWQ: {e}")

# Check results
if dwq_model_dir.exists() and list(dwq_model_dir.glob("*")):
    print("\nDWQ quantized files:")
    total_size = 0
    for file in dwq_model_dir.glob("*"):
        if file.is_file():
            size_mb = file.stat().st_size / 1024 / 1024
            total_size += size_mb
            print(f"  {file.name} ({size_mb:.2f} MB)")
    
    print(f"\nTotal DWQ model size: {total_size:.2f} MB")
    print(f"Size reduction: {((total_size/total_size if 'total_size' in locals() else 0) - 1) * 100:.1f}%")

## Step 7: Test DWQ Model

In [None]:
# Test the DWQ quantized model
if dwq_model_dir.exists() and list(dwq_model_dir.glob("*")):
    print("Testing DWQ quantized model...")
    
    try:
        # Load the DWQ model
        model, tokenizer = load(str(dwq_model_dir))
        print("✅ DWQ model loaded successfully!")
        
        # Test generation
        test_prompts = [
            "Hello, how are you?",
            "The weather today is",
            "Artificial intelligence is"
        ]
        
        print("\n=== DWQ Model Test Results ===")
        for prompt in test_prompts:
            print(f"\nPrompt: '{prompt}'")
            
            response = generate(
                model, 
                tokenizer, 
                prompt=prompt, 
                max_tokens=50,
                temp=0.7
            )
            
            print(f"Response: {response}")
            
        print("\n✅ DWQ model is working correctly!")
        
    except Exception as e:
        print(f"❌ Error testing DWQ model: {e}")
else:
    print("❌ DWQ model not found. Quantization may have failed.")

## Step 8: Evaluate Model Quality

In [None]:
# Optional: Evaluate the quantized model
# This requires a dataset for evaluation

print("=== Model Quality Evaluation ===\n")

evaluate_model = input("Do you want to evaluate model quality? (y/n): ").strip().lower()

if evaluate_model == 'y':
    # You can use mlx_lm.evaluate for this
    eval_cmd = [
        "python", "-m", "mlx_lm.evaluate",
        "--model", str(dwq_model_dir),
        "--dataset", "wikitext",  # or your preferred dataset
        "--few-shot", "5"
    ]
    
    print(f"Running evaluation: {' '.join(eval_cmd)}")
    
    try:
        result = subprocess.run(eval_cmd, capture_output=True, text=True)
        
        if result.returncode == 0:
            print("\n✅ Evaluation completed!")
            print(result.stdout)
        else:
            print("\n❌ Evaluation failed!")
            print(result.stderr)
            
    except Exception as e:
        print(f"❌ Error running evaluation: {e}")
else:
    print("Skipping evaluation.")

## Step 9: Upload to Hugging Face (Optional)

In [None]:
from huggingface_hub import HfApi, upload_folder
import getpass

upload_to_hf = input("Do you want to upload the DWQ model to Hugging Face? (y/n): ").strip().lower()

if upload_to_hf == 'y':
    # Get Hugging Face credentials
    print("Please enter your Hugging Face token:")
    hf_token = getpass.getpass("HF Token: ")
    
    try:
        login(token=hf_token)
        print("✅ Successfully logged in to Hugging Face!")
        
        # Get repository name
        repo_name = input("Enter repository name (e.g., 'username/model-name-dwq'): ").strip()
        
        # Create repository
        api = HfApi()
        api.create_repo(repo_id=repo_name, repo_type="model", exist_ok=True)
        print(f"✅ Repository {repo_name} created!")
        
        # Create model card
        model_card = f"""---
license: apache-2.0
base_model: {MODEL_NAME}
tags:
- mlx
- dwq
- quantized
- {DWQ_CONFIG['bits']}-bit
---

# {MODEL_NAME.split('/')[-1]} - DWQ {DWQ_CONFIG['bits']}-bit

This is a DWQ (Distilled Weight Quantization) {DWQ_CONFIG['bits']}-bit quantized version of [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}).

## Quantization Details
- Method: DWQ (Distilled Weight Quantization)
- Precision: {DWQ_CONFIG['bits']}-bit
- Calibration samples: {DWQ_CONFIG['num_samples']}
- Group size: {DWQ_CONFIG['group_size']}

## Usage
```python
from mlx_lm import load, generate

model, tokenizer = load("{repo_name}")
response = generate(model, tokenizer, prompt="Hello", max_tokens=100)
```
"""
        
        # Save model card
        with open(dwq_model_dir / "README.md", "w") as f:
            f.write(model_card)
        
        # Upload
        print(f"Uploading to {repo_name}...")
        upload_folder(
            folder_path=str(dwq_model_dir),
            repo_id=repo_name,
            repo_type="model",
            commit_message=f"Add DWQ {DWQ_CONFIG['bits']}-bit quantized model"
        )
        
        print(f"✅ Model uploaded successfully!")
        print(f"🔗 https://huggingface.co/{repo_name}")
        
    except Exception as e:
        print(f"❌ Upload failed: {e}")
else:
    print("Skipping upload.")

## Step 10: Summary

In [None]:
# Final summary
print("\n" + "="*60)
print("🎉 DWQ QUANTIZATION SUMMARY")
print("="*60)

print(f"\n📋 Configuration:")
print(f"   Base Model: {MODEL_NAME}")
print(f"   Target Bits: {DWQ_CONFIG['bits']}")
print(f"   Calibration Samples: {DWQ_CONFIG['num_samples']}")
print(f"   Group Size: {DWQ_CONFIG['group_size']}")

print(f"\n📁 Directories:")
print(f"   Original: {original_model_dir}")
print(f"   DWQ Model: {dwq_model_dir}")

# Check if quantization was successful
if dwq_model_dir.exists() and list(dwq_model_dir.glob("*")):
    print(f"\n✅ Status: DWQ quantization completed successfully!")
    
    # Calculate size reduction if possible
    original_size = sum(f.stat().st_size for f in original_model_dir.glob("*") if f.is_file()) / 1024 / 1024
    dwq_size = sum(f.stat().st_size for f in dwq_model_dir.glob("*") if f.is_file()) / 1024 / 1024
    
    print(f"   Original size: {original_size:.2f} MB")
    print(f"   DWQ size: {dwq_size:.2f} MB")
    print(f"   Size reduction: {((original_size - dwq_size) / original_size * 100):.1f}%")
else:
    print(f"\n❌ Status: DWQ quantization failed or incomplete")

print(f"\n💡 Tips for DWQ:")
print(f"   • Works best for 2-4 bit quantization")
print(f"   • Decreasing group size can improve results")
print(f"   • Adjust learning rate based on precision")
print(f"   • More calibration samples = better quality")

print("\n" + "="*60)
print("Thank you for using DWQ quantization!")
print("="*60)