diff --git a/torchbenchmark/models/mobilenet_v2/__init__.py b/torchbenchmark/models/mobilenet_v2/__init__.py index d7ffbec5b1..1ffa6f6422 100644 --- a/torchbenchmark/models/mobilenet_v2/__init__.py +++ b/torchbenchmark/models/mobilenet_v2/__init__.py @@ -14,6 +14,7 @@ ####################################################### class Model(BenchmarkModel): task = COMPUTER_VISION.CLASSIFICATION + def __init__(self, device=None, jit=False): super().__init__() self.device = device diff --git a/torchbenchmark/models/mobilenet_v2_quantized_qat/__init__.py b/torchbenchmark/models/mobilenet_v2_quantized_qat/__init__.py index 72d7075090..ef6d90d7de 100644 --- a/torchbenchmark/models/mobilenet_v2_quantized_qat/__init__.py +++ b/torchbenchmark/models/mobilenet_v2_quantized_qat/__init__.py @@ -3,10 +3,13 @@ import torch.optim as optim import torchvision.models as models from torch.quantization import quantize_fx +from torchbenchmark.tasks import COMPUTER_VISION from ...util.model import BenchmarkModel class Model(BenchmarkModel): + task = COMPUTER_VISION.CLASSIFICATION + def __init__(self, device=None, jit=False): super().__init__() self.device = device diff --git a/torchbenchmark/models/pyhpc_equation_of_state/__init__.py b/torchbenchmark/models/pyhpc_equation_of_state/__init__.py index 4aa4d1c8e3..62573104be 100644 --- a/torchbenchmark/models/pyhpc_equation_of_state/__init__.py +++ b/torchbenchmark/models/pyhpc_equation_of_state/__init__.py @@ -1,5 +1,6 @@ import torch from . import eos_pytorch +from torchbenchmark.tasks import OTHER from ...util.model import BenchmarkModel def _generate_inputs(size): @@ -26,6 +27,8 @@ def forward(self, s, t, p): return eos_pytorch.gsw_dHdT(s, t, p) class Model(BenchmarkModel): + task = OTHER.OTHER_TASKS + def __init__(self, device=None, jit=False): super().__init__() self.device = device diff --git a/torchbenchmark/models/pyhpc_isoneutral_mixing/__init__.py b/torchbenchmark/models/pyhpc_isoneutral_mixing/__init__.py index 998f73108b..99bacfdde9 100644 --- a/torchbenchmark/models/pyhpc_isoneutral_mixing/__init__.py +++ b/torchbenchmark/models/pyhpc_isoneutral_mixing/__init__.py @@ -1,5 +1,6 @@ import torch from . import isoneutral_pytorch +from torchbenchmark.tasks import OTHER from ...util.model import BenchmarkModel @@ -122,6 +123,8 @@ def forward( class Model(BenchmarkModel): + task = OTHER.OTHER_TASKS + def __init__(self, device=None, jit=False): super().__init__() self.device = device