diff --git a/test/smoke_test.py b/test/smoke_test.py index 35e079c31d1..63b35d04bed 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -2,6 +2,7 @@ import os from pathlib import Path +from sys import platform import torch import torch.nn as nn @@ -29,11 +30,17 @@ def smoke_test_torchvision_read_decode() -> None: def smoke_test_compile() -> None: - model = resnet50().cuda() - model = torch.compile(model) - x = torch.randn(1, 3, 224, 224, device="cuda") - out = model(x) - print(f"torch.compile model output: {out.shape}") + try: + model = resnet50().cuda() + model = torch.compile(model) + x = torch.randn(1, 3, 224, 224, device="cuda") + out = model(x) + print(f"torch.compile model output: {out.shape}") + except RuntimeError: + if platform == "win32": + print("Successfully caught torch.compile RuntimeError on win") + else: + raise def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: