# Quantization and precision analysis

**Objective:** Demonstrate how **TFLite Dynamic Range Quantization** reduces model footprint by 4x (from 32-bit float to 8-bit integer) while maintaining near-original accuracy.

**Key Concepts:**
* **Dynamic Range Quantization:** Converting weights to INT8 to save space, while keeping activations in FP32 for computation stability.
* **Calibration:** Using a representative dataset to determine the dynamic range (min/max) of values.
* **Confidence Drift:** Measuring the divergence in probability scores between the original and quantized models.

**Instructions:**
1.  **Runtime:** The standard **CPU Runtime** is sufficient (no GPU required).
2.  **Run:** Execute all cells to download calibration data, convert the model, and generate the degradation report.

In [None]:
# @title 1. Setup
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
from skimage import data
from skimage.transform import resize
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input, decode_predictions

print(f"TensorFlow Version: {tf.__version__}")

# MobileNetV2 requires 224x224 images
IMG_SIZE = 224

In [None]:
# @title 2. Data Factory
def get_calibration_data():
    """
    Loads standard computer vision images from skimage.
    Returns:
        raw_images: List of images for display (human readable)
        input_batch: Preprocessed numpy batch (model ready)
    """
    # 1. Load standard images (guaranteed to exist in all versions)
    images = [
        data.chelsea(),      # Cat
        data.astronaut(),    # Human
        data.coffee(),       # Object
        data.rocket()        # Vehicle
    ]

    # 2. Resize and Preprocess
    processed_imgs = []
    for img in images:
        # Resize to 224x224
        img_resized = resize(img, (IMG_SIZE, IMG_SIZE), preserve_range=True)
        # Expand dims to (1, 224, 224, 3)
        img_expanded = np.expand_dims(img_resized, axis=0)
        # Preprocess (MobileNetV2 specific: scales to [-1, 1])
        img_preproc = preprocess_input(img_expanded)
        processed_imgs.append(img_preproc)

    # Stack into a single batch: (4, 224, 224, 3)
    input_batch = np.vstack(processed_imgs)

    return images, input_batch

# Load and verify
raw_images, test_batch = get_calibration_data()

# Show them to the user
plt.figure(figsize=(12, 3))
for i, img in enumerate(raw_images):
    plt.subplot(1, 4, i+1)
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"Image {i+1}")
plt.show()

print(f"Input Batch Shape: {test_batch.shape}")

In [None]:
# @title 3. Load Baseline Model
print("Loading MobileNetV2 (FP32)...")
baseline_model = tf.keras.applications.MobileNetV2(weights='imagenet', input_shape=(IMG_SIZE, IMG_SIZE, 3))

# Sanity Check: Run prediction on the Astronaut (Index 1)
print("\n--- Baseline Sanity Check (Astronaut) ---")
astro_pred = baseline_model.predict(test_batch[1:2], verbose=0)
print(decode_predictions(astro_pred, top=1)[0])

In [None]:
# @title 4. Quantization Pipeline

# Helper to save TFLite models
def save_tflite(tflite_model, filename):
    with open(filename, 'wb') as f:
        f.write(tflite_model)
    print(f"Saved: {filename}")

# 1. Convert to Standard FP32 TFLite (No Optimization)
converter = tf.lite.TFLiteConverter.from_keras_model(baseline_model)
tflite_fp32 = converter.convert()
save_tflite(tflite_fp32, 'mobilenet_fp32.tflite')

# 2. Convert to Dynamic Range Quantization (INT8 Weights)
# We set the optimization flag. TFLite will quantize weights to INT8
# but keep activations in FP32 during runtime.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant = converter.convert()
save_tflite(tflite_quant, 'mobilenet_int8.tflite')

In [None]:
# @title 5. Size Benchmark
size_fp32 = os.path.getsize('mobilenet_fp32.tflite') / (1024 * 1024)
size_int8 = os.path.getsize('mobilenet_int8.tflite') / (1024 * 1024)

print(f"FP32 Model Size: {size_fp32:.2f} MB")
print(f"INT8 Model Size: {size_int8:.2f} MB")
print(f"Compression Ratio: {size_fp32 / size_int8:.2f}x")

In [None]:
# @title 6. Interpreter Helper
def run_tflite_inference(model_path, input_data):
    """
    Runs inference on a TFLite model file.
    """
    # Load Interpreter
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()

    # Get input/output details
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # Run inference for each image in the batch
    outputs = []
    for i in range(len(input_data)):
        # Set input tensor (needs specific shape, e.g., (1, 224, 224, 3))
        input_tensor = input_data[i:i+1].astype(np.float32)
        interpreter.set_tensor(input_details[0]['index'], input_tensor)

        # Run
        interpreter.invoke()

        # Get output
        output_data = interpreter.get_tensor(output_details[0]['index'])
        outputs.append(output_data)

    return np.vstack(outputs)

In [None]:
# @title 7. Accuracy Evaluation
print("Running Inference on Test Batch...")

# 1. Run both models
# Note: Ensure you ran Cell 6 to define 'run_tflite_inference' first!
preds_fp32 = run_tflite_inference('mobilenet_fp32.tflite', test_batch)
preds_int8 = run_tflite_inference('mobilenet_int8.tflite', test_batch)

# 2. Calculate Degradation
results = []

# Dynamic loop based on actual batch size (safe for 4 or 5 images)
num_images = len(test_batch)

for i in range(num_images):
    # Get top prediction index and confidence
    top_fp32 = np.argmax(preds_fp32[i])
    conf_fp32 = preds_fp32[i][top_fp32]

    top_int8 = np.argmax(preds_int8[i])
    conf_int8 = preds_int8[i][top_int8]

    # Calculate Mean Squared Error between the full probability vectors
    mse = np.mean((preds_fp32[i] - preds_int8[i])**2)

    results.append({
        "image_idx": i,
        "label_match": top_fp32 == top_int8,
        "conf_delta": conf_fp32 - conf_int8,
        "mse": mse,
        "top_fp32": top_fp32,
        "conf_fp32": conf_fp32,
        "conf_int8": conf_int8
    })

# Print Summary
print("\n--- Degradation Report ---")
for res in results:
    status = " MATCH" if res['label_match'] else " FLIP "
    print(f"Img {res['image_idx']}: {status} | Conf Drift: {res['conf_delta']:.4f} | MSE: {res['mse']:.6f}")

In [None]:
# @title 8. Visualization
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

fig = plt.figure(figsize=(14, 6))

# --- Plot 1: Size Comparison (Unchanged) ---
ax1 = plt.subplot(1, 2, 1)
bars1 = ax1.bar(['FP32 (Original)', 'INT8 (Quantized)'], [size_fp32, size_int8], color=['#3498db', '#e74c3c'])
ax1.set_title('Model Footprint (Lower is Better)', fontsize=12)
ax1.set_ylabel('Size (MB)')
ax1.grid(axis='y', alpha=0.3)
for bar in bars1:
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:.1f} MB', ha='center', va='bottom', fontsize=10)

# --- Plot 2: Confidence Drift with Thumbnails ---
ax2 = plt.subplot(1, 2, 2)
drifts = [r['conf_delta'] * 100 for r in results] # Convert to %
colors = ['green' if d <= 0 else 'orange' for d in drifts]
x_pos = range(1, len(results) + 1)

bars2 = ax2.bar(x_pos, drifts, color=colors)
ax2.axhline(0, color='black', linewidth=0.8)
ax2.set_title('Confidence Loss per Image (Lower is Better)', fontsize=12)
ax2.set_xlabel('Test Image', fontsize=10)
ax2.set_ylabel('Confidence Drop (%)', fontsize=10)
ax2.grid(axis='y', alpha=0.3)
ax2.set_xticks(x_pos)

# Helper function to create thumbnail
def create_thumbnail(img_array, zoom=0.12):
    return OffsetImage(img_array, zoom=zoom)

# Add images on top/bottom of bars
for i, bar in enumerate(bars2):
    height = bar.get_height()

    # 1. Create the image box from raw_images data
    imagebox = create_thumbnail(raw_images[i])

    # 2. Determine position
    # x center of bar, y at the top/bottom edge of bar
    xy = (bar.get_x() + bar.get_width() / 2, height)

    # 3. Determine alignment based on positive/negative bar
    # If positive drop (orange), put image ABOVE bar (bottom aligned at (0.5, 0))
    # If negative drop (green), put image BELOW bar (top aligned at (0.5, 1))
    xybox_offset = (0, 5) if height >= 0 else (0, -5)
    alignment = (0.5, 0) if height >= 0 else (0.5, 1)

    # 4. Create annotation
    ab = AnnotationBbox(imagebox, xy,
                        xybox=xybox_offset,
                        xycoords='data',
                        boxcoords="offset points",
                        pad=0.1,
                        frameon=True,
                        bboxprops=dict(edgecolor=colors[i], lw=2), # Color border to match bar
                        box_alignment=alignment)
    ax2.add_artist(ab)

# Manually expand y-limits to make room for images so they aren't cut off
y_min, y_max = ax2.get_ylim()
ax2.set_ylim(y_min - 2, y_max + 2)

plt.tight_layout()
# Optional: save the high-quality version for the README immediately
plt.savefig('quantization_results_with_images.png', dpi=100, bbox_inches='tight')
print("Saved quantization_results_with_images.png")
plt.show()