### Torchdata.nodes basics

#### All torchdata.nodes.BaseNode implementations are Iterators, adhering to the following API:
```Python
class BaseNode(Iterator[T]):
    def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
        """Resets the node to its initial state or a specified state."""
        ...
    def __next__(self) -> T:
        """Returns the next value in the sequence."""
        ...
    def get_state(self) -> Dict[str, Any]:
        """Returns a dictionary representing the current state of the node."""
        ...
```
#### This standardized interface enables seamless chaining of iterators, allowing for flexible, efficient, and composable data processing pipelines.

#### Let's see the functionalities of torchdata.nodes through the help of very simple example

#### BaseNode

In [1]:
from torchdata.nodes import IterableWrapper
# This Wrapper that converts any Iterable in to a BaseNode.

dataset = range(10) # creating a very simple dataset, and then converting it into a BaseNode
base_node = IterableWrapper(dataset)

In [2]:
# Let's take a look at the items in the node
for item in base_node:
    print(item)

0
1
2
3
4
5
6
7
8
9


#### Integrating with torch.data Dataloaders and Samplers

In [3]:
# We can also use torch.data.utils style dataloaders and samplers, and then wrap them into nodes
from torchdata.nodes import MapStyleWrapper
from torch.utils.data import RandomSampler

sampler = RandomSampler(dataset)
node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)

for item in node:
    print(item)

8
4
7
6
0
3
5
9
2
1


#### Map

In [4]:
# We can use the Mapper class, to apply a transformation defined using the `map_fn`
from torchdata.nodes import Mapper
mapped_dataset = Mapper(base_node, map_fn = lambda x : x**2)
for item in mapped_dataset:
    print(item)

0
1
4
9
16
25
36
49
64
81


In [5]:
# It can also be executed in parallel, using the multi threading/processing approaches, depending the defined `method`
from torchdata.nodes import ParallelMapper
mapped_dataset = ParallelMapper(base_node, map_fn = lambda x : x**2, num_workers =2, method = "thread")
for item in mapped_dataset:
    print(item)

0
1
4
9
16
25
36
49
64
81


#### Batch

In [6]:
# A BaseNode can be passed into a Batcher, to get batches of size `batch_size`.
# By default, drop_last is True, meaning if the last batch has a size smaller than the `batch_size`, it is not produced.
from torchdata.nodes import Batcher
batched_dataset = Batcher(base_node, batch_size = 4)
for batch in batched_dataset:
    print(batch)

[0, 1, 2, 3]
[4, 5, 6, 7]


In [7]:
# We can make `drop_last = False` to produce the last batch
batched_dataset = Batcher(base_node, batch_size = 4, drop_last = False)
for batch in batched_dataset:
    print(batch)

[0, 1, 2, 3]
[4, 5, 6, 7]
[8, 9]


In [8]:
batched_dataset = Batcher(base_node, batch_size = 10)

# If we try to use this batcher over multiple epochs, we will need to reset it after every epoch
num_epochs = 2

for epoch in range(num_epochs):
    print(f"On epoch = {epoch}")
    for batch in batched_dataset:
        print(batch)
    batched_dataset.reset()
    
# This is one extra step than traditional dataloader, we can actually wrap the batcher in a Loader to skip that
# Let's look at Loader in the next cell

On epoch = 0
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
On epoch = 1
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


#### Loader

In [9]:
from torchdata.nodes import Loader
batched_dataset = Batcher(base_node, batch_size = 10)
loader = Loader(batched_dataset)

for epoch in range(num_epochs):
    print(f"On epoch = {epoch}")
    for batch in loader:
        print(batch)

# As you can see, we get a batch in every epoch, without even needing to reset the loader!!
    


On epoch = 0
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
On epoch = 1
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [10]:
# torchadata nodes are composable, thus, many BaseNodes type nodes can be chained together for desired transformations

#### Chaining multiple operations together

In [None]:
sampler = RandomSampler(dataset)
node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)
mapped_dataset = Mapper(base_node, map_fn = lambda x : x**2)
batched_dataset = Batcher(base_node, batch_size = 4, dr)