In [1]:
import torch
from itertools import cycle, zip_longest
import math
from torch.utils.data import Dataset, DataLoader, Sampler, TensorDataset
import numpy as np
import random
torch.manual_seed(0)
# random.seed(0)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f32ec1a9930>

In [3]:
class GWASDataset(Dataset):
    def __init__(self, data, labels):
        if not isinstance(data, torch.Tensor):
            self.data = torch.tensor(data, dtype=torch.float)
            self.labels = torch.tensor(labels, dtype=torch.float)
        else:
            self.data = data
            self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Return data (seq_len, batch, input_dim), label for index 
        return (self.data[idx], self.labels[idx])
    
class GroupSampler(Sampler):
    def __init__(self, data_source, grp_size, random_seed):
        self.data_source = data_source
        self.grp_size = grp_size
        self.data_size = grp_size*len(self)
        random.seed(random_seed)
        
    def __iter__(self):
        indices = list(range(len(self.data_source)))
        random.shuffle(indices)
        for i in range(0, self.data_size, self.grp_size):
            yield indices[i:i+self.grp_size]

    def __len__(self):
        return len(self.data_source)//self.grp_size

class BalancedBatchGroupSampler(Sampler):
    def __init__(self, dataset, batch_size, grp_size, random_seed):
        self.case_idxs = torch.where(dataset.labels==1)[0]
        self.cont_idxs = torch.where(dataset.labels==0)[0]
        
        assert batch_size % 2 == 0
        self.batch_size = batch_size
        
        self.case_sampler = GroupSampler(self.case_idxs, 
                                         grp_size=grp_size,
                                         random_seed=random_seed)
        self.cont_sampler = GroupSampler(self.cont_idxs, 
                                         grp_size=grp_size,
                                         random_seed=random_seed)
        
    def __iter__(self):
        batch = []
        for case, cont in zip_longest(cycle(self.case_sampler), self.cont_sampler):
            if case is None or cont is None:
                break
            batch.append(self.case_idxs[case])
            batch.append(self.cont_idxs[cont])
            if len(batch) == self.batch_size:
                random.shuffle(batch)
                yield batch
                batch = []
            
        if len(batch) != 0:
            random.shuffle(batch)
            yield batch

    def __len__(self):
        return math.ceil((len(self.cont_sampler)*2)/self.batch_size)



In [17]:
iids = torch.arange(60)
vals = torch.ones((60, 9))
x = torch.column_stack((iids, vals))
labels = torch.cat((torch.ones(30), torch.zeros(30)))

dataset = GWASDataset(x, labels)
batch_sampler = BalancedBatchGroupSampler(dataset, batch_size=4, grp_size=5, random_seed=1)
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)

print(len(dataloader))

c = torch.nn.Conv1d(10, 8, 1)
pool = torch.nn.AvgPool1d(5)
for epoch in range(5):
    print(f'Epoch: {epoch}')
    for i, d in enumerate(dataloader):
        print('Data groups')
        print(d[0][:, :, 0])
        print(f'Labels: {d[1][:, 0]}')
        print()
        # a = c(torch.transpose(d[0], 1, 2))
        # print(a.shape)
        # a = torch.squeeze(pool(a), dim=-1)
        # print(a.shape)
    print('\n')

3
Epoch: 0
Data groups
tensor([[39., 54., 52., 57., 34.],
        [56., 41., 32., 31., 38.],
        [ 1.,  5., 29.,  7., 20.],
        [26., 16., 11., 10., 23.]])
Labels: tensor([0., 0., 1., 1.])

Data groups
tensor([[49., 53., 40., 33., 35.],
        [ 9., 28., 17., 13.,  0.],
        [19., 22.,  6., 12., 21.],
        [48., 45., 44., 37., 46.]])
Labels: tensor([0., 1., 1., 0.])

Data groups
tensor([[14., 15.,  3.,  8.,  2.],
        [58., 47., 50., 59., 30.],
        [55., 43., 36., 51., 42.],
        [24., 25., 27., 18.,  4.]])
Labels: tensor([1., 0., 0., 1.])



Epoch: 1
Data groups
tensor([[59., 57., 49., 33., 31.],
        [ 4., 26., 22.,  9.,  3.],
        [48., 43., 32., 53., 54.],
        [ 0., 24., 23., 20.,  6.]])
Labels: tensor([0., 1., 0., 1.])

Data groups
tensor([[21., 28., 25.,  8., 10.],
        [14.,  2., 19., 17., 11.],
        [39., 38., 40., 34., 45.],
        [44., 51., 41., 56., 42.]])
Labels: tensor([1., 1., 0., 0.])

Data groups
tensor([[ 1., 18., 12., 16., 27