Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions torch_patches/X10-device_test.diff
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
index a23696c29c..d7eb93373b 100644
index 01973711e7..48e5e202ed 100644
--- a/torch/testing/_internal/common_device_type.py
+++ b/torch/testing/_internal/common_device_type.py
@@ -3,6 +3,7 @@ import threading
from functools import wraps
@@ -4,6 +4,7 @@ from functools import wraps
import unittest
import os
import torch
+import copy
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
skipCUDANonDefaultStreamIf

@@ -191,7 +192,7 @@ class DeviceTypeTestBase(TestCase):
@@ -219,7 +220,7 @@ class DeviceTypeTestBase(TestCase):
# Sets precision and runs test
# Note: precision is reset after the test is run
guard_precision = self.precision
Expand All @@ -19,10 +19,10 @@ index a23696c29c..d7eb93373b 100644
self.precision = self._get_precision_override(test, dtype)
result = test(self, device_arg, dtype)
finally:
@@ -242,10 +243,103 @@ class CUDATestBase(DeviceTypeTestBase):
@@ -270,10 +271,103 @@ class CUDATestBase(DeviceTypeTestBase):
cls.primary_device = 'cuda:{0}'.format(torch.cuda.current_device())


+import torch_xla
+import torch_xla.core.xla_model as xm
+
Expand Down Expand Up @@ -107,28 +107,28 @@ index a23696c29c..d7eb93373b 100644
+ torch_xla._XLAC._xla_set_use_full_mat_mul_precision(use_full_mat_mul_precision=True)
+
+ # Overrides assertEqual to popular custom precision
+ def assertEqual(self, x, y, prec=None, message='', allow_inf=False):
+ def assertEqual(self, x, y, prec=None, message='', allow_inf=False, **kwargs):
+ if prec is None:
+ prec = self.precision
+ else:
+ prec = max(self.precision, prec)
+ return DeviceTypeTestBase.assertEqual(self, x, y, prec, message, allow_inf)
+ return DeviceTypeTestBase.assertEqual(self, x, y, prec, message, allow_inf, **kwargs)
+
+
# Adds available device-type-specific test base classes
device_type_test_bases.append(CPUTestBase)
if torch.cuda.is_available():
device_type_test_bases.append(CUDATestBase)
+device_type_test_bases.append(XLATestBase)


# Adds 'instantiated' device-specific test cases to the given scope.
@@ -289,7 +383,7 @@ def instantiate_device_type_tests(generic_test_class, scope, except_for=None):
PYTORCH_CUDA_MEMCHECK = os.getenv('PYTORCH_CUDA_MEMCHECK', '0') == '1'

@@ -320,7 +414,7 @@ def instantiate_device_type_tests(generic_test_class, scope, except_for=None):
assert inspect.isfunction(test), "Couldn't extract function from '{0}'".format(name)

# Instantiates the device-specific tests
- device_type_test_class.instantiate_test(name, test)
+ device_type_test_class.instantiate_test(name, copy.deepcopy(test))
else: # Ports non-test member
assert not hasattr(device_type_test_class, name), "Redefinition of non-test member {0}".format(name)

assert name not in device_type_test_class.__dict__, "Redefinition of directly defined member {0}".format(name)