diff --git a/test/smoke_test.py b/test/smoke_test.py index 9ffc9117773..35e079c31d1 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -27,6 +27,7 @@ def smoke_test_torchvision_read_decode() -> None: if img_png.ndim != 3 or img_png.numel() < 100: raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}") + def smoke_test_compile() -> None: model = resnet50().cuda() model = torch.compile(model) @@ -34,6 +35,7 @@ def smoke_test_compile() -> None: out = model(x) print(f"torch.compile model output: {out.shape}") + def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device) @@ -73,6 +75,5 @@ def main() -> None: smoke_test_torchvision_resnet50_classify("mps") - if __name__ == "__main__": main()