#### Standard flow control and data processing torchdata.nodes

#### All torchdata.nodes.BaseNode implementations are Iterators, adhering to the following API:
```Python
class BaseNode(Iterator[T]):
    def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
        """Resets the node to its initial state or a specified state."""
        ...
    def __next__(self) -> T:
        """Returns the next value in the sequence."""
        ...
    def get_state(self) -> Dict[str, Any]:
        """Returns a dictionary representing the current state of the node."""
        ...
```
#### This standardized interface enables seamless chaining of iterators, allowing for flexible, efficient, and composable data processing pipelines.

In [78]:
dataset = range(10)

In [79]:
from torchdata.nodes import IterableWrapper
#Thin Wrapper that converts any Iterable in to a BaseNode.

In [80]:
base_node = IterableWrapper(dataset)


In [81]:
for item in node:
    print(item)

In [82]:
# For people accustomed to torch.data.utils style dataloders and samplers
sampler = RandomSampler(dataset)
node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)

for item in node:
    print(item)

2
1
9
5
0
6
7
4
8
3


In [83]:
from torchdata.nodes import Mapper, IterableWrapper, Loader, Batcher, ParallelMapper
# Now we can set up some torchdata.nodes to create our pre-proc pipeline
from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader
from torch.utils.data import default_collate, RandomSampler, SequentialSampler

In [84]:
# what's a basenode, print BaseNode API definition
# All torchdata.nodes.BaseNode implementations are Iterators.
```Python
class BaseNode(Iterator[T]):
    def reset(self, initial_state: Optional[Dict[str, Any]] = None): ...
    def next(self): ...
    def get_state(self) -> Dict[str, Any]: ...
```
# Base node is 
# All nodes adhere to this API, you can chain iterators

SyntaxError: invalid syntax (2768665799.py, line 3)

In [85]:
while True:
    try:
        print(next(node))
    except StopIteration:
        break

In [86]:
mapped_dataset = Mapper(base_node, map_fn = lambda x : x**2)
for item in mapped_dataset:
    print(item)

0
1
4
9
16
25
36
49
64
81


In [90]:
mapped_dataset = ParallelMapper(base_node, map_fn = lambda x : x**2, num_workers =2, method = "thread")
for item in mapped_dataset:
    print(item)

0
1
4
9
16
25
36
49
64
81


In [92]:
batched_dataset = Batcher(base_node, batch_size = 4)
for batch in batched_dataset:
    print(batch)

[0, 1, 2, 3]
[4, 5, 6, 7]


In [93]:
batched_dataset = Batcher(base_node, batch_size = 4, drop_last = False)
for batch in batched_dataset:
    print(batch)

[0, 1, 2, 3]
[4, 5, 6, 7]
[8, 9]
