In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, Sampler

# for time series subsampling
class TimeSeriesSubsampleDataset(Dataset):
    def __init__(self, timeseries, input_len, target_len):
        self.timeseries = timeseries
        self.input_len = input_len
        self.target_len = target_len
        self.total_len = input_len + target_len
        
    def __len__(self):
        return len(self.timeseries) - self.total_len + 1
    
    def __getitem__(self, idx):
        input_seq = self.timeseries[idx:idx + self.input_len]
        target_seq = self.timeseries[idx + self.input_len:idx + self.input_len + self.target_len]
        return input_seq, target_seq
  
  
# for time series forecasting
class TimeseriesSubsampler(Sampler):
    def __init__(self, data_source, num_samples=None, replacement=None):
        
        self.data_source = data_source
        self.data_len = len(data_source)
        self.num_samples = num_samples or len(data_source)
        self.replacement = replacement
        
    def __iter__(self):
        idx = torch.randint(
            low=0,
            high=self.data_len,
            size=(self.num_samples,),
            generator=None,
            dtype=torch.int64
        ) if self.replacement else torch.randperm(self.data_len)[:self.num_samples]
        return iter(idx)
    
    def __len__(self):
        return self.num_samples
        
# for time series classification
class StratifiedTimeseriesSubsampler(Sampler):
    def __init__(self, data_source, num_samples=None, replacement=None, stratify=None):
        
        self.data_source = data_source
        self.data_len = len(data_source)
        self.num_samples = num_samples or len(data_source)
        self.replacement = replacement
        
        
        if stratify is None or len(stratify) != self.data_len:
            raise ValueError("Stratify must be a list of the same length as the dataset")
        self.stratify = stratify
        
        # calculate the number of samples per strata
        self.strata_indices = self._compute_strata_indices()
        self.strata_sample_counts = self._compute_sample_counts()
        
    def _compute_strata_indices(self):
        strata_indices = {}
        for idx, label in enumerate(self.stratify.tolist()):
            if label not in strata_indices:
                strata_indices[label] = []
            strata_indices[label].append(idx)
        return strata_indices
    
    def _compute_sample_counts(self):
        label_counts = torch.bincount(self.stratify)
        total_labels = label_counts.sum()
        proportions = label_counts.float() / total_labels
        return (proportions * self.num_samples).long()
    
    def __iter__(self):
        sampled_indices = []
        for label, indices in self.strata_indices.items():
            indices_tensor = torch.tensor(indices, dtype=torch.long)
            if self.replacement:
                sampled = torch.multinomial(
                    torch.ones(len(indices_tensor)),
                    num_samples=self.strata_sample_counts[label].item(),
                    replacement=True
                )
            else:
                sampled = torch.randperm(
                    len(indices_tensor))[
                        :self.strata_sample_counts[label].item()]
            
            sampled_indices.extend(indices_tensor[sampled].tolist())
        
        shuffled_indices = torch.randperm(len(sampled_indices))
        
        return iter([sampled_indices[i] for i in shuffled_indices])

 
    
# test dataset and sampler
time_series = torch.arange(1000)
labels = torch.tensor([0]*300 + [1]*400 + [2]*300)

input_len = 10
target_len = 5
batch_size = 32
num_samples = 200

dataset = TimeSeriesSubsampleDataset(time_series, input_len, target_len)

sampler = TimeseriesSubsampler(dataset,
                               num_samples=num_samples,
                               replacement=False)
sampler = StratifiedTimeseriesSubsampler(dataset,
                                         num_samples=num_samples,
                                         replacement=False,
                                         stratify=labels[:len(dataset)])
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)

for batch_input, batch_target in dataloader:
    print(f'Batch Inputs: {batch_input}')
    print(f'Batch Targets: {batch_target}')
    break

Batch Inputs: tensor([[714, 715, 716, 717, 718, 719, 720, 721, 722, 723],
        [678, 679, 680, 681, 682, 683, 684, 685, 686, 687],
        [573, 574, 575, 576, 577, 578, 579, 580, 581, 582],
        [190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
        [756, 757, 758, 759, 760, 761, 762, 763, 764, 765],
        [891, 892, 893, 894, 895, 896, 897, 898, 899, 900],
        [340, 341, 342, 343, 344, 345, 346, 347, 348, 349],
        [241, 242, 243, 244, 245, 246, 247, 248, 249, 250],
        [527, 528, 529, 530, 531, 532, 533, 534, 535, 536],
        [ 60,  61,  62,  63,  64,  65,  66,  67,  68,  69],
        [558, 559, 560, 561, 562, 563, 564, 565, 566, 567],
        [792, 793, 794, 795, 796, 797, 798, 799, 800, 801],
        [510, 511, 512, 513, 514, 515, 516, 517, 518, 519],
        [455, 456, 457, 458, 459, 460, 461, 462, 463, 464],
        [646, 647, 648, 649, 650, 651, 652, 653, 654, 655],
        [283, 284, 285, 286, 287, 288, 289, 290, 291, 292],
        [470, 471, 472, 47