In [None]:
import math
import os

from typing import Iterator, List, Optional

import more_itertools
import numpy as np

import torch
import torch.distributed as dist

from torch.utils.data import Sampler
from torch.utils.data.sampler import T_co


class LengthSortBatchSamplerWithFirstMaxLength(Sampler[T_co]):
    """Batch Sampler that restricts data loading to a subset of the dataset. It tries to make equally-length batches
        returning the longest batch first.

        Docstring by https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler.
        It is especially useful in conjunction with
        :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
        process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a
        :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
        original dataset that is exclusive to it.

        Args:
            batch_size: Size of mini-batch
            lengths: lengths list to make sampling
            is_distributed: use or not distributed mode
            num_replicas (int, optional): Number of processes participating in
                distributed training. It's equvivalent to old `world_size` arg.
                 By default, :attr:`world_size` is retrieved from the
                current distributed group.
            rank (int, optional): Rank of the current process within :attr:`num_replicas`.
                By default, :attr:`rank` is retrieved from the current distributed
                group.
            shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
                batch indices.
            seed (int, optional): random seed used to shuffle the sampler if
                :attr:`shuffle=True`. This number should be identical across all
                processes in the distributed group. Default: ``0``.
            drop_last (bool, optional): if ``True``, then the sampler will drop the
                tail of the data to make it evenly divisible across the number of
                replicas. If ``True``, the sampler will drop the last batch if
                its size would be less than ``batch_size``.

        .. warning::
            In distributed mode, calling the :meth:`set_epoch` method at
            the beginning of each epoch **before** creating the :class:`DataLoader` iterator
            is necessary to make shuffling work properly across multiple epochs. Otherwise,
            the same ordering will be always used.
        """

    def __init__(self, batch_size: int, lengths: List[int], is_distributed: bool = False,
                 num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0,
                 drop_last: bool = False):

        super().__init__(lengths)

        self.is_distributed = is_distributed
        self.batch_size = batch_size
        self.lengths = lengths
        self.drop_last = drop_last
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0

        if self.is_distributed:
            if num_replicas is None:
                if not dist.is_available():
                    raise RuntimeError("Requires distributed package to be available")
                num_replicas = dist.get_world_size()
            if rank is None:
                if not dist.is_available():
                    raise RuntimeError("Requires distributed package to be available")
                rank = dist.get_rank()
            if rank >= num_replicas or rank < 0:
                raise ValueError(
                    "Invalid rank {}, rank should be in the interval"
                    " [0, {}]".format(rank, num_replicas - 1))

            self.num_replicas = num_replicas
            self.rank = rank
            self.num_samples = math.ceil(self.__len__() / num_replicas)
            self.total_size = self.num_samples * num_replicas

    def __len__(self) -> int:
        num_batches = len(self.lengths) // self.batch_size
        if len(self.lengths) % self.batch_size != 0 and not self.drop_last:
            num_batches += 1
        return num_batches

    def __iter__(self) -> Iterator[T_co]:

        # construct batches
        sorted_indices = np.argsort(self.lengths)
        batches = list(more_itertools.chunked(sorted_indices, self.batch_size))

        if self.drop_last and len(batches[-1]) < self.batch_size:
            batches = batches[:-1]

        # the batch having the longest item; w/ max length it will be first to approach OOM problem
        longest_batch = [batches[-1]]

        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)

            # shuffle batch indices
            shuffled_batches = np.array(batches[:-1])[torch.randperm(len(batches) - 1, generator=g).tolist()].tolist()
            batches = longest_batch + shuffled_batches
        else:
            batches = longest_batch + batches[:-1]

        if self.is_distributed:
            assert len(batches) == self.total_size

            # subsample
            batches = batches[self.rank:self.total_size:self.num_replicas]
            assert len(batches) == self.num_samples

        return iter(batches)

    def set_epoch(self, epoch: int) -> None:
        self.epoch = epoch


if __name__ == '__main__':

    print("Dist mode: ON")
    
    GLOO = "gloo"  # <SET YOUR AVAILABLE BACKEND>

    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "22"  # <SET YOUR AVAILABLE PORT>

    dist.init_process_group(GLOO, rank=0, world_size=1)  # make it if it's not initialized yet
    
    lengths = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
    for shuffle in [True, False]:
        for batch_size in [1, 2, 3, 4, 5]:
            dist_sampler = LengthSortBatchSamplerWithFirstMaxLength(batch_size, lengths=lengths,
                                                                    is_distributed=True,
                                                                    shuffle=shuffle, seed=2512)

            print(list(dist_sampler))

    print("Dist mode: OFF")
    lengths = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

    for shuffle in [True, False]:
        for batch_size in [1, 2, 3, 4, 5]:
            sampler = LengthSortBatchSamplerWithFirstMaxLength(batch_size, lengths=lengths,
                                                               is_distributed=False,
                                                               shuffle=shuffle,
                                                               seed=2512)
            print(list(sampler))
