### In this notebook, we will explore the usage of MultiNodeWeightedSampler in torchdata.nodes

#### MultiNodeWeightedSampler allows us to sample with a probability from multiple datsets
#### We will make three datasets, and then see how does the composition of the batch depend on the weights defined in the MultiNodeWeightedSampler

In [1]:
import torch
from torch.utils.data import SequentialSampler

from torchdata.nodes import MapStyleWrapper, Batcher,  Loader, Mapper, MultiNodeWeightedSampler

# defining a simple map_fn as a place holder example
def map_fn(item):
    return {"x":item}

# In this function, we create a dictionary of datasets, that can be passed to the MultiNodeWeightedSampler
# Each dataset contains just one value, `length` number of times

def get_datasets(num_datasets, length=100000):
    """
    Create a dictionary of datasets with simple transformations.
    Args:
        num_datasets (int): Number of datasets to create.
        length (int, optional): Length of each dataset. Defaults to 10000.
    Returns:
        dict: Dictionary of datasets with simple transformations.
    """
    datasets = {}
    for i in range(num_datasets):
        data = [i] * length # We first create a list, all elements are have value = i
        sampler = SequentialSampler(data) # We can use a SequentialSampler or a RandomSampler if we want to shuffle the dataset
        # Next we create a BaseNode, by passing our dataset, and sampler
        node = MapStyleWrapper(map_dataset=data, sampler=sampler) 
        # Next, we apply our simple transformation of changing the type of each element
        datasets[f"ds{i}"] = Mapper(node, map_fn=map_fn)
    return datasets

    

In [2]:
# First, we create a dictionary of three datasets
num_datasets = 3
datasets = get_datasets(num_datasets)

# Next, we have to define weights for sampling from a particular dataset
# Make sure that the weights dict has the same keys as the datasets
weights = {"ds0":0.5, "ds1":0.25, "ds2":0.25}

# Finally we instatiate the MultiNodeWeightedSampler to sample from our datasets
multi_node_sampler = MultiNodeWeightedSampler(datasets, weights)

# We can use the Batcher functionality to create batches of `batch_size`
multi_node_batcher = Batcher(multi_node_sampler, batch_size = 1000)

# Since nodes are iterators, they need to be manually .reset() between epochs.
# We can wrap the root node in Loader to convert it to a more conventional Iterable.
train_loader = Loader(multi_node_batcher)

# This train loader can be used to provide batches during training epochs

In [3]:
# We have written a function to find the composition of each batch,
# and see if the batch composition follows our given weights
from collections import Counter
def compute_dataset_fraction(num_datasets, batch):

    total_length = len(batch)
    results = [item["x"] for item in batch]
    counts = Counter(results)
    fractions = {}
    for key, value in counts.items():
        fractions[f"ds{int(key)}"] = value/total_length
    for i in range(num_datasets):
        print(f"The fraction of ds{i} is = ", fractions[f"ds{i}"])
        

In [4]:
# let's go through the batches, and compute the fraction of each dataset in that batch
for batch in train_loader:

    compute_dataset_fraction(num_datasets, batch)
    print("The original weights were", weights)
    break

The fraction of ds0 is =  0.511
The fraction of ds1 is =  0.244
The fraction of ds2 is =  0.245
The original weights were {'ds0': 0.5, 'ds1': 0.25, 'ds2': 0.25}


#### Since picking items from the datasets according to weights is a stochastic process, the fraction of datasets in the batch is close to provided weights, but not exactly equal.
#### If we increase the batch_size, the fractions will asymptotically reach the provided weights.

In [5]:
# Let's try a bigger batch size
multi_node_batcher = Batcher(multi_node_sampler, batch_size = 100000)

train_loader = Loader(multi_node_batcher)

# This time we get fractions much closer to our provided weights
for batch in train_loader:

    compute_dataset_fraction(num_datasets, batch)
    print("The original weights were", weights)
    break

The fraction of ds0 is =  0.50148
The fraction of ds1 is =  0.24844
The fraction of ds2 is =  0.25008
The original weights were {'ds0': 0.5, 'ds1': 0.25, 'ds2': 0.25}
