In [10]:
from datastore.data import RandomData
from datastore.data import KuzushijiMNIST
from datastore.sampling.cross_validation import stratified_split

from torch.utils.data import DataLoader

In [18]:
# Random Dataset
random_data = RandomData(num_samples=100, num_classes=2)
random_data_splits = stratified_split(random_data, num_splits=5, seed=13)

In [19]:
# Kuzushiji Dataset
kmnist = KuzushijiMNIST('/Users/yngtodd/data', partition='train', download=True)
kmnist_splits = stratified_split(kmnist, num_splits=5, seed=13)

In [20]:
def get_dataloaders(split):
    trainloader = DataLoader(split.train, batch_size=1)
    validloader = DataLoader(split.valid, batch_size=1)
    return trainloader, validloader

In [21]:
def count_label_ratio(dataloader, label=1):
    num_label = 0
    for idx, (_, target) in enumerate(dataloader):
        if target.item() == label:
            num_label += 1
            
    print(f'Proportion of label {label} in split: {num_label/len(dataloader)}')

In [22]:
def inspect_splits(splits, label=1):
    split_idx = 0
    for split in splits:
        print(f'Split: {split_idx}')
        trainloader, validloader = get_dataloaders(split)
        count_label_ratio(trainloader, label)
        count_label_ratio(validloader, label)
        split_idx += 1
        print('*' * 30)

In [23]:
inspect_splits(random_data_splits, label=1)

Split: 0
Proportion of label 1 in split: 0.569620253164557
Proportion of label 1 in split: 0.5714285714285714
******************************
Split: 1
Proportion of label 1 in split: 0.569620253164557
Proportion of label 1 in split: 0.5714285714285714
******************************
Split: 2
Proportion of label 1 in split: 0.575
Proportion of label 1 in split: 0.55
******************************
Split: 3
Proportion of label 1 in split: 0.5679012345679012
Proportion of label 1 in split: 0.5789473684210527
******************************
Split: 4
Proportion of label 1 in split: 0.5679012345679012
Proportion of label 1 in split: 0.5789473684210527
******************************


In [8]:
inspect_splits(kmnist_splits, label=2)

Split: 0
Proportion of label 2 in split: 0.1
Proportion of label 2 in split: 0.1
******************************
Split: 1
Proportion of label 2 in split: 0.1
Proportion of label 2 in split: 0.1
******************************
Split: 2
Proportion of label 2 in split: 0.1
Proportion of label 2 in split: 0.1
******************************
Split: 3
Proportion of label 2 in split: 0.1
Proportion of label 2 in split: 0.1
******************************
Split: 4
Proportion of label 2 in split: 0.1
Proportion of label 2 in split: 0.1
******************************
