Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch._export can't export resnet50 model #124595

Open
ecilay opened this issue Apr 22, 2024 · 5 comments
Open

torch._export can't export resnet50 model #124595

ecilay opened this issue Apr 22, 2024 · 5 comments
Labels

Comments

@ecilay
Copy link

ecilay commented Apr 22, 2024

馃悰 Describe the bug

I followed this documentation https://pytorch.org/docs/stable/quantization.html to try the three available quantization methods, both fx graph and eager mode works, but the export approach doesn't work.
My questions are:

  1. based on the code given, do you see if I miss anything to make export approach work?
  2. the speed gain is pt_time: 0.04296875 vs quantized_time: 0.035390625 for processing a single batch 1 input, not that significant considering conv layer is the major layer used in resnet50 model, also model size only reduced from 98M to 92M, can it be faster and smaller?
  3. I used the same eager mode export approach to try quantize an autoencoder model https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/model.py#L368, but it doesn't speed up at all and it is producing exactly the same results which looks like it did no optimize. I think it is because attention layer is not speeded up; is there any way to quantize attention layer with pytorch quantization?

Code to reproduce:

import torch, time
from torchvision.models import resnet50, ResNet50_Weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
# model.cuda()
model.eval()


def eager_quantize(model):
    # pt_time:  0.0418
    # qt_time:  0.0392
    model_int8 = torch.ao.quantization.quantize_dynamic(
        model,  # the original model
        {torch.nn.Linear, torch.nn.Conv2d},  # a set of layers to dynamically quantize
        dtype=torch.qint8)  # the target dtype for quantized weights
    return model_int8


def export_quantize(m):
    # Issue: bn expected 4d, got 3
    from torch._export import capture_pre_autograd_graph
    example_inputs = torch.randn(1, 3, 224, 224)
    m = capture_pre_autograd_graph(m, example_inputs)
    # we get a model with aten ops


    # Step 2. quantization
    from torch.ao.quantization.quantize_pt2e import (
    prepare_pt2e,
    convert_pt2e,
    )

    from torch.ao.quantization.quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
    )
    # backend developer will write their own Quantizer and expose methods to allow
    # users to express how they
    # want the model to be quantized
    quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
    m = prepare_pt2e(m, quantizer)

    # calibration omitted

    m = convert_pt2e(m)
    # we have a model with aten ops doing integer computations when possible

    return m

def fx_quantize(float_model):
    # pt_time:  0.04296875
    # qt_time:  0.035390625
    from torch.ao.quantization import default_dynamic_qconfig, QConfigMapping
    from torch.quantization.quantize_fx import prepare_fx, convert_fx
    # qconfig = get_default_qconfig("x86")
    qconfig_mapping = QConfigMapping().set_global(default_dynamic_qconfig)
    example_inputs = torch.randn(1, 3, 224, 224)
    prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs)  # fuse modules and insert observers
    # no calibration is required for dynamic quantization
    quantized_model = convert_fx(prepared_model)  # convert the model to a dynamically quantized 
    return quantized_model



model_int8 = eager_quantize(model)
torch.save(model_int8.state_dict(), "eager_quantize.pt")

Error logs

File "/home/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 416, in _check_input_dim
    raise ValueError(f"expected 4D input (got {input.dim()}D input)")
torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___bn1(*(FakeTensor(..., size=(64, 112, 112), grad_fn=<SqueezeBackward1>),), **{}):
expected 4D input (got 3D input)

Minified repro

No response

Versions

Collecting environment information...
PyTorch version: 2.1.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.13 | packaged by conda-forge | (main, Oct 26 2023, 18:07:37) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1015-aws-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB
Nvidia driver version: 525.105.17
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   46 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          96
On-line CPU(s) list:             0-95
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
CPU family:                      6
Model:                           85
Thread(s) per core:              2
Core(s) per socket:              24
Socket(s):                       2
Stepping:                        7
BogoMIPS:                        5999.99
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       1.5 MiB (48 instances)
L1i cache:                       1.5 MiB (48 instances)
L2 cache:                        48 MiB (48 instances)
L3 cache:                        71.5 MiB (2 instances)
NUMA node(s):                    2
NUMA node0 CPU(s):               0-23,48-71
NUMA node1 CPU(s):               24-47,72-95
Vulnerability Itlb multihit:     KVM: Mitigation: VMX unsupported
Vulnerability L1tf:              Mitigation; PTE Inversion
Vulnerability Mds:               Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:          Mitigation; PTI
Vulnerability Mmio stale data:   Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, STIBP disabled, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] aitemplate==0.3.dev0+fa2torch2.1cu12.1
[pip3] mypy==1.9.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] onnx==1.15.0
[pip3] onnx2torch==1.5.13
[pip3] onnxruntime==1.15.1
[pip3] open-clip-torch==2.24.0
[pip3] pytorch-lightning==2.0.2
[pip3] pytorch_quantization==2.2.0
[pip3] pytorch-wavelets==1.3.0
[pip3] rotary-embedding-torch==0.5.3
[pip3] torch==2.1.0
[pip3] torch-utils==0.1.2
[pip3] torchaudio==2.1.0
[pip3] torchmetrics==1.3.1
[pip3] torchvision==0.16.0
[pip3] triton==2.1.0
[conda] aitemplate                0.3.dev0+fa2torch2.1cu12.1          pypi_0    pypi
[conda] numpy                     1.24.4                   pypi_0    pypi
[conda] onnx2torch                1.5.13                   pypi_0    pypi
[conda] open-clip-torch           2.24.0                   pypi_0    pypi
[conda] pytorch-lightning         2.0.2                    pypi_0    pypi
[conda] pytorch-quantization      2.2.0                    pypi_0    pypi
[conda] pytorch-wavelets          1.3.0                    pypi_0    pypi
[conda] rotary-embedding-torch    0.5.3                    pypi_0    pypi
[conda] torch                     2.1.0                    pypi_0    pypi
[conda] torch-utils               0.1.2                    pypi_0    pypi
[conda] torchaudio                2.1.0                    pypi_0    pypi
[conda] torchmetrics              1.3.1                    pypi_0    pypi
[conda] torchvision               0.16.0                   pypi_0    pypi
[conda] triton                    2.1.0                    pypi_0    pypi

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@BoyuanFeng
Copy link
Contributor

Hi @ecilay, please try the following code to make export work. Changes include a) providing a tuple input to capture_pre_autograd_graph and b) fixing the import path for XNNPACKQuantizer.

def export_quantize(m):
    # Issue: bn expected 4d, got 3
    from torch._export import capture_pre_autograd_graph
    example_inputs = (torch.randn(1, 3, 224, 224),) # Note: input should be a tuple
    # breakpoint()
    m = capture_pre_autograd_graph(m, example_inputs)
    # we get a model with aten ops


    # Step 2. quantization
    from torch.ao.quantization.quantize_pt2e import (
    prepare_pt2e,
    convert_pt2e,
    )

    from torch.ao.quantization.quantizer.xnnpack_quantizer import ( # Note: Updated import path
        XNNPACKQuantizer,
        get_symmetric_quantization_config,
    )
    # backend developer will write their own Quantizer and expose methods to allow
    # users to express how they
    # want the model to be quantized
    quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
    m = prepare_pt2e(m, quantizer)

    # calibration omitted

    m = convert_pt2e(m)
    # we have a model with aten ops doing integer computations when possible

    return m

@ecilay
Copy link
Author

ecilay commented Apr 22, 2024

Thanks your fix works, however, the resnet classification results are totally wrong. You can reproduce using below inference code. Also the runtime expoerted in this way is almost doubled/tripled: pt_time: 0.0398 vs quantize_time: 0.0918

from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights

img = read_image("crane.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
# model = resnet50(weights=weights)
# model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()



def run_inference(model):
    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)

    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    score = prediction[class_id].item()
    category_name = weights.meta["categories"][class_id]
    print(f"{category_name}: {100 * score:.1f}%")

@ecilay
Copy link
Author

ecilay commented Apr 22, 2024

Besides, if I use fx export, the outputs are numerically very different from original model; though the softmax classification results are the same, is this expected?

@BoyuanFeng
Copy link
Contributor

Do you mind share "crane.jpg" and the log? Thanks!

@ecilay
Copy link
Author

ecilay commented Apr 22, 2024

crane
The outputs are below between original and export quantized model:

geyser: 5.9%
iron: 99.6%

@angelayi angelayi added the oncall: quantization Quantization support in PyTorch label Apr 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants