### `torchdata.nodes` Basics

All torchdata.nodes.BaseNode implementations are Iterators, adhering to the following API:
```Python
class BaseNode(Iterator[T]):
    def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
        """Resets the node to its initial state or a specified state."""
        ...
    def __next__(self) -> T:
        """Returns the next value in the sequence."""
        ...
    def get_state(self) -> Dict[str, Any]:
        """Returns a dictionary representing the current state of the node."""
        ...
```
This standardized interface enables seamless chaining of iterators, allowing for flexible, efficient, and composable data processing pipelines.

Let's see the functionalities of `torchdata.nodes` through the help of a very simple example.

#### IterableWrapper

In [1]:
from torchdata.nodes import IterableWrapper
# This Wrapper converts any Iterable in to a BaseNode.

dataset = range(10) # creating a very simple dataset, and then converting it into a BaseNode
source = IterableWrapper(dataset)

In [2]:
# Let's take a look at the items in the node
for item in source:
    print(item)

0
1
2
3
4
5
6
7
8
9


#### Integrating with torch.data Dataloaders and Samplers

We can also use `torch.data.utils` style dataloaders and samplers, and then wrap them into nodes.
Please refer to this [migration guide](https://pytorch.org/data/main/migrate_to_nodes_from_utils.html) to migrate from torch.utils.data to torchdata.nodes


In [3]:
from torchdata.nodes import MapStyleWrapper
from torch.utils.data import RandomSampler

sampler = RandomSampler(dataset)
node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)

for item in node:
    print(item)

6
9
2
1
7
5
4
3
8
0


#### Map

We can use the Mapper class, to apply a transformation defined using the `map_fn`

In [4]:
from torchdata.nodes import Mapper
node = Mapper(source, map_fn = lambda x : x**2)
for item in node:
    print(item)

0
1
4
9
16
25
36
49
64
81


It can also be executed in parallel, using the multi threading/processing approaches, depending on the defined `method`

In [5]:
from torchdata.nodes import ParallelMapper
mapper = ParallelMapper(source, map_fn = lambda x : x**2, num_workers =2, method = "thread")
for item in mapper:
    print(item)

0
1
4
9
16
25
36
49
64
81


#### Batch

A BaseNode can be passed into a Batcher, to get batches of size `batch_size`.
By default, `drop_last` is True, meaning if the last batch has a size smaller than the `batch_size`, it is not produced.

In [6]:
from torchdata.nodes import Batcher
batcher = Batcher(source, batch_size = 4)
for batch in batcher:
    print(batch)

[0, 1, 2, 3]
[4, 5, 6, 7]


We can make `drop_last = False` to produce the last batch

In [7]:
batcher = Batcher(source, batch_size = 4, drop_last = False)
for batch in batcher:
    print(batch)

[0, 1, 2, 3]
[4, 5, 6, 7]
[8, 9]


If we try to use this batcher over multiple epochs, we will need to reset it after every epoch

In [8]:
batcher = Batcher(source, batch_size = 10)
num_epochs = 2

for epoch in range(num_epochs):
    for batch in batcher:
        print(f"Epoch = {epoch}", f" Batch = {batch}")
    batcher.reset()
    
# This is one extra step than traditional dataloader, we can actually wrap the batcher in a Loader to skip that
# Let's look at Loader in the next cell

Epoch = 0  Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Epoch = 1  Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


#### Loader
As you can see, we get a batch in every epoch, without even needing to reset the loader!!

In [9]:
from torchdata.nodes import Loader
batcher = Batcher(source, batch_size = 10)
loader = Loader(batcher)

for epoch in range(num_epochs):
    for batch in loader:
        print(f"Epoch = {epoch}", f" Batch = {batch}")

Epoch = 0  Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Epoch = 1  Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


#### SamplerWrapper

As mentioned earlier, we can use `torch.data.utils` samplers using `MapStyleWrapper`.
Alternatively, we can employ the `SamplerWrapper`, which converts a `Sampler` into a `BaseNode`. `SamplerWrapper` differs from `IterableWrapper` because it will track the number of epochs, and call the sampler's `set_epoch` method if it is implemented.

In [10]:
from torchdata.nodes import SamplerWrapper

sampler = RandomSampler(dataset)
node = SamplerWrapper(sampler)
batcher = Batcher(node, batch_size = 10)
loader = Loader(batcher)
for epoch in range(num_epochs):
    
    for batch in loader:
        print(f"Epoch = {node.epoch}", f" Batch = {batch}")

Epoch = 0  Batch = [2, 9, 0, 5, 8, 6, 7, 3, 1, 4]
Epoch = 1  Batch = [2, 3, 6, 8, 5, 0, 1, 9, 7, 4]


`torchadata` nodes are composable, thus, many BaseNodes type nodes can be chained together for desired transformations

#### Chaining multiple operations together

Base nodes are iterators, and are designed to be chained together to create more complex dataloading graphs.

In [11]:
sampler = RandomSampler(dataset)
node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)
node = Mapper(node, map_fn = lambda x : x**3)
node = Batcher(node, batch_size = 4, drop_last = False)
loader = Loader(node)

In [12]:
for batch in loader:
    print(batch)

[64, 216, 27, 729]
[0, 8, 343, 1]
[512, 125]
