### In this notebook, we will explore the usage of MultiNodeWeightedSampler in torchdata.nodes

#### MultiNodeWeightedSampler allows us to sample with a probability from multiple datsets
#### We will make three datasets, and then see how does the composition of the output depends on the weights defined in the MultiNodeWeightedSampler

In [1]:
from torchdata.nodes import Mapper, MultiNodeWeightedSampler, IterableWrapper, Loader
import collections

# defining a simple map_fn as a place holder example
def map_fn(item):
    return {"x":item}


def constant_stream(value: int):
  while True:
    yield value

# First, we create a dictionary of three datasets, with each dataset converted into BaseNode using the IterableWrapper
num_datasets = 3
datasets = {
    "ds0": IterableWrapper(constant_stream(0)),
    "ds1": IterableWrapper(constant_stream(1)),
    "ds2": IterableWrapper(constant_stream(2)),
}

# Next, we have to define weights for sampling from a particular dataset
weights = {"ds0": 0.5, "ds1": 0.25, "ds2": 0.25}

# Finally we instatiate the MultiNodeWeightedSampler to sample from our datasets
multi_node_sampler = MultiNodeWeightedSampler(datasets, weights)

# Since nodes are iterators, they need to be manually .reset() between epochs.
# We can wrap the root node in Loader to convert it to a more conventional Iterable.
train_loader = Loader(multi_node_sampler)

In [2]:
# Let's take a look at the output for 1000 numbers, compute the fraction of each dataset in that batch
# and see if the batch composition follows our given weights
n = 1000
it = iter(train_loader)
samples = [next(it) for _ in range(n)]
fractions = {k: v/len(samples) for k, v in collections.Counter(samples).items()}
print(f"fractions = {fractions}")
print(f"The original weights were = {weights}")

fractions = {0: 0.511, 2: 0.245, 1: 0.244}
The original weights were = {'ds0': 0.5, 'ds1': 0.25, 'ds2': 0.25}


#### Since picking items from the datasets according to weights is a stochastic process, the fraction of datasets in the items we have fetched is close to provided weights, but not exactly equal.
#### If we fetch more items, the fractions will asymptotically reach the provided weights.

In [3]:
# Let's increase `n` to ten thousand
n = 10000
it = iter(train_loader)
samples = [next(it) for _ in range(n)]
fractions = {k: v/len(samples) for k, v in collections.Counter(samples).items()}
print(f"fractions = {fractions}")
print(f"The original weights were = {weights}")

fractions = {0: 0.5097, 2: 0.244, 1: 0.2463}
The original weights were = {'ds0': 0.5, 'ds1': 0.25, 'ds2': 0.25}
