# DataPipe Typing System

DataPipe typing system is introduced to make the graph of DataPipes more reliable and provide type inference for users. The typing system provide the flexibility for users to determine which level(s) to have type enforcement and risk false positive errors.

In [1]:
from torch.utils.data import IterDataPipe
from typing import Any, Iterator, List, Tuple, TypeVar, Set, Union

T_co = TypeVar('T_co', covariant=True)

In [2]:
# Hide traceback of Error
import functools
ipython = get_ipython()
ipython.showtraceback = functools.partial(ipython.showtraceback, exception_only=True)

## Compile-time
Compile-time typing is enabled by default for now. And it will generate an attribute of `type` for each DataPipe. If there is no type hint specified, the DataPipe is set to a default type `Any`.

### Invalid Typing
1. Return type hint of `__iter__` is not `Iterator`

In [3]:
class InvalidDP1(IterDataPipe[int]):
    def __iter__(self) -> str:
        pass

TypeError: Expected 'Iterator' as the return annotation for `__iter__` of InvalidDP1, but found str

2. Return type hint of `__iter__` doesn't match the declared type hint

In [4]:
class InvalidDP2(IterDataPipe[Tuple]):
    def __iter__(self) -> Iterator[Tuple[int, str]]:
        pass

TypeError: Unmatched type annotation for InvalidDP2 (typing.Tuple vs typing.Tuple[int, str])

In [5]:
class InvalidDP3(IterDataPipe):
    def __iter__(self) -> Iterator[int]:
        pass

TypeError: Unmatched type annotation for InvalidDP3 (typing.Any vs int)

### Valid Typing
1. Default Typing (Any) with/without return hint for `__iter__`

In [6]:
class DP(IterDataPipe):
    def __iter__(self):
        pass
class DP(IterDataPipe):
    def __iter__(self) -> Iterator:
        pass
class DP(IterDataPipe):
    def __iter__(self) -> Iterator[T_co]:
        pass
print(DP.type)

typing.Any


2. Matched type hints (including equal but not same types)

In [7]:
class DP(IterDataPipe[Tuple[T_co, str]]):
    def __iter__(self) -> Iterator[Tuple[T_co, str]]:
        pass

T = TypeVar('T', int, str)  # equals to Union[int, str]
class DP(IterDataPipe[Tuple[T, str]]):
    def __iter__(self) -> Iterator[Tuple[Union[int, str], str]]:
        pass

### Attribute `type`
The attribute `type` is added into each DataPipe class.

1. Fixed `type`

    If `type` is a fixed type, the `type` of each instance will be referenced from class like `int`, `str`, `Tuple[int, str]`, and etc.

In [8]:
class DP(IterDataPipe[Tuple[int, str]]):
    def __iter__(self) -> Iterator[Tuple[int, str]]:
        pass
dp = DP()
print(DP.type, dp.type, id(DP.type) == id(dp.type))

typing.Tuple[int, str] typing.Tuple[int, str] True


In [9]:
class DP(IterDataPipe[Union[int, str]]):
    def __iter__(self) -> Iterator[Union[int, str]]:
        pass
dp = DP()
print(DP.type, dp.type, id(DP.type) == id(dp.type))

typing.Union[int, str] typing.Union[int, str] True


2. Non-fixed `type`
     If `type` is a non-fixed type, the `type` attribute will be copied for each instance like `T_co`, `tuple`, and etc.

In [10]:
class DP(IterDataPipe[Any]):
    def __iter__(self) -> Iterator[Any]:
        pass
dp = DP()
print(DP.type, dp.type, id(DP.type) == id(dp.type))

typing.Any typing.Any False


In [11]:
class DP(IterDataPipe[tuple]):
    def __iter__(self) -> Iterator[tuple]:
        pass
dp = DP()
print(DP.type, dp.type, id(DP.type) == id(dp.type))

tuple tuple True
