From 3aed456541ea553852acb7d348bea7dcda6491be Mon Sep 17 00:00:00 2001 From: atalman Date: Fri, 3 Mar 2023 09:51:39 -0800 Subject: [PATCH 1/3] Add guard to execute compile only on linux --- test/smoke_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/smoke_test.py b/test/smoke_test.py index 35e079c31d1..860f2c01ce0 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -2,6 +2,7 @@ import os from pathlib import Path +from sys import platform import torch import torch.nn as nn @@ -69,7 +70,8 @@ def main() -> None: smoke_test_torchvision_resnet50_classify() if torch.cuda.is_available(): smoke_test_torchvision_resnet50_classify("cuda") - smoke_test_compile() + if platform == "linux" or platform == "linux2": + smoke_test_compile() if torch.backends.mps.is_available(): smoke_test_torchvision_resnet50_classify("mps") From 38009f7a2b71d174b5d932a09896790613958ab7 Mon Sep 17 00:00:00 2001 From: atalman Date: Fri, 3 Mar 2023 10:03:33 -0800 Subject: [PATCH 2/3] Fix smoke test to check for exception --- test/smoke_test.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/test/smoke_test.py b/test/smoke_test.py index 860f2c01ce0..3662f1f4936 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -30,12 +30,17 @@ def smoke_test_torchvision_read_decode() -> None: def smoke_test_compile() -> None: - model = resnet50().cuda() - model = torch.compile(model) - x = torch.randn(1, 3, 224, 224, device="cuda") - out = model(x) - print(f"torch.compile model output: {out.shape}") - + try: + model = resnet50().cuda() + model = torch.compile(model) + x = torch.randn(1, 3, 224, 224, device="cuda") + out = model(x) + print(f"torch.compile model output: {out.shape}") + except RuntimeError: + if platform == "win32": + print(f"Successfully caught torch.compile RuntimeError on win") + else: + raise def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device) @@ -70,8 +75,7 @@ def main() -> None: smoke_test_torchvision_resnet50_classify() if torch.cuda.is_available(): smoke_test_torchvision_resnet50_classify("cuda") - if platform == "linux" or platform == "linux2": - smoke_test_compile() + smoke_test_compile() if torch.backends.mps.is_available(): smoke_test_torchvision_resnet50_classify("mps") From 9c6e994b280ff4b0997bb2500afca93414370a41 Mon Sep 17 00:00:00 2001 From: atalman Date: Fri, 3 Mar 2023 10:18:22 -0800 Subject: [PATCH 3/3] fix lint --- test/smoke_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/smoke_test.py b/test/smoke_test.py index 3662f1f4936..63b35d04bed 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -38,10 +38,11 @@ def smoke_test_compile() -> None: print(f"torch.compile model output: {out.shape}") except RuntimeError: if platform == "win32": - print(f"Successfully caught torch.compile RuntimeError on win") + print("Successfully caught torch.compile RuntimeError on win") else: raise + def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)