In [30]:
import torch

# define a floating point model
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(4, 4)

    def forward(self, x):
        x = self.fc(x)
        return x

# create a model instance
model_fp32 = M()
# create a quantized model instance
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32,  # the original model
    {torch.nn.Linear},  # a set of layers to dynamically quantize
    dtype=torch.qint8)  # the target dtype for quantized weights

# run the model
input_fp32 = torch.randn(4, 4, 4, 4)
res = model_int8(input_fp32)

In [31]:
res

tensor([[[[ 2.0240e+00,  1.5521e-01,  7.5095e-02,  4.4543e-01],
          [ 4.3773e-01,  1.7560e+00, -5.5713e-01,  8.9000e-01],
          [-5.7184e-01,  1.9503e+00, -7.9842e-01,  4.2823e-01],
          [ 6.4926e-02, -5.2784e-01, -6.2297e-01, -8.8784e-01]],

         [[ 1.3309e+00,  6.3560e-01, -5.2023e-01, -3.9306e-01],
          [ 1.3031e+00,  1.4272e+00, -1.0994e+00, -1.0564e-01],
          [ 2.7697e-01, -1.3815e-01,  1.0761e+00, -3.3645e-01],
          [-5.0866e-01, -1.0265e-01,  1.2904e-01, -5.5304e-01]],

         [[-3.5448e-01,  5.7023e-01,  4.4133e-01,  5.2111e-01],
          [-2.4876e-01,  1.4459e+00, -1.1869e+00,  6.8687e-01],
          [ 5.9504e-01,  3.0127e-01, -5.6746e-01,  3.2533e-01],
          [ 8.8590e-01,  4.5952e-01,  1.3219e+00,  2.9233e-01]],

         [[ 4.9684e-01,  6.4014e-01, -5.7715e-01, -5.8228e-01],
          [ 5.8816e-01,  1.4389e+00, -1.0417e+00,  1.1222e+00],
          [ 1.7804e+00, -1.3831e-01,  4.3836e-01, -3.4255e-01],
          [ 3.9738e-01,  8.6235e-0

In [11]:
input_fp32.dtype

torch.float32

In [12]:
res.dtype

torch.float32

In [21]:
print(model_fp32.parameters())

<generator object Module.parameters at 0x7f11449f7c10>


In [23]:
for param in model_fp32.parameters():
    print(param)
    print(param.dtype)

Parameter containing:
tensor([[-0.4070, -0.3779,  0.1892,  0.3390],
        [-0.4217, -0.0338,  0.3939,  0.2279],
        [-0.2858, -0.4249, -0.4494, -0.0818],
        [-0.2770,  0.0165, -0.2295,  0.2479]], requires_grad=True)
torch.float32
Parameter containing:
tensor([-0.0713, -0.3136,  0.0278, -0.3181], requires_grad=True)
torch.float32


In [24]:
for param in model_int8.parameters():
    print(param)
    print(param.dtype)

References:
* https://huggingface.co/docs/optimum/concept_guides/quantization