Skip to content

Commit

Permalink
Add type informations to torch.cuda (#47134)
Browse files Browse the repository at this point in the history
Summary:
Fixes #47133

Pull Request resolved: #47134

Reviewed By: smessmer

Differential Revision: D24955031

Pulled By: ezyang

fbshipit-source-id: 87f4623643715baa6ac0627383f009956f80cd46
  • Loading branch information
guilhermeleobas authored and facebook-github-bot committed Nov 14, 2020
1 parent 2eb1e86 commit 4f9d075
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 21 deletions.
6 changes: 0 additions & 6 deletions mypy.ini
Expand Up @@ -128,9 +128,6 @@ ignore_errors = True
[mypy-torch.nn.quantized.modules.conv]
ignore_errors = True

[mypy-torch.cuda]
ignore_errors = True

[mypy-torch._lobpcg]
ignore_errors = True

Expand All @@ -140,9 +137,6 @@ ignore_errors = True
[mypy-torch._utils]
ignore_errors = True

[mypy-torch._overrides]
ignore_errors = True

[mypy-torch.utils.tensorboard._caffe2_graph]
ignore_errors = True

Expand Down
11 changes: 7 additions & 4 deletions tools/pyi/gen_pyi.py
Expand Up @@ -667,11 +667,14 @@ def gen_pyi(declarations_path, out):

# TODO: These are deprecated, maybe we shouldn't type hint them
legacy_storage_base_hints = []
for c in ('Double', 'Float', 'Long', 'Int',
'Short', 'Char', 'Byte', 'Bool',
'Half', 'BFloat16', 'ComplexDouble',
'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32', 'QUInt4x2'):
dt = ('Double', 'Float', 'Long', 'Int',
'Short', 'Char', 'Byte', 'Bool',
'Half', 'BFloat16', 'ComplexDouble',
'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32', 'QUInt4x2')
for c in dt:
legacy_storage_base_hints.append('class {}StorageBase(object): ...'.format(c))
for c in dt:
legacy_storage_base_hints.append('class Cuda{}StorageBase(object): ...'.format(c))

legacy_class_hints = []
for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
Expand Down
7 changes: 7 additions & 0 deletions torch/_C/__init__.pyi.in
Expand Up @@ -561,6 +561,13 @@ def _cuda_getCurrentStream(device: _int) -> _int: ...
def _cuda_getDefaultStream(device: _int) -> _int: ...
def _cuda_getCurrentBlasHandle() -> _int: ...
def _cuda_setDevice(device: _int) -> None: ...
def _cuda_getDevice() -> _int: ...
def _cuda_getDeviceCount() -> _int: ...
def _cuda_sleep(cycles: _int) -> None: ...
def _cuda_synchronize() -> None: ...
def _cuda_ipc_collect() -> None: ...
def _cuda_getArchFlags() -> Optional[str]: ...
def _cuda_init() -> None: ...
def _cuda_setStream(cuda_stream: _int) -> None: ...
def _cuda_getCompiledVersion() -> _int: ...
def _cuda_cudaHostAllocator() -> _int: ...
Expand Down
23 changes: 13 additions & 10 deletions torch/cuda/__init__.py
Expand Up @@ -21,7 +21,7 @@
import torch._C

try:
from torch._C import _cudart
from torch._C import _cudart # type: ignore
except ImportError:
_cudart = None

Expand All @@ -30,18 +30,18 @@
_initialization_lock = threading.Lock()
_queued_calls = [] # don't invoke these until initialization occurs
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
_device_t = Union[_device, str, int]
_device_t = Union[_device, str, int, None]

# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA
if hasattr(torch._C, '_CudaDeviceProperties'):
_CudaDeviceProperties = torch._C._CudaDeviceProperties
else:
_CudaDeviceProperties = _dummy_type('_CudaDeviceProperties')
_CudaDeviceProperties = _dummy_type('_CudaDeviceProperties') # type: ignore

# Global variables dynamically populated by native code
has_magma: bool = False
has_half: bool = False
default_generators: Tuple[torch._C.Generator] = ()
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]

def is_available() -> bool:
r"""Returns a bool indicating if CUDA is currently available."""
Expand Down Expand Up @@ -297,7 +297,7 @@ def get_device_properties(device: _device_t) -> _CudaDeviceProperties:
device = _get_device_index(device, optional=True)
if device < 0 or device >= device_count():
raise AssertionError("Invalid device id")
return _get_device_properties(device)
return _get_device_properties(device) # type: ignore[name-defined]


@contextlib.contextmanager
Expand Down Expand Up @@ -356,8 +356,8 @@ def get_gencode_flags() -> str:
arch_list = get_arch_list()
if len(arch_list) == 0:
return ""
arch_list = [arch.split("_") for arch in arch_list]
return " ".join([f"-gencode compute=compute_{arch},code={kind}_{arch}" for (kind, arch) in arch_list])
arch_list_ = [arch.split("_") for arch in arch_list]
return " ".join([f"-gencode compute=compute_{arch},code={kind}_{arch}" for (kind, arch) in arch_list_])



Expand Down Expand Up @@ -454,7 +454,7 @@ def current_blas_handle():
torch._C.__dict__['_CudaEventBase'] = _dummy_type('CudaEventBase')


@staticmethod
@staticmethod # type: ignore[misc]
def _lazy_new(cls, *args, **kwargs):
_lazy_init()
# We may need to call lazy init again if we are a forked child
Expand All @@ -467,8 +467,11 @@ class _CudaBase(object):
is_sparse = False

def type(self, *args, **kwargs):
with device(self.get_device()):
return super(_CudaBase, self).type(*args, **kwargs)
# We could use a Protocol here to tell mypy that self has `get_device` method
# but it is only available in the typing module on Python >= 3.8
# or on typing_extensions module on Python >= 3.6
with device(self.get_device()): # type: ignore
return super(_CudaBase, self).type(*args, **kwargs) # type: ignore[misc]

__new__ = _lazy_new

Expand Down
2 changes: 1 addition & 1 deletion torch/cuda/_utils.py
Expand Up @@ -5,7 +5,7 @@
from torch._utils import _get_device_index as _torch_get_device_index


def _get_device_index(device: Union[Device, str, int], optional: bool = False,
def _get_device_index(device: Union[Device, str, int, None], optional: bool = False,
allow_cpu: bool = False) -> int:
r"""Gets the device index from :attr:`device`, which can be a torch.device
object, a Python integer, or ``None``.
Expand Down

0 comments on commit 4f9d075

Please sign in to comment.