diff --git a/mypy.ini b/mypy.ini index 006b36276727..5307c6a29f09 100644 --- a/mypy.ini +++ b/mypy.ini @@ -128,9 +128,6 @@ ignore_errors = True [mypy-torch.nn.quantized.modules.conv] ignore_errors = True -[mypy-torch.cuda] -ignore_errors = True - [mypy-torch._lobpcg] ignore_errors = True @@ -140,9 +137,6 @@ ignore_errors = True [mypy-torch._utils] ignore_errors = True -[mypy-torch._overrides] -ignore_errors = True - [mypy-torch.utils.tensorboard._caffe2_graph] ignore_errors = True diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 52499efe2eca..8e1cfe58125e 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -667,11 +667,14 @@ def gen_pyi(declarations_path, out): # TODO: These are deprecated, maybe we shouldn't type hint them legacy_storage_base_hints = [] - for c in ('Double', 'Float', 'Long', 'Int', - 'Short', 'Char', 'Byte', 'Bool', - 'Half', 'BFloat16', 'ComplexDouble', - 'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32', 'QUInt4x2'): + dt = ('Double', 'Float', 'Long', 'Int', + 'Short', 'Char', 'Byte', 'Bool', + 'Half', 'BFloat16', 'ComplexDouble', + 'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32', 'QUInt4x2') + for c in dt: legacy_storage_base_hints.append('class {}StorageBase(object): ...'.format(c)) + for c in dt: + legacy_storage_base_hints.append('class Cuda{}StorageBase(object): ...'.format(c)) legacy_class_hints = [] for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor', diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index b05d588972f9..c3dd4af34493 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -561,6 +561,13 @@ def _cuda_getCurrentStream(device: _int) -> _int: ... def _cuda_getDefaultStream(device: _int) -> _int: ... def _cuda_getCurrentBlasHandle() -> _int: ... def _cuda_setDevice(device: _int) -> None: ... +def _cuda_getDevice() -> _int: ... +def _cuda_getDeviceCount() -> _int: ... +def _cuda_sleep(cycles: _int) -> None: ... +def _cuda_synchronize() -> None: ... +def _cuda_ipc_collect() -> None: ... +def _cuda_getArchFlags() -> Optional[str]: ... +def _cuda_init() -> None: ... def _cuda_setStream(cuda_stream: _int) -> None: ... def _cuda_getCompiledVersion() -> _int: ... def _cuda_cudaHostAllocator() -> _int: ... diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 1176c6ee3060..0850b535fe30 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -21,7 +21,7 @@ import torch._C try: - from torch._C import _cudart + from torch._C import _cudart # type: ignore except ImportError: _cudart = None @@ -30,18 +30,18 @@ _initialization_lock = threading.Lock() _queued_calls = [] # don't invoke these until initialization occurs _is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False) -_device_t = Union[_device, str, int] +_device_t = Union[_device, str, int, None] # Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA if hasattr(torch._C, '_CudaDeviceProperties'): _CudaDeviceProperties = torch._C._CudaDeviceProperties else: - _CudaDeviceProperties = _dummy_type('_CudaDeviceProperties') + _CudaDeviceProperties = _dummy_type('_CudaDeviceProperties') # type: ignore # Global variables dynamically populated by native code has_magma: bool = False has_half: bool = False -default_generators: Tuple[torch._C.Generator] = () +default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment] def is_available() -> bool: r"""Returns a bool indicating if CUDA is currently available.""" @@ -297,7 +297,7 @@ def get_device_properties(device: _device_t) -> _CudaDeviceProperties: device = _get_device_index(device, optional=True) if device < 0 or device >= device_count(): raise AssertionError("Invalid device id") - return _get_device_properties(device) + return _get_device_properties(device) # type: ignore[name-defined] @contextlib.contextmanager @@ -356,8 +356,8 @@ def get_gencode_flags() -> str: arch_list = get_arch_list() if len(arch_list) == 0: return "" - arch_list = [arch.split("_") for arch in arch_list] - return " ".join([f"-gencode compute=compute_{arch},code={kind}_{arch}" for (kind, arch) in arch_list]) + arch_list_ = [arch.split("_") for arch in arch_list] + return " ".join([f"-gencode compute=compute_{arch},code={kind}_{arch}" for (kind, arch) in arch_list_]) @@ -454,7 +454,7 @@ def current_blas_handle(): torch._C.__dict__['_CudaEventBase'] = _dummy_type('CudaEventBase') -@staticmethod +@staticmethod # type: ignore[misc] def _lazy_new(cls, *args, **kwargs): _lazy_init() # We may need to call lazy init again if we are a forked child @@ -467,8 +467,11 @@ class _CudaBase(object): is_sparse = False def type(self, *args, **kwargs): - with device(self.get_device()): - return super(_CudaBase, self).type(*args, **kwargs) + # We could use a Protocol here to tell mypy that self has `get_device` method + # but it is only available in the typing module on Python >= 3.8 + # or on typing_extensions module on Python >= 3.6 + with device(self.get_device()): # type: ignore + return super(_CudaBase, self).type(*args, **kwargs) # type: ignore[misc] __new__ = _lazy_new diff --git a/torch/cuda/_utils.py b/torch/cuda/_utils.py index 1239dff4588f..8f4105623a98 100644 --- a/torch/cuda/_utils.py +++ b/torch/cuda/_utils.py @@ -5,7 +5,7 @@ from torch._utils import _get_device_index as _torch_get_device_index -def _get_device_index(device: Union[Device, str, int], optional: bool = False, +def _get_device_index(device: Union[Device, str, int, None], optional: bool = False, allow_cpu: bool = False) -> int: r"""Gets the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.