## 2024.0509 
研究按照概率进行采样，设计的方法有: dataset, sampler(返回一个indices,兼容dataloader层面的batch_size和drop_last), batch_sampler(自定义返回一组) 和collate_fn

In [1]:
from torch.utils.data import Dataset as Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.utils.data.sampler import Sampler
import torch
import numpy as np

In [2]:
# alphafold mimic residue index
copy_num = 1
input_seqs = ['AABCAA']
input_sequence = input_seqs[0] * copy_num
sequence_features = {'residue_index': np.array(range(len(input_sequence)))}

idx_res = sequence_features['residue_index']
L_prev = 0
Ls = [len(input_seqs[0])]*copy_num
# Ls: number of residues in each chain
for L_i in Ls[:-1]:
    idx_res[L_prev+L_i:] += 200
    L_prev += L_i
sequence_features['residue_index'] = idx_res

In [5]:
class toy_dataset(Dataset):
    def __init__(self):
        self.data = torch.randn(5, 10)
        self.labels = torch.randint(0, 2, (5,))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
dataset = toy_dataset()

In [6]:
weights = [0.9,0.01,0.01,0.01,0.01]
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(dataset), replacement=False)
dataloader = DataLoader(dataset, batch_size=2, sampler=sampler,drop_last=False)
# 注意, weightedRandomSampler是一个sampler，每次返回一个indices,这样可以指定batch_size和drop_last；如果是自己写batch_sampler，则需要手动返回一个batch的indices,不支持在Dataloader层面输入batch_size

In [7]:
import torch
from torch import Tensor
class WeightedRandomSampler(Sampler[int]):
    r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).

    Args:
        weights (sequence)   : a sequence of weights, not necessary summing up to one
        num_samples (int): number of samples to draw
        replacement (bool): if ``True``, samples are drawn with replacement.
            If not, they are drawn without replacement, which means that when a
            sample index is drawn for a row, it cannot be drawn again for that row.
        generator (Generator): Generator used in sampling.

    Example:
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
        >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
        [4, 4, 1, 4, 5]
        >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
        [0, 1, 4, 3, 2]
    """

    weights: Tensor
    num_samples: int
    replacement: bool

    def __init__(self, weights, num_samples: int,
                 replacement: bool = True, generator=None) -> None:
        if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
                num_samples <= 0:
            raise ValueError(f"num_samples should be a positive integer value, but got num_samples={num_samples}")
        if not isinstance(replacement, bool):
            raise ValueError(f"replacement should be a boolean value, but got replacement={replacement}")

        weights_tensor = torch.as_tensor(weights, dtype=torch.double)
        if len(weights_tensor.shape) != 1:
            raise ValueError("weights should be a 1d sequence but given "
                             f"weights have shape {tuple(weights_tensor.shape)}")

        self.weights = weights_tensor
        self.num_samples = num_samples
        self.replacement = replacement
        self.generator = generator

    def __iter__(self) :
        # weights 需要为全部的数据集长度, num_samples表示从数据集中按照weight的概率采样N个样本
        rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
        for indices in rand_tensor:
            yield [indices]
        # yield from iter(rand_tensor.tolist())

    def __len__(self) -> int:
        return self.num_samples


In [8]:
weights = [0.9,0.01,0.01,0.01,0.01]
sampler = WeightedRandomSampler(weights, len(dataset), replacement=False) # sampler定义好之后，在for 循环的时候已经定义好一个dataloader的顺序
dataloader = DataLoader(dataset,batch_sampler=sampler)


In [9]:
for i in dataloader:
    print(i)

[tensor([[-0.9735,  0.0124, -1.3516,  0.7601, -0.6293,  0.5954, -0.2301, -0.2935,
         -0.6909, -0.0957]]), tensor([0])]
[tensor([[-0.6128, -1.4545, -0.0640,  1.0201, -0.9541, -0.0118,  0.3479,  0.7788,
         -2.0531,  0.1543]]), tensor([0])]
[tensor([[ 8.8880e-04,  8.8777e-01,  1.7839e+00, -1.1767e+00, -1.9333e-01,
          2.0674e+00,  1.4307e-01,  7.7601e-01, -1.1393e+00, -6.4407e-01]]), tensor([0])]
[tensor([[ 0.2018, -1.4598, -0.2240, -1.3986, -0.6045, -0.1464,  2.2327,  0.8451,
         -0.6582, -0.3910]]), tensor([0])]
[tensor([[ 0.2854, -0.3968, -0.0665,  1.6820, -0.3343, -1.2104, -0.3105,  1.1500,
          0.0115, -0.1329]]), tensor([0])]


In [None]:
class AccedingSequenceLengthSampler(Sampler[int]): # sampler=sampler
    def __init__(self, data: List[str]) -> None:
        self.data = data

    def __len__(self) -> int:
        return len(self.data)

    def __iter__(self) -> Iterator[int]:
        sizes = torch.tensor([len(x) for x in self.data])
        yield from torch.argsort(sizes).tolist()

class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): # batch_sampler=sampler
    # 按照长度进行chunk_size的batch sample
    def __init__(self, data: List[str], batch_size: int) -> None:
        self.data = data
        self.batch_size = batch_size
    def __len__(self) -> int:
        return (len(self.data) + self.batch_size - 1) // self.batch_size
    def __iter__(self) -> Iterator[List[int]]:
        sizes = torch.tensor([len(x) for x in self.data])
        for batch in torch.chunk(torch.argsort(sizes), len(self)):
            yield batch.tolist()

In [13]:
# mypy: allow-untyped-defs
from typing import (
    Generic,
    Iterable,
    Iterator,
    List,
    Optional,
    Sequence,
    Sized,
    TypeVar,
    Union,
)

import torch


__all__ = [
    "BatchSampler",
    "RandomSampler",
    "Sampler",
    "SequentialSampler",
    "SubsetRandomSampler",
    "WeightedRandomSampler",
]


_T_co = TypeVar("_T_co", covariant=True)


class Sampler(Generic[_T_co]):
    r"""Base class for all Samplers.

    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices or lists of indices (batches) of dataset elements,
    and may provide a :meth:`__len__` method that returns the length of the returned iterators.

    Args:
        data_source (Dataset): This argument is not used and will be removed in 2.2.0.
            You may still have custom implementation that utilizes it.

    Example:
        >>> # xdoctest: +SKIP
        >>> class AccedingSequenceLengthSampler(Sampler[int]):
        >>>     def __init__(self, data: List[str]) -> None:
        >>>         self.data = data
        >>>
        >>>     def __len__(self) -> int:
        >>>         return len(self.data)
        >>>
        >>>     def __iter__(self) -> Iterator[int]:
        >>>         sizes = torch.tensor([len(x) for x in self.data])
        >>>         yield from torch.argsort(sizes).tolist()
        >>>
        >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
        >>>     def __init__(self, data: List[str], batch_size: int) -> None:
        >>>         self.data = data
        >>>         self.batch_size = batch_size
        >>>
        >>>     def __len__(self) -> int:
        >>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
        >>>
        >>>     def __iter__(self) -> Iterator[List[int]]:
        >>>         sizes = torch.tensor([len(x) for x in self.data])
        >>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
        >>>             yield batch.tolist()

    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """

    def __init__(self, data_source: Optional[Sized] = None) -> None:
        if data_source is not None:
            import warnings

            warnings.warn(
                "`data_source` argument is not used and will be removed in 2.2.0."
                "You may still have custom implementation that utilizes it."
            )

    def __iter__(self) -> Iterator[_T_co]:
        raise NotImplementedError

    # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    #
    # Many times we have an abstract class representing a collection/iterable of
    # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
    # implementing a `__len__` method. In such cases, we must make sure to not
    # provide a default implementation, because both straightforward default
    # implementations have their issues:
    #
    #   + `return NotImplemented`:
    #     Calling `len(subclass_instance)` raises:
    #       TypeError: 'NotImplementedType' object cannot be interpreted as an integer
    #
    #   + `raise NotImplementedError`:
    #     This prevents triggering some fallback behavior. E.g., the built-in
    #     `list(X)` tries to call `len(X)` first, and executes a different code
    #     path if the method is not found or `NotImplemented` is returned, while
    #     raising a `NotImplementedError` will propagate and make the call fail
    #     where it could have used `__iter__` to complete the call.
    #
    # Thus, the only two sensible things to do are
    #
    #   + **not** provide a default `__len__`.
    #
    #   + raise a `TypeError` instead, which is what Python uses when users call
    #     a method that is not defined on an object.
    #     (@ssnl verifies that this works on at least Python 3.7.)


class SequentialSampler(Sampler[int]):
    r"""Samples elements sequentially, always in the same order.

    Args:
        data_source (Dataset): dataset to sample from
    """

    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)
# mypy: allow-untyped-defs
from typing import (
    Generic,
    Iterable,
    Iterator,
    List,
    Optional,
    Sequence,
    Sized,
    TypeVar,
    Union,
)

import torch


__all__ = [
    "BatchSampler",
    "RandomSampler",
    "Sampler",
    "SequentialSampler",
    "SubsetRandomSampler",
    "WeightedRandomSampler",
]


_T_co = TypeVar("_T_co", covariant=True)


class Sampler(Generic[_T_co]):
    r"""Base class for all Samplers.

    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices or lists of indices (batches) of dataset elements,
    and may provide a :meth:`__len__` method that returns the length of the returned iterators.

    Args:
        data_source (Dataset): This argument is not used and will be removed in 2.2.0.
            You may still have custom implementation that utilizes it.

    Example:
        >>> # xdoctest: +SKIP
        >>> class AccedingSequenceLengthSampler(Sampler[int]):
        >>>     def __init__(self, data: List[str]) -> None:
        >>>         self.data = data
        >>>
        >>>     def __len__(self) -> int:
        >>>         return len(self.data)
        >>>
        >>>     def __iter__(self) -> Iterator[int]:
        >>>         sizes = torch.tensor([len(x) for x in self.data])
        >>>         yield from torch.argsort(sizes).tolist()
        >>>
        >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
        >>>     def __init__(self, data: List[str], batch_size: int) -> None:
        >>>         self.data = data
        >>>         self.batch_size = batch_size
        >>>
        >>>     def __len__(self) -> int:
        >>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
        >>>
        >>>     def __iter__(self) -> Iterator[List[int]]:
        >>>         sizes = torch.tensor([len(x) for x in self.data])
        >>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
        >>>             yield batch.tolist()

    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """

    def __init__(self, data_source: Optional[Sized] = None) -> None:
        if data_source is not None:
            import warnings

            warnings.warn(
                "`data_source` argument is not used and will be removed in 2.2.0."
                "You may still have custom implementation that utilizes it."
            )

    def __iter__(self) -> Iterator[_T_co]:
        raise NotImplementedError

    # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    #
    # Many times we have an abstract class representing a collection/iterable of
    # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
    # implementing a `__len__` method. In such cases, we must make sure to not
    # provide a default implementation, because both straightforward default
    # implementations have their issues:
    #
    #   + `return NotImplemented`:
    #     Calling `len(subclass_instance)` raises:
    #       TypeError: 'NotImplementedType' object cannot be interpreted as an integer
    #
    #   + `raise NotImplementedError`:
    #     This prevents triggering some fallback behavior. E.g., the built-in
    #     `list(X)` tries to call `len(X)` first, and executes a different code
    #     path if the method is not found or `NotImplemented` is returned, while
    #     raising a `NotImplementedError` will propagate and make the call fail
    #     where it could have used `__iter__` to complete the call.
    #
    # Thus, the only two sensible things to do are
    #
    #   + **not** provide a default `__len__`.
    #
    #   + raise a `TypeError` instead, which is what Python uses when users call
    #     a method that is not defined on an object.
    #     (@ssnl verifies that this works on at least Python 3.7.)


# If no shuffile
class SequentialSampler(Sampler[int]):
    r"""Samples elements sequentially, always in the same order.

    Args:
        data_source (Dataset): dataset to sample from
    """

    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)


# If shuffle
class RandomSampler(Sampler[int]):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.

    If with replacement, then user can specify :attr:`num_samples` to draw.

    Args:
        data_source (Dataset): dataset to sample from
        replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
        num_samples (int): number of samples to draw, default=`len(dataset)`.
        generator (Generator): Generator used in sampling.
    """

    data_source: Sized
    replacement: bool

    def __init__(
        self,
        data_source: Sized,
        replacement: bool = False,
        num_samples: Optional[int] = None,
        generator=None,
    ) -> None:
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        self.generator = generator

        if not isinstance(self.replacement, bool):
            raise TypeError(
                f"replacement should be a boolean value, but got replacement={self.replacement}"
            )

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError(
                f"num_samples should be a positive integer value, but got num_samples={self.num_samples}"
            )

    @property
    def num_samples(self) -> int:
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        if self.generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = self.generator

        if self.replacement:
            for _ in range(self.num_samples // 32):
                yield from torch.randint(
                    high=n, size=(32,), dtype=torch.int64, generator=generator
                ).tolist()
            yield from torch.randint(
                high=n,
                size=(self.num_samples % 32,),
                dtype=torch.int64,
                generator=generator,
            ).tolist()
        else:
            for _ in range(self.num_samples // n):
                yield from torch.randperm(n, generator=generator).tolist()
            yield from torch.randperm(n, generator=generator).tolist()[
                : self.num_samples % n
            ]

    def __len__(self) -> int:
        return self.num_samples


In [39]:
dataset = torch.randn((100,1000,1000))
sequential_sampler = SequentialSampler(dataset)
random_sampler = RandomSampler(dataset,replacement=False,num_samples=101)

Python native cache function, makes input datatype hashable to cache

hashable (immutable) : number, string, tuple

unhashable (mutable) : list, set ,dict, set

In [61]:
max_squared_res = 1e6
max_len = 455
max_batch_examples = int(max_squared_res // max_len**2)
print(max_batch_examples)

4


In [59]:
x = torch.randn((100,))
pad_idx = 0
max_len = 512

seq_len = x.shape[pad_idx]
pad_amt = max_len - seq_len
pad_widths = [(0, 0)] * x.ndim
pad_widths[pad_idx] = (0, pad_amt)
np.pad(x, pad_widths).shape

(512,)

In [None]:
max_batch_examples = int(max_squared_res // max_len**2)

In [46]:
import functools as fn
import numpy as np

# Correctly defining the compute_square function
@fn.lru_cache(maxsize=3)
def compute_square(n):
    """
    Computes the square of each element in the input sequence.

    Args:
        n (tuple): A tuple of integers.

    Returns:
        list: A list containing the squares of the input integers.
    """
    print(f"Computing square of {n}")
    return np.square(n).tolist()

# Function calls with tuples instead of lists
print(compute_square((2, 3)))  # Computes and caches result
print(compute_square((3, 4)))  # Computes and caches result
print(compute_square((2, 3)))  # Returns cached result
print(compute_square((4, 5)))  # Computes and caches result
print(compute_square((5, 6)))  # Computes and caches result, evicts least recently used ((2, 3))
print(compute_square((4, 5)))  # Returns cached result


Computing square of (2, 3)
[4, 9]
Computing square of (3, 4)
[9, 16]
[4, 9]
Computing square of (4, 5)
[16, 25]
Computing square of (5, 6)
[25, 36]
[16, 25]


In [3]:
sampler_iter = iter(sequential_sampler)

In [11]:
batch_size = 10
batch = [next(sampler_iter) for _ in range(10)]
print(batch)

[50, 51, 52, 53, 54, 55, 56, 57, 58, 59]


In [7]:
batch

[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

In [None]:
class BatchSampler(Sampler[List[int]]):
    r"""Wraps another sampler to yield a mini-batch of indices.

    Args:
        sampler (Sampler or Iterable): Base sampler. Can be any iterable object
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``

    Example:
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """

    def __init__(
        self,
        sampler: Union[Sampler[int], Iterable[int]],
        batch_size: int,
        drop_last: bool,
    ) -> None:
        # Since collections.abc.Iterable does not check for `__getitem__`, which
        # is one way for an object to be an iterable, we don't do an `isinstance`
        # check here.
        if (
            not isinstance(batch_size, int)
            or isinstance(batch_size, bool)
            or batch_size <= 0
        ):
            raise ValueError(
                f"batch_size should be a positive integer value, but got batch_size={batch_size}"
            )
        if not isinstance(drop_last, bool):
            raise ValueError(
                f"drop_last should be a boolean value, but got drop_last={drop_last}"
            )
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self) -> Iterator[List[int]]:
        # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
        if self.drop_last:
            sampler_iter = iter(self.sampler) # important! only generates iter once
            while True:
                try:
                    batch = [next(sampler_iter) for _ in range(self.batch_size)]
                    yield batch
                except StopIteration:
                    break
        else:
            batch = [0] * self.batch_size
            idx_in_batch = 0
            for idx in self.sampler: # for loop still generates iter once and then iterates <-> for idx in iter(self.sampler)
                batch[idx_in_batch] = idx
                idx_in_batch += 1
                if idx_in_batch == self.batch_size:
                    yield batch
                    idx_in_batch = 0
                    batch = [0] * self.batch_size
            if idx_in_batch > 0:
                yield batch[:idx_in_batch]

    def __len__(self) -> int:
        # Can only be called if self.sampler has __len__ implemented
        # We cannot enforce this condition, so we turn off typechecking for the
        # implementation below.
        # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
        if self.drop_last:
            return len(self.sampler) // self.batch_size  # type: ignore[arg-type]
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size  # type: ignore[arg-type]

next() : A built-in function used to return the next item in an iterator. 

iter() : A built-in function used to convert an iterable to an iterator. 

yield() : A python keyword similar to the return keyword, except yield returns a generator object instead of a value. 

generator is Lazy Evaluation so produce items one at a time and don't need to load all the dadta

In [24]:
class test_naive_iterator:
    def __init__(self):
        self.data = torch.randn(10, 10)
        self.labels = torch.randint(0, 2, (10,))
    def __iter__(self):
        return iter(range(len(self.data)))
    def __next__(self):
        return next(self.__iter__())

class test_generator_iterator:
    def __init__(self):
        self.data = torch.randn(10, 10)
        self.labels = torch.randint(0, 2, (10,))
    def __iter__(self):
        yield from range(len(self.data))
    def __next__(self):
        return next(self.__iter__())

In [25]:
naive = test_naive_iterator()
generator = test_generator_iterator()


In [26]:
for i in naive:
    print(i)

0
1
2
3
4
5
6
7
8
9


In [27]:
for i in generator:
    print(i)

0
1
2
3
4
5
6
7
8
9
