In [1]:
from datastore.data import P3B3
from datastore.data import RandomData
from datastore.data import KuzushijiMNIST
from datastore.data import RandomMultiTaskData
from datastore.sampling.cross_validation import (
    stratified_split, multitask_stratified_split
)

from torch.utils.data import DataLoader

In [2]:
# Random Dataset
random_data = RandomData(num_samples=20, num_classes=2)
random_data_splits = stratified_split(random_data, num_splits=5, seed=13)

In [None]:
# Kuzushiji Dataset
kmnist = KuzushijiMNIST('/Users/yngtodd/data', partition='train', download=True)
kmnist_splits = stratified_split(kmnist, num_splits=5, seed=13)

In [None]:
# Random Multitask
random_multitask = RandomMultiTaskData(num_samples=100, num_tasks=3, num_classes=2)
multitask_splits = multitask_stratified_split(random_multitask, num_splits=5, label='task0')

In [None]:
# P3B3 Synthetic
#p3b3 = P3B3('/Users/yngtodd/data', partition='train')
#p3b3_splits = multitask_stratified_split(p3b3, num_splits=5, label='subsite')

In [None]:
def get_dataloaders(split):
    trainloader = DataLoader(split.train, batch_size=1)
    validloader = DataLoader(split.valid, batch_size=1)
    return trainloader, validloader

In [None]:
def count_label_ratio(dataloader, label=1):
    num_label = 0
    for idx, (_, target) in enumerate(dataloader):
        if target.item() == label:
            num_label += 1
            
    print(f'Proportion of label {label} in split: {num_label/len(dataloader)}')

In [None]:
def count_multitask_ratios(dataloader, tasks, label=1):
    for task in tasks:
        num_label = 0
        for idx, (_, target) in enumerate(dataloader):
            if target[task].item() == label:
                num_label += 1

        print(f'Proportion of label {label} of {task} in split: {num_label/len(dataloader)}')

In [None]:
def inspect_splits(splits, label=1):
    split_idx = 0
    for split in splits:
        print(f'Split: {split_idx}')
        trainloader, validloader = get_dataloaders(split)
        count_label_ratio(trainloader, label)
        count_label_ratio(validloader, label)
        split_idx += 1
        print('*' * 30)

In [None]:
def inspect_multitask_splits(splits, tasks, label=1):
    split_idx = 0
    for split in splits:
        print(f'Split: {split_idx}')
        trainloader, validloader = get_dataloaders(split)
        count_multitask_ratios(trainloader, tasks, label)
        count_multitask_ratios(validloader, tasks, label)
        split_idx += 1
        print('*' * 30)

In [None]:
inspect_splits(random_data_splits, label=1)

In [None]:
inspect_splits(kmnist_splits, label=2)

In [None]:
inspect_multitask_splits(multitask_splits, tasks=['task0', 'task1'], label=1)

In [None]:
# Note: split on subsite here
inspect_multitask_splits(
    p3b3_splits, 
    tasks=['subsite', 'laterality', 'behavior', 'grade'], 
    label=1
)