diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py index 24355b08779..fc9da01940e 100644 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ b/torchvision/prototype/models/quantization/mobilenetv3.py @@ -43,14 +43,14 @@ def _mobilenet_v3_model( if quantize: model.fuse_model() - model.qconfig = torch.quantization.get_default_qat_qconfig(backend) - torch.quantization.prepare_qat(model, inplace=True) + model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) + torch.ao.quantization.prepare_qat(model, inplace=True) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) if quantize: - torch.quantization.convert(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) model.eval() return model