Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/xpu_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
test:
# Don't run on forked repos or empty test matrix
# if: github.repository_owner == 'pytorch' && toJSON(fromJSON(inputs.test-matrix).include) != '[]'
timeout-minutes: 60
timeout-minutes: 120
runs-on: linux.idc.xpu
env:
DOCKER_IMAGE: ci-image:pytorch-linux-noble-xpu-n-py3
Expand Down Expand Up @@ -166,7 +166,7 @@ jobs:
GITHUB_RUN_NUMBER: ${{ github.run_number }}
GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }}
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
timeout-minutes: 60
timeout-minutes: 120
run: |
set -x

Expand Down
16 changes: 9 additions & 7 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from torch.testing._internal.common_utils import (
TEST_CUDA,
TEST_XPU,
TemporaryFileName,
instantiate_parametrized_tests,
parametrize,
Expand Down Expand Up @@ -68,9 +69,10 @@
QuantizationConfig,
)
from torchao.testing.pt2e.utils import PT2EQuantizationTestCase
from torchao.utils import torch_version_at_least
from torchao.utils import get_current_accelerator_device, torch_version_at_least

DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else [])
DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else []) + (["xpu"] if TEST_XPU else [])
_DEVICE = get_current_accelerator_device()

if torch_version_at_least("2.7.0"):
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -2057,7 +2059,7 @@ def __init__(self) -> None:
def forward(self, x):
return self.bn(x)

if TEST_CUDA or TEST_HPU:
if TEST_CUDA or TEST_HPU or TEST_XPU:
m = M().train().to(device)
example_inputs = (torch.randn((1, 3, 3, 3), device=device),)

Expand Down Expand Up @@ -2132,9 +2134,9 @@ def forward(self, x):
x = self.dropout(x)
return x

if TEST_CUDA:
m = M().train().cuda()
example_inputs = (torch.randn(1, 3, 3, 3).cuda(),)
if TEST_CUDA or TEST_XPU:
m = M().train().to(_DEVICE)
example_inputs = (torch.randn(1, 3, 3, 3).to(_DEVICE),)
else:
m = M().train()
example_inputs = (torch.randn(1, 3, 3, 3),)
Expand All @@ -2146,7 +2148,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool):
bn_op = bn_train_op if train else bn_eval_op
bn_node = self._get_node(m, bn_op)
self.assertTrue(bn_node is not None)
if TEST_CUDA:
if TEST_CUDA or TEST_XPU:
self.assertEqual(bn_node.args[5], train)
dropout_node = self._get_node(m, torch.ops.aten.dropout.default)
self.assertEqual(dropout_node.args[2], train)
Expand Down
18 changes: 10 additions & 8 deletions test/quantization/pt2e/test_quantize_pt2e_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
skipIfNoQNNPACK,
)
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import TEST_XPU, run_tests

from torchao.quantization.pt2e import (
FusedMovingAvgObsFakeQuantize,
Expand All @@ -52,7 +52,9 @@
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
from torchao.utils import torch_version_at_least
from torchao.utils import get_current_accelerator_device, torch_version_at_least

_DEVICE = get_current_accelerator_device()


class PT2EQATTestCase(QuantizationTestCase):
Expand Down Expand Up @@ -453,10 +455,10 @@ def test_qat_conv_bn_fusion(self):
self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=False)
self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "GPU unavailable")
def test_qat_conv_bn_fusion_cuda(self):
m = self._get_conv_bn_model().cuda()
example_inputs = (self.example_inputs[0].cuda(),)
m = self._get_conv_bn_model().to(_DEVICE)
example_inputs = (self.example_inputs[0].to(_DEVICE),)
self._verify_symmetric_xnnpack_qat_graph(
m,
example_inputs,
Expand Down Expand Up @@ -540,10 +542,10 @@ def test_qat_conv_bn_relu_fusion(self):
self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=True)
self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "GPU unavailable")
def test_qat_conv_bn_relu_fusion_cuda(self):
m = self._get_conv_bn_model(has_relu=True).cuda()
example_inputs = (self.example_inputs[0].cuda(),)
m = self._get_conv_bn_model(has_relu=True).to(_DEVICE)
example_inputs = (self.example_inputs[0].to(_DEVICE),)
self._verify_symmetric_xnnpack_qat_graph(
m,
example_inputs,
Expand Down
Loading