

## What is optimization needed for ?
Optimization is measured in terms of two components :
1. Latency - How much time will it take for you to cross a bridge (or for a computation to complete)
2. Throughput - How many such queries can we execute at the same time (or how many people can cross a bridge at the same time)

## [PyTorch Optimization tricks](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/)
Quantization leverages 8bit integer instrictions to reduce model size and run inference faster 

1. Pytorch has data types correspondng to quantized tensots
2. We can write kernels with quantinsed tensors,  torch.nn.quantized and torch.nn.quantized.dynamic name-space
3. Quantized models are traceable and scriptable
4. Mapping floating point tensor to quantize dtensors is customizable with user defined/fake quantization blocks (so that model trains with floats but knows to ultimately convert the model to int8)


### Dynamic quantization
1. Convert weights to int8 (Weights Only ). Good for LSTMs/Transformers and MLPs wirh small batch size . 
    1. 2X faster computer, 4x less memory
2. convert activations to int8 on the fly (just before doing computation) [ like tanh, reulu etc] 
    a. Computations will be performed using int8 mult, faster compute (so only while doing the computation , we convert actications to int8 )
    b. activations are read and written to memory in floating point
3. use torch.quantization to convert modle to a dynamic quantized model 


### Post-training static quantization
1. Improve performance = latency : Weights and Activations (8 bit) .
    1. Works best for CNNS , accuracy is good. Requires Dataset calibration
    2. Tweak Model, Calibrate on data, convert 
    3. Quantize wright and activations for entire model or submodules 
    4. 1.5 - 2x faster compute, 4x less memory
2. First feeding batches of data through network and compute resulting distributions of different activations (by inserting observer modules at different points to record there distributions )
3. This info is used to determine how different activsations should be quantized at inference time 
    1. A Simple technice is to divide the entire range of activations into 256 levels , but more sophisticated methods are also supported
    2. Allows us to pass quantized vals between operations instead of converting vals to floats and back to ints 
4. Different options to optimize static quantization
    1. Observers : customize observer to specify how stats are collected prior to quantization to try out more advanced methods to quantize
        1. Use Torch.quantixation.prepare
    2. Operator fusion : fuse multiple ops into a single op (saving on memory access and also improve acc) 
        1. Use torch.quantization.fuse_modules
    3. Per channel quantization : independently quantize weights for each output channel in a convolution/linear layer = higher accuracy with same speed 
    4. Quantization itself done using torch.quantization.convert

### Quantization Aware Training 
1. Works on weights and activation, to be used during fine tuning  
    1. Works for all, best accuracy v/s performance tradeoff  
    2. Steps almost identical to post training
    3. Specify a different qconfig and use prepare_qat
    4. Train instead of calibrate (so while training itself we are quantizing the weights)
2. Under the hood :
    1. Fake quantization to mimic quantization in forward pass (can be customized)
    2. Straight through estimator in backword pass
    3. Batch Norm special handling :
        1. Fold batch normalization to mimix inference during training (?)
        2. Freeze batch norm stats update for improved accuracy during quantization aware training (like FastAI?)


In [2]:
# Pruning a module 
import torch
from torch.nn.utils.prune import *