In [14]:
import torch
from itertools import cycle, zip_longest
import math
from torch.utils.data import Dataset, DataLoader, Sampler, TensorDataset
import numpy as np
import random
torch.manual_seed(0)
# random.seed(0)

<torch._C.Generator at 0x7f181c09a930>

In [15]:
class GWASDataset(Dataset):
    def __init__(self, data, labels):
        if not isinstance(data, torch.Tensor):
            self.data = torch.tensor(data, dtype=torch.float)
            self.labels = torch.tensor(labels, dtype=torch.float)
        else:
            self.data = data
            self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Return data (seq_len, batch, input_dim), label for index 
        return (self.data[idx], self.labels[idx])
    
class GroupSampler(Sampler):
    def __init__(self, data_source, grp_size, random_seed):
        self.data_source = data_source
        self.grp_size = grp_size
        self.data_size = grp_size*len(self)
        random.seed(random_seed)
        
    def __iter__(self):
        indices = list(range(len(self.data_source)))
        random.shuffle(indices)
        for i in range(0, self.data_size, self.grp_size):
            yield indices[i:i+self.grp_size]

    def __len__(self):
        return len(self.data_source)//self.grp_size

class BalancedBatchGroupSampler(Sampler):
    def __init__(self, dataset, batch_size, grp_size, random_seed):
        self.case_idxs = torch.where(dataset.labels==1)[0]
        self.cont_idxs = torch.where(dataset.labels==0)[0]
        
        assert batch_size % 2 == 0
        self.batch_size = batch_size
        
        self.case_sampler = GroupSampler(self.case_idxs, 
                                         grp_size=grp_size,
                                         random_seed=random_seed)
        self.cont_sampler = GroupSampler(self.cont_idxs, 
                                         grp_size=grp_size,
                                         random_seed=random_seed)
        
    def __iter__(self):
        batch = []
        for case, cont in zip_longest(cycle(self.case_sampler), self.cont_sampler):
            if case is None or cont is None:
                break
            batch.append(self.case_idxs[case])
            batch.append(self.cont_idxs[cont])
            if len(batch) == self.batch_size:
                random.shuffle(batch)
                yield batch
                batch = []
            
        if len(batch) != 0:
            yield batch

    def __len__(self):
        return int(math.ceil((len(self.cont_sampler)*2)/self.batch_size))



In [25]:
iids = torch.arange(6000)
vals = torch.ones((6000, 4))
x = torch.column_stack((iids, vals))
labels = torch.cat((torch.ones(1000), torch.zeros(5000)))

dataset = GWASDataset(x, labels)
batch_sampler = BalancedBatchGroupSampler(dataset, batch_size=32, grp_size=4, random_seed=0)
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)

In [26]:
len(dataloader)

79

In [27]:
for _ in range(5):
    for i, d in enumerate(dataloader):
        if i == 0:
            print(d[0][:,:,0][0])
        

tensor([ 34.,  54., 399., 559.])
tensor([157.,  30., 768., 370.])
tensor([124., 534., 666., 614.])
tensor([227., 643., 289., 963.])
tensor([837., 200., 198., 284.])
