Skip to content

Commit 1330440

Browse files
shangerxinguangyey
authored andcommitted
Port 4 dynamo test files for the intel XPU (#160953)
# Description Fixes #114850, we will port dynamo tests to Intel GPU We could enable Intel GPU with following methods and try the best to keep the original code styles: # Changes 1. Get device type from accelerator method. 2. Replace the requires cuda statement with requires_gpu. 3. Add HAS_XPU_AND_TRITON into the scope. 4. Add several wrapper methods in cuda module into the accelerator. # Notify Pull Request resolved: #160953 Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/jansel Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
1 parent 8e48d1b commit 1330440

File tree

4 files changed

+30
-13
lines changed

4 files changed

+30
-13
lines changed

test/dynamo/test_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def forward(self, x):
386386
self.assertTrue(backend_run)
387387

388388

389-
devices = ["cpu", "cuda", "hpu"]
389+
devices = ["cpu", "cuda", "hpu", "xpu"]
390390
instantiate_device_type_tests(TestOptimizations, globals(), only_for=devices)
391391

392392
if __name__ == "__main__":

test/dynamo/test_callback.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from torch._dynamo.test_case import run_tests, TestCase
99
from torch._guards import CompileId
1010
from torch.testing._internal.common_utils import TEST_WITH_ROCM
11-
from torch.testing._internal.triton_utils import requires_cuda_and_triton
11+
from torch.testing._internal.triton_utils import HAS_CUDA_AND_TRITON, requires_gpu
12+
13+
14+
device_type = (
15+
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
16+
)
1217

1318

1419
class CallbackTests(TestCase):
@@ -61,7 +66,7 @@ def test_counter_assertion(self) -> None:
6166
@unittest.skipIf(
6267
TEST_WITH_ROCM, "ROCm outputs a different number of autotuning logs"
6368
)
64-
@requires_cuda_and_triton
69+
@requires_gpu
6570
@torch._inductor.config.patch(force_disable_caches=True)
6671
def test_triggers(self) -> None:
6772
torch._dynamo.reset()
@@ -91,9 +96,9 @@ def forward(self, x):
9196
torch._dynamo.graph_break()
9297
return self.fc2(temp)
9398

94-
model = TinyModel().to("cuda")
99+
model = TinyModel().to(device_type)
95100
compiled_model = torch.compile(model, mode="max-autotune")
96-
x = torch.randn(10, 10, device="cuda")
101+
x = torch.randn(10, 10, device=device_type)
97102

98103
loss = compiled_model(x).sum()
99104
loss.backward()
@@ -111,9 +116,13 @@ def forward(self, x):
111116
)
112117
order.clear()
113118

119+
if not HAS_CUDA_AND_TRITON:
120+
return
121+
114122
compiled_model.zero_grad()
115123
loss = compiled_model(x).sum()
116124
loss.backward()
125+
117126
self.assertExpectedInline(
118127
"\n".join(order),
119128
"""\

test/dynamo/test_functions.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,16 @@
4040
instantiate_parametrized_tests,
4141
parametrize,
4242
)
43+
from torch.testing._internal.inductor_utils import HAS_GPU
4344

4445
# Defines all the kernels for tests
4546
from torch.testing._internal.triton_utils import * # noqa: F403
4647

4748

49+
device_type = (
50+
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
51+
)
52+
4853
T = TypeVar("T")
4954

5055
d = torch.ones(10, 10)
@@ -1150,10 +1155,10 @@ def test_tensor_type(a, b):
11501155
m = a.to(torch.float16)
11511156
return b.type(m.type())
11521157

1153-
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
1158+
@unittest.skipIf(not HAS_GPU, "requires gpu")
11541159
@make_test
11551160
def test_tensor_type2(a, b):
1156-
m = a.to("cuda")
1161+
m = a.to(device_type)
11571162
return m + b.type(m.type())
11581163

11591164
@make_test
@@ -4040,7 +4045,7 @@ def test_torch_get_device_module(self):
40404045
def f1():
40414046
mod1 = torch.get_device_module()
40424047
mod2 = torch.get_device_module("cpu")
4043-
mod3 = torch.get_device_module(torch.device("cuda"))
4048+
mod3 = torch.get_device_module(torch.device(device_type))
40444049
return mod1, mod2, mod3
40454050

40464051
self.assertEqual(f1(), torch.compile(f1, backend="eager", fullgraph=True)())
@@ -4075,6 +4080,7 @@ def f5():
40754080
new_device = (
40764081
"cpu" if torch._C._get_accelerator() == torch.device("cuda") else "cuda"
40774082
)
4083+
40784084
old_get_device_module = torch.get_device_module
40794085

40804086
def new_get_device_module(device=None):
@@ -4721,22 +4727,24 @@ def fn(x, ys, zs):
47214727
opt_fn(x, ys, zs[:1])
47224728

47234729
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
4724-
def test_cuda_current_device(self):
4730+
def test_gpu_current_device(self):
47254731
def fn(x):
47264732
y = torch.empty(
4727-
(2, 3), dtype=torch.float32, device=torch.cuda.current_device()
4733+
(2, 3),
4734+
dtype=torch.float32,
4735+
device=torch.accelerator.current_device_index(),
47284736
)
47294737
y.copy_(x)
47304738
return torch.sin(y + y.device.index)
47314739

47324740
counter = torch._dynamo.testing.CompileCounter()
47334741
opt_fn = torch.compile(backend=counter, fullgraph=True)(fn)
47344742

4735-
with torch.cuda.device(0):
4743+
with torch.accelerator.device_index(0):
47364744
x = torch.randn(2, 3)
47374745
self.assertEqual(opt_fn(x), fn(x))
47384746
self.assertEqual(counter.frame_count, 1)
4739-
with torch.cuda.device(1):
4747+
with torch.accelerator.device_index(1):
47404748
self.assertEqual(opt_fn(x), fn(x))
47414749
self.assertEqual(counter.frame_count, 2)
47424750

test/dynamo/test_misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13293,7 +13293,7 @@ def f(rank):
1329313293
self.assertEqual(out, opt_out)
1329413294

1329513295
@unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU")
13296-
def test_cuda_set_device(self, device):
13296+
def test_gpu_set_device(self, device):
1329713297
def fn():
1329813298
a = torch.ones(2, device=device)
1329913299
torch.get_device_module(device).set_device(1)

0 commit comments

Comments
 (0)