### L4-B - Building your own Quantizer: Replace PyTorch layers with Quantized LayersÂ¶
In this lesson, you will learn about the quantization pipline using your own 8-bit quantizer.

#### Step 1. Import everything

In [1]:
import torch
from helpers import *

#### Step 2: Quantization Pipeline
Replace all of the torch.nn.Linear layers with the W8A16LinearLayer layer.
Call quantize on the linear layers using the original weights.
2.1 - Model In-place Linear Layer Replacement
Implement replace_linear_with_target

In [3]:
# loop over torch.nn.Module children and replace them with a new module
def replace_linear_with_target(module, target_class, module_names_to_exclude):
  '''
  Replace all linear layers in a module with a target class.

  Parameters:
  - module: The module containing the linear layers to be replaced.
  - target_class: The target class to replace the linear layers with.
  - module_names_to_exclude: A list of module names to exclude from replacement.

  Returns:
  - The modified module with linear layers replaced by the target class.
  '''
  for name, child in module.named_children():
    if isinstance(child, torch.nn.Linear) and not any([x ==name for x in module_names_to_exclude]):
      # get old module bias
      old_bias = child.bias

      # create new module 
      new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)

      # replace current module name with new module
      setattr(module, name, new_module)
      if old_bias is not None:
        # if old module had bias, set the bias of new module to the old bias
        getattr(module, name).bias.data = old_bias
    else:
      # Recursively apply the function to the child module for nested modules
      replace_linear_with_target(child, target_class, module_names_to_exclude)




In [4]:
# Create a dummy model
class DummyModel(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.emb = torch.nn.Embedding(1,1)
    # Try With Bias
    self.linear_1 = nn.Linear(1, 1)
    # Try Without Bias
    self.linear_2 = nn.Linear(1, 1, bias=False)
    # Linear Model predition head
    self.lm_head = nn.Linear(1, 1, bias=False)
    

In [10]:
# Create a dummy models, one with excluded layers and one with replacing all layers
model_1 = DummyModel() # exclude features
model_2 = DummyModel() # replace all layers
print("Model 1 before replacing it with the quantizable module: ", model_1)
print("Model 2 before replacing it with the quantizable module: ", model_2)

Model 1 before replacing it with the quantizable module:  DummyModel(
  (emb): Embedding(1, 1)
  (linear_1): Linear(in_features=1, out_features=1, bias=True)
  (linear_2): Linear(in_features=1, out_features=1, bias=False)
  (lm_head): Linear(in_features=1, out_features=1, bias=False)
)
Model 2 before replacing it with the quantizable module:  DummyModel(
  (emb): Embedding(1, 1)
  (linear_1): Linear(in_features=1, out_features=1, bias=True)
  (linear_2): Linear(in_features=1, out_features=1, bias=False)
  (lm_head): Linear(in_features=1, out_features=1, bias=False)
)


##### Try to replace our model features excluding lm_head

In [9]:
replace_linear_with_target(model_1, W8A16LinearLayer, ["lm_head"])
print("Model 1 after replacing it with the quantizable module: ", model_1)

Model 1 after replacing it with the quantizable module:  DummyModel(
  (emb): Embedding(1, 1)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): Linear(in_features=1, out_features=1, bias=False)
)


##### Try to replace all the linear layers

In [11]:
replace_linear_with_target(model_2, W8A16LinearLayer, [])
print("Model 2 after replacing it with the quantizable module: ", model_2)

Model 2 after replacing it with the quantizable module:  DummyModel(
  (emb): Embedding(1, 1)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): W8A16LinearLayer()
)


#### Step 3. Quantize the new module once we replace the old module with the new module
Let's refine the function above to add quantization for all the replaced layers

In [14]:
# loop over torch.nn.Module children and replace them with a new module
def replace_linear_with_target_and_quantize(module, target_class, module_names_to_exclude):
  '''
  Replace all linear layers in a module with a target class.

  Parameters:
  - module: The module containing the linear layers to be replaced.
  - target_class: The target class to replace the linear layers with.
  - module_names_to_exclude: A list of module names to exclude from replacement.

  Returns:
  - The modified module with linear layers replaced by the target class.
  '''
  for name, child in module.named_children():
    if isinstance(child, torch.nn.Linear) and not any([x ==name for x in module_names_to_exclude]):
      # get old module bias
      old_bias = child.bias
      # retrieve the old weight
      old_weight = child.weight

      # create new module 
      new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)

      # replace current module name with new module
      setattr(module, name, new_module)

      # Once the old module is replaced above, we can now set the old weight to the new module
      # Get this new module, and quantize it's old weight
      getattr(module, name).quantize(old_weight)

      if old_bias is not None:
        # if old module had bias, set the bias of new module to the old bias
        getattr(module, name).bias.data = old_bias
    else:
      # Recursively apply the function to the child module for nested modules
      replace_linear_with_target_and_quantize(child, target_class, module_names_to_exclude)

In [15]:
# Try it out on another dummy model
model_3 = DummyModel() # exclude features
print("Model 3 before replacing it with the quantizable module: ", model_3)
replace_linear_with_target_and_quantize(model_3, W8A16LinearLayer, [])
print("Model 3 after replacing it with the quantizable module: ", model_3)

Model 3 before replacing it with the quantizable module:  DummyModel(
  (emb): Embedding(1, 1)
  (linear_1): Linear(in_features=1, out_features=1, bias=True)
  (linear_2): Linear(in_features=1, out_features=1, bias=False)
  (lm_head): Linear(in_features=1, out_features=1, bias=False)
)
Model 3 after replacing it with the quantizable module:  DummyModel(
  (emb): Embedding(1, 1)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): W8A16LinearLayer()
)
