diff --git a/test.py b/test.py index a326c31f30..97e9caaf86 100644 --- a/test.py +++ b/test.py @@ -120,6 +120,10 @@ def _load_tests(): devices.append('cuda') for path in _list_model_paths(): + # TODO: skipping quantized tests for now due to BC-breaking changes for prepare + # api, enable after PyTorch 1.13 release + if "quantized" in path: + continue for device in devices: _load_test(path, device) diff --git a/test_bench.py b/test_bench.py index 646f213d37..38fefb6cc0 100644 --- a/test_bench.py +++ b/test_bench.py @@ -51,6 +51,10 @@ def test_train(self, model_path, device, compiler, benchmark): if skip_by_metadata(test="train", device=device, jit=(compiler == 'jit'), \ extra_args=[], metadata=get_metadata_from_yaml(model_path)): raise NotImplementedError("Test skipped by its metadata.") + # TODO: skipping quantized tests for now due to BC-breaking changes for prepare + # api, enable after PyTorch 1.13 release + if "quantized" in model_path: + return task = ModelTask(model_path) if not task.model_details.exists: return # Model is not supported. @@ -67,6 +71,10 @@ def test_eval(self, model_path, device, compiler, benchmark, pytestconfig): if skip_by_metadata(test="eval", device=device, jit=(compiler == 'jit'), \ extra_args=[], metadata=get_metadata_from_yaml(model_path)): raise NotImplementedError("Test skipped by its metadata.") + # TODO: skipping quantized tests for now due to BC-breaking changes for prepare + # api, enable after PyTorch 1.13 release + if "quantized" in model_path: + return task = ModelTask(model_path) if not task.model_details.exists: return # Model is not supported. diff --git a/torchbenchmark/models/mobilenet_v2_quantized_qat/__init__.py b/torchbenchmark/models/mobilenet_v2_quantized_qat/__init__.py index d3637a9bf5..21f2cc96ab 100644 --- a/torchbenchmark/models/mobilenet_v2_quantized_qat/__init__.py +++ b/torchbenchmark/models/mobilenet_v2_quantized_qat/__init__.py @@ -32,7 +32,7 @@ def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]): def prep_qat_train(self): qconfig_dict = {"": torch.quantization.get_default_qat_qconfig('fbgemm')} self.model.train() - self.model = quantize_fx.prepare_qat_fx(self.model, qconfig_dict) + self.model = quantize_fx.prepare_qat_fx(self.model, qconfig_dict, self.example_inputs) def train(self, niter=3): optimizer = optim.Adam(self.model.parameters()) diff --git a/torchbenchmark/models/resnet50_quantized_qat/__init__.py b/torchbenchmark/models/resnet50_quantized_qat/__init__.py index 2b31e42b55..381faf281a 100644 --- a/torchbenchmark/models/resnet50_quantized_qat/__init__.py +++ b/torchbenchmark/models/resnet50_quantized_qat/__init__.py @@ -32,7 +32,7 @@ def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]): def prep_qat_train(self): qconfig_dict = {"": torch.quantization.get_default_qat_qconfig('fbgemm')} self.model.train() - self.model = quantize_fx.prepare_qat_fx(self.model, qconfig_dict) + self.model = quantize_fx.prepare_qat_fx(self.model, qconfig_dict, self.example_inputs) def get_module(self): return self.model, self.example_inputs