Skip to content

Error while running Whisper model quantization with Intel neural compressor #2056

Open
@Shivani-k16

Description

@Shivani-k16

OS: Ubuntu

Hardware: CPU Intel(R) Xeon(R) Platinum 8468V

I have installed the required dependencies listed in the GitHub repository using the latest versions, as specific versions were not specified

transformers 4.44.2
datasets 2.21.0
evaluate 0.4.3
jiwer 3.0.4
librosa 0.10.2.post1
torch 2.1.0.post3+cxx11.abi
optimum-intel 1.19.0
neural_compressor 3.0.2
Python 3.10.12

Reproduction steps:

  1. Cloned the repository.

  2. Installed dependencies using pip install -r requirements.txt.

  3. pip install neural-compressor

  4. Ran python run_whisper_large.py

Command line:

python run_whisper_large.py --tune --int8 --batch_size 1 --output_dir output_model --cache_dir cache

Error:

2024-10-16 07:18:04 [INFO] Start auto tuning.
2024-10-16 07:18:04 [INFO] Execute the tuning process due to detect the evaluation function.
2024-10-16 07:18:04 [INFO] Adaptor has 5 recipes.
2024-10-16 07:18:04 [INFO] 0 recipes specified by user.
2024-10-16 07:18:04 [INFO] 3 recipes require future tuning.
2024-10-16 07:18:04 [INFO] *** Initialize auto tuning
2024-10-16 07:18:04 [INFO] {
2024-10-16 07:18:04 [INFO]     'PostTrainingQuantConfig': {
2024-10-16 07:18:04 [INFO]         'AccuracyCriterion': {
2024-10-16 07:18:04 [INFO]             'criterion': 'relative',
2024-10-16 07:18:04 [INFO]             'higher_is_better': True,
2024-10-16 07:18:04 [INFO]             'tolerable_loss': 0.01,
2024-10-16 07:18:04 [INFO]             'absolute': None,
2024-10-16 07:18:04 [INFO]             'keys': <bound method AccuracyCriterion.keys of <neural_compressor.config.AccuracyCriterion object at 0x154f25fca260>>,
2024-10-16 07:18:04 [INFO]             'relative': 0.01
2024-10-16 07:18:04 [INFO]         },
2024-10-16 07:18:04 [INFO]         'approach': 'post_training_dynamic_quant',
2024-10-16 07:18:04 [INFO]         'backend': 'default',
2024-10-16 07:18:04 [INFO]         'calibration_sampling_size': [
2024-10-16 07:18:04 [INFO]             100
2024-10-16 07:18:04 [INFO]         ],
2024-10-16 07:18:04 [INFO]         'device': 'cpu',
2024-10-16 07:18:04 [INFO]         'domain': 'auto',
2024-10-16 07:18:04 [INFO]         'example_inputs': 'Not printed here due to large size tensors...',
2024-10-16 07:18:04 [INFO]         'excluded_precisions': [
2024-10-16 07:18:04 [INFO]         ],
2024-10-16 07:18:04 [INFO]         'framework': 'pytorch_fx',
2024-10-16 07:18:04 [INFO]         'inputs': [
2024-10-16 07:18:04 [INFO]         ],
2024-10-16 07:18:04 [INFO]         'model_name': '',
2024-10-16 07:18:04 [INFO]         'op_name_dict': None,
2024-10-16 07:18:04 [INFO]         'op_type_dict': {
2024-10-16 07:18:04 [INFO]             'Embedding': {
2024-10-16 07:18:04 [INFO]                 'weight': {
2024-10-16 07:18:04 [INFO]                     'dtype': [
2024-10-16 07:18:04 [INFO]                         'fp32'
2024-10-16 07:18:04 [INFO]                     ]
2024-10-16 07:18:04 [INFO]                 },
2024-10-16 07:18:04 [INFO]                 'activation': {
2024-10-16 07:18:04 [INFO]                     'dtype': [
2024-10-16 07:18:04 [INFO]                         'fp32'
2024-10-16 07:18:04 [INFO]                     ]
2024-10-16 07:18:04 [INFO]                 }
2024-10-16 07:18:04 [INFO]             }
2024-10-16 07:18:04 [INFO]         },
2024-10-16 07:18:04 [INFO]         'outputs': [
2024-10-16 07:18:04 [INFO]         ],
2024-10-16 07:18:04 [INFO]         'quant_format': 'default',
2024-10-16 07:18:04 [INFO]         'quant_level': 'auto',
2024-10-16 07:18:04 [INFO]         'recipes': {
2024-10-16 07:18:04 [INFO]             'smooth_quant': False,
2024-10-16 07:18:04 [INFO]             'smooth_quant_args': {
2024-10-16 07:18:04 [INFO]             },
2024-10-16 07:18:04 [INFO]             'layer_wise_quant': False,
2024-10-16 07:18:04 [INFO]             'layer_wise_quant_args': {
2024-10-16 07:18:04 [INFO]             },
2024-10-16 07:18:04 [INFO]             'fast_bias_correction': False,
2024-10-16 07:18:04 [INFO]             'weight_correction': False,
2024-10-16 07:18:04 [INFO]             'gemm_to_matmul': True,
2024-10-16 07:18:04 [INFO]             'graph_optimization_level': None,
2024-10-16 07:18:04 [INFO]             'first_conv_or_matmul_quantization': True,
2024-10-16 07:18:04 [INFO]             'last_conv_or_matmul_quantization': True,
2024-10-16 07:18:04 [INFO]             'pre_post_process_quantization': True,
2024-10-16 07:18:04 [INFO]             'add_qdq_pair_to_weight': False,
2024-10-16 07:18:04 [INFO]             'optypes_to_exclude_output_quant': [
2024-10-16 07:18:04 [INFO]             ],
2024-10-16 07:18:04 [INFO]             'dedicated_qdq_pair': False,
2024-10-16 07:18:04 [INFO]             'rtn_args': {
2024-10-16 07:18:04 [INFO]             },
2024-10-16 07:18:04 [INFO]             'awq_args': {
2024-10-16 07:18:04 [INFO]             },
2024-10-16 07:18:04 [INFO]             'gptq_args': {
2024-10-16 07:18:04 [INFO]             },
2024-10-16 07:18:04 [INFO]             'teq_args': {
2024-10-16 07:18:04 [INFO]             },
2024-10-16 07:18:04 [INFO]             'autoround_args': {
2024-10-16 07:18:04 [INFO]             }
2024-10-16 07:18:04 [INFO]         },
2024-10-16 07:18:04 [INFO]         'reduce_range': None,
2024-10-16 07:18:04 [INFO]         'TuningCriterion': {
2024-10-16 07:18:04 [INFO]             'max_trials': 100,
2024-10-16 07:18:04 [INFO]             'objective': [
2024-10-16 07:18:04 [INFO]                 'performance'
2024-10-16 07:18:04 [INFO]             ],
2024-10-16 07:18:04 [INFO]             'strategy': 'basic',
2024-10-16 07:18:04 [INFO]             'strategy_kwargs': None,
2024-10-16 07:18:04 [INFO]             'timeout': 0
2024-10-16 07:18:04 [INFO]         },
2024-10-16 07:18:04 [INFO]         'use_bf16': True,
2024-10-16 07:18:04 [INFO]         'ni_workload_name': 'quantization'
2024-10-16 07:18:04 [INFO]     }
2024-10-16 07:18:04 [INFO] }
2024-10-16 07:18:04 [WARNING] [Strategy] Please install `mpi4py` correctly if using distributed tuning; otherwise, ignore this warning.
/home/u8bd311b633876ba392b704069aeab3e/env/xpu/lib/python3.10/site-packages/torch/overrides.py:110: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
  torch.has_cuda,
/home/u8bd311b633876ba392b704069aeab3e/env/xpu/lib/python3.10/site-packages/torch/overrides.py:111: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
  torch.has_cudnn,
/home/u8bd311b633876ba392b704069aeab3e/env/xpu/lib/python3.10/site-packages/torch/overrides.py:117: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
  torch.has_mps,
/home/u8bd311b633876ba392b704069aeab3e/env/xpu/lib/python3.10/site-packages/torch/overrides.py:118: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
  torch.has_mkldnn,
2024-10-16 07:18:08 [INFO]  Found 12 blocks
2024-10-16 07:18:08 [INFO] Attention Blocks: 12
2024-10-16 07:18:08 [INFO] FFN Blocks: 12
2024-10-16 07:18:08 [INFO] Pass query framework capability elapsed time: 3381.72 ms
2024-10-16 07:18:08 [INFO] Get FP32 model baseline.
/home/u8bd311b633876ba392b704069aeab3e/env/xpu/lib/python3.10/site-packages/transformers/models/whisper/tokenization_whisper.py:501: UserWarning: The private method `_normalize` is deprecated and will be removed in v5 of Transformers.You can normalize an input string using the Whisper English normalizer using the `normalize` method.
  warnings.warn(
Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Result wer: 6.159110350727117
Accuracy: 0.93841
2024-10-16 08:19:12 [INFO] Save tuning history to /home/u8bd311b633876ba392b704069aeab3e/neural-compressor/examples/pytorch/speech_recognition/whisper_large/quantization/ptq_dynamic/fx/nc_workspace/2024-10-16_07-18-02/./history.snapshot.
2024-10-16 08:19:12 [INFO] FP32 baseline is: [Accuracy: 0.9384, Duration (seconds): 3663.6608]
2024-10-16 08:19:12 [INFO] Quantize the model with default config.
2024-10-16 08:19:14 [INFO] Convert operators to bfloat16
2024-10-16 08:19:14 [INFO] Fx trace of the entire model failed, We will conduct auto quantization
2024-10-16 08:19:25 [INFO] |*******Mixed Precision Statistics*******|
2024-10-16 08:19:25 [INFO] +-----------+-------+------+------+------+
2024-10-16 08:19:25 [INFO] |  Op Type  | Total | INT8 | BF16 | FP32 |
2024-10-16 08:19:25 [INFO] +-----------+-------+------+------+------+
2024-10-16 08:19:25 [INFO] |   Linear  |  193  | 193  |  0   |  0   |
2024-10-16 08:19:25 [INFO] |   Conv1d  |   2   |  0   |  2   |  0   |
2024-10-16 08:19:25 [INFO] | Embedding |   2   |  0   |  0   |  2   |
2024-10-16 08:19:25 [INFO] +-----------+-------+------+------+------+
2024-10-16 08:19:25 [INFO] Pass quantize model elapsed time: 13181.62 ms
2024-10-16 08:19:25 [ERROR] Unexpected exception AttributeError("'GraphModule' object has no attribute 'stride'") happened during tuning.
Traceback (most recent call last):
  File "/home/u8bd311b633876ba392b704069aeab3e/env/xpu/lib/python3.10/site-packages/neural_compressor/quantization.py", line 220, in fit
    strategy.traverse()
  File "/home/u8bd311b633876ba392b704069aeab3e/env/xpu/lib/python3.10/site-packages/neural_compressor/strategy/auto.py", line 140, in traverse
    super().traverse()
  File "/home/u8bd311b633876ba392b704069aeab3e/env/xpu/lib/python3.10/site-packages/neural_compressor/strategy/strategy.py", line 519, in traverse
    self.last_tune_result = self._evaluate(self.last_qmodel)
  File "/home/u8bd311b633876ba392b704069aeab3e/env/xpu/lib/python3.10/site-packages/neural_compressor/strategy/strategy.py", line 1693, in _evaluate
    val = self.objectives.evaluate(self.eval_func, model.model)
  File "/home/u8bd311b633876ba392b704069aeab3e/env/xpu/lib/python3.10/site-packages/neural_compressor/objective.py", line 443, in evaluate
    acc = eval_func(model)
  File "/home/u8bd311b633876ba392b704069aeab3e/neural-compressor/examples/pytorch/speech_recognition/whisper_large/quantization/ptq_dynamic/fx/run_whisper_large.py", line 47, in eval_func
    predicted_ids = model.generate(input_features)[0]
  File "/home/u8bd311b633876ba392b704069aeab3e/env/xpu/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py", line 505, in generate
    input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
  File "/home/u8bd311b633876ba392b704069aeab3e/env/xpu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'GraphModule' object has no attribute 'stride'
2024-10-16 08:19:25 [ERROR] Specified timeout or max trials is reached! Not found any quantized model which meet accuracy goal. Exit.
Traceback (most recent call last):
  File "/home/u8bd311b633876ba392b704069aeab3e/neural-compressor/examples/pytorch/speech_recognition/whisper_large/quantization/ptq_dynamic/fx/run_whisper_large.py", line 66, in <module>
    q_model.save(args.output_dir)
AttributeError: 'NoneType' object has no attribute 'save'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions