# Data Generation Tutorial: Data-Free (Zero-Shot) Quantization in Pytorch with the Model Compression Toolkit (MCT)
[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/pytorch/example_pytorch_data_generation.ipynb)

## Overview
In this tutorial, we will explore how to generate synthetic images using the Model Compression Toolkit (MCT) and the Data Generation Library. These generated images are based on the statistics stored in the model's batch normalization layers and can be usefull for various compression tasks, such as quantization and pruning. We will use the generated images as a representative dataset to quantize our model to 8-bit using MCT's Post Training Quantization (PTQ).

## Summary
We will cover the following steps:
1. **Setup** Install and import necessary libraries and load a pre-trained model.
2. **Configuration**: Define the data generation configuration.
3. **Data Generation**: Generate synthetic images.
4. **Visualization**: Visualize the generated images.
5. **Quantization**: Quantize our model to 8-bit using PTQ with the generated images as a representative dataset. This is called **"Data-Free Quantization"** since no real data is used in the quantization process.

## Step 1: Setup
Install the necessary packages:

In [None]:
!pip install -q torch torchvision

In [None]:
import importlib
if not importlib.util.find_spec('model_compression_toolkit'):
    !pip install model_compression_toolkit

In [None]:
import torch
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.datasets import ImageNet
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

Load the model from the torchvision library:

In [None]:
# Load a pre-trained model (e.g., ResNet18)
weights = ResNet18_Weights.DEFAULT
float_model = resnet18(weights=weights)

## Step 2: Define a Data Generation Configuration
Next, we need to specify the configuration for data generation using `get_pytorch_data_generation_config`. This configuration includes parameters such as the number of iterations, optimizer, batch size, and more. Customize these parameters according to your needs.

In [None]:
import model_compression_toolkit as mct

data_gen_config = mct.data_generation.get_pytorch_data_generation_config(
    n_iter=500,                      # Number of iterations
    optimizer=torch.optim.RAdam,     # Optimizer
    data_gen_batch_size=128,          # Batch size for data generation
    initial_lr=16,                   # Initial learning rate
    output_loss_multiplier=1e-6,     # Multiplier for output loss
    extra_pixels=32, 
    # ... (customize other parameters)
)

## Step 3: Generate Synthetic Images

Now, let's generate synthetic images using the `pytorch_data_generation_experimental` function. Specify the number of images you want to generate and the output image size.

In [None]:
n_images = 256              # Number of images to generate
output_image_size = 224     # Size of output images

generated_images = mct.data_generation.pytorch_data_generation_experimental(
    model=float_model,
    n_images=n_images,
    output_image_size=output_image_size,
    data_generation_config=data_gen_config
)

## Step 4: Visualization
Lets define a function to display the generated images:

In [None]:
def plot_image(image, reverse_preprocess=False, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    image = image.detach().cpu().numpy()[0]
    image = image.transpose(1, 2, 0)
    if reverse_preprocess:
        new_image = np.round(((image.astype(np.float32) * std) + mean) * 255).astype(np.uint8)
    plt.imshow(new_image)
    plt.show()

Now, let's visualize the generated images by selecting an image index to plot. You can modify the index values to experiment with different images.

In [None]:
img_index_to_plot = 0
plot_image(generated_images[img_index_to_plot],True)

## Step 5: Post Training Quantization
In order to evaulate our generated images, we will use them to quantize the model using MCT's PTQ.This is referred to as **"Zero-Shot Quantization (ZSQ)"** or **"Data-Free Quantization"** because no real data is used in the quantization process. Next we will define configurations for MCT's PTQ.

### Target Platform Capabilities (TPC)
MCT optimizes the model for dedicated hardware platforms. This is done using TPC (for more details, please visit our [documentation](https://sonysemiconductorsolutions.github.io/mct-model-optimization/api/api_docs/modules/target_platform_capabilities.html)). Here, we use the default Pytorch TPC:

In [None]:
target_platform_cap = mct.get_target_platform_capabilities("pytorch", "default")

### Representative Dataset
For quantization with MCT, we need to define a representative dataset required by the PTQ algorithm. This dataset is a generator that returns a list of images. We wil use our generated images for the representative dataset.

In [None]:
batch_size = 64
n_iter = 10

generated_images = np.concatenate(generated_images, axis=0).reshape(*(-1, batch_size, *list(generated_images[0].shape[1:])))
        
def representative_data_gen():
    for nn in range(n_iter):
        nn_mod = nn % generated_images.shape[0]
        yield [generated_images[nn_mod]]

### Quantization with our generated images
Now, we are ready to use MCT to quantize the model.

In [None]:
# run post training quantization on the model to get the quantized model output
quantized_model_generated_data, quantization_info = mct.ptq.pytorch_post_training_quantization(
    in_module=float_model,
    representative_data_gen=representative_data_gen,
    target_platform_capabilities=target_platform_cap
)

## Setup for evaluation on the ImageNet dataset
### Download ImageNet validation set
Download ImageNet dataset with only the validation split. This step may take several minutes...

In [None]:
import os

if not os.path.isdir('imagenet'):
    !mkdir imagenet
    !wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz
    !wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar

# Extract ImageNet validation dataset using torchvision "datasets" module.
dataset = ImageNet(root='./imagenet', split='val', transform=weights.transforms())
val_dataloader = DataLoader(dataset, batch_size=50, shuffle=False, num_workers=16, pin_memory=True)

Here we define functions for evaluation:

In [None]:
from tqdm import tqdm


def evaluate(model, testloader):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():
        for data in tqdm(testloader):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # correct += (predicted == labels).sum().item()
    val_acc = (100 * correct / total)
    print('Accuracy: %.2f%%' % val_acc)
    return val_acc

### Evaluation of the quantized model's performance
Here we evaluate our model's top 1 classification performance after PTQ on the ImageNet validation dataset.
Let's start with the floating-point model evaluation.

In [None]:
evaluate(float_model, val_dataloader)

Finally, let's evaluate the quantized model:

In [None]:
evaluate(quantized_model_generated_data, val_dataloader)

## Conclusion:
In this tutorial, we demonstrated how to generate synthetic images from a trained model and use them for model quantization. The quantized model achieved a 4x reduction in size compared to the original float model, while maintaining performance similar to the reported float results. Notably, no real data was required in this process.

## Copyrights:
Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.