In [1]:
from datastore.data import RandomData
from datastore.data import KuzushijiMNIST
from datastore.data import RandomMultiTaskData
from datastore.sampling.cross_validation import (
    stratified_split, multitask_stratified_split
)

from torch.utils.data import DataLoader

In [2]:
# Random Dataset
random_data = RandomData(num_samples=100, num_classes=2)
random_data_splits = stratified_split(random_data, num_splits=5, seed=13)

In [3]:
# Kuzushiji Dataset
kmnist = KuzushijiMNIST('/Users/yngtodd/data', partition='train', download=True)
kmnist_splits = stratified_split(kmnist, num_splits=5, seed=13)

In [4]:
# Random Multitask
random_multitask = RandomMultiTaskData(num_samples=100, num_tasks=3, num_classes=2)
multitask_splits = multitask_stratified_split(random_multitask, num_splits=5, label='task0')

In [5]:
def get_dataloaders(split):
    trainloader = DataLoader(split.train, batch_size=1)
    validloader = DataLoader(split.valid, batch_size=1)
    return trainloader, validloader

In [6]:
def count_label_ratio(dataloader, label=1):
    num_label = 0
    for idx, (_, target) in enumerate(dataloader):
        if target.item() == label:
            num_label += 1
            
    print(f'Proportion of label {label} in split: {num_label/len(dataloader)}')

In [7]:
def count_multitask_ratios(dataloader, tasks, label=1):
    for task in tasks:
        num_label = 0
        for idx, (_, target) in enumerate(dataloader):
            if target[task].item() == label:
                num_label += 1

        print(f'Proportion of label {label} of {task} in split: {num_label/len(dataloader)}')

In [8]:
def inspect_splits(splits, label=1):
    split_idx = 0
    for split in splits:
        print(f'Split: {split_idx}')
        trainloader, validloader = get_dataloaders(split)
        count_label_ratio(trainloader, label)
        count_label_ratio(validloader, label)
        split_idx += 1
        print('*' * 30)

In [9]:
def inspect_multitask_splits(splits, tasks, label=1):
    split_idx = 0
    for split in splits:
        print(f'Split: {split_idx}')
        trainloader, validloader = get_dataloaders(split)
        count_multitask_ratios(trainloader, tasks, label)
        count_multitask_ratios(validloader, tasks, label)
        split_idx += 1
        print('*' * 30)

In [10]:
inspect_splits(random_data_splits, label=1)

Split: 0
Proportion of label 1 in split: 0.4430379746835443
Proportion of label 1 in split: 0.42857142857142855
******************************
Split: 1
Proportion of label 1 in split: 0.4375
Proportion of label 1 in split: 0.45
******************************
Split: 2
Proportion of label 1 in split: 0.4375
Proportion of label 1 in split: 0.45
******************************
Split: 3
Proportion of label 1 in split: 0.4375
Proportion of label 1 in split: 0.45
******************************
Split: 4
Proportion of label 1 in split: 0.4444444444444444
Proportion of label 1 in split: 0.42105263157894735
******************************


In [11]:
inspect_splits(kmnist_splits, label=2)

Split: 0
Proportion of label 2 in split: 0.1
Proportion of label 2 in split: 0.1
******************************
Split: 1
Proportion of label 2 in split: 0.1
Proportion of label 2 in split: 0.1
******************************
Split: 2
Proportion of label 2 in split: 0.1
Proportion of label 2 in split: 0.1
******************************
Split: 3
Proportion of label 2 in split: 0.1
Proportion of label 2 in split: 0.1
******************************
Split: 4
Proportion of label 2 in split: 0.1
Proportion of label 2 in split: 0.1
******************************


In [13]:
inspect_multitask_splits(multitask_splits, tasks=['task0', 'task1'], label=1)

Split: 0
Proportion of label 1 of task0 in split: 0.4430379746835443
Proportion of label 1 of task1 in split: 0.46835443037974683
Proportion of label 1 of task0 in split: 0.42857142857142855
Proportion of label 1 of task1 in split: 0.5238095238095238
******************************
Split: 1
Proportion of label 1 of task0 in split: 0.4375
Proportion of label 1 of task1 in split: 0.4375
Proportion of label 1 of task0 in split: 0.45
Proportion of label 1 of task1 in split: 0.65
******************************
Split: 2
Proportion of label 1 of task0 in split: 0.4375
Proportion of label 1 of task1 in split: 0.5
Proportion of label 1 of task0 in split: 0.45
Proportion of label 1 of task1 in split: 0.4
******************************
Split: 3
Proportion of label 1 of task0 in split: 0.4375
Proportion of label 1 of task1 in split: 0.475
Proportion of label 1 of task0 in split: 0.45
Proportion of label 1 of task1 in split: 0.5
******************************
Split: 4
Proportion of label 1 of task0 