In [40]:
from datastore.data import P3B3
from datastore.data import RandomData
from datastore.data import KuzushijiMNIST
from datastore.data import RandomMultiTaskData
from datastore.sampling.bootstrap import bootstrap

from torch.utils.data import DataLoader

In [41]:
# Random Dataset
random_data = RandomData(num_samples=10, num_classes=2, seed=None)
random_samples = bootstrap(random_data, num_bootstraps=2)

In [42]:
def get_dataloaders(sample):
    trainloader = DataLoader(sample.train, batch_size=1)
    validloader = DataLoader(sample.valid, batch_size=1)
    return trainloader, validloader

In [43]:
def get_set(dataloader, verbose=False):
    if verbose:
        print(f'number of samples: {len(dataloader)}')
        
    elems = set()
    for idx, (data, target) in enumerate(dataloader):
        elems.add(data.item())
        
        if verbose:
            print(data)
    
    return elems

In [47]:
def inspect_sample(sample, verbose=False):
    train_loader, valid_loader = get_dataloaders(sample)
    if verbose: print(f'Training set\n')
    train_elems = get_set(train_loader, verbose)
    if verbose: print(f'\nValidation set\n')
    valid_elems = get_set(valid_loader, verbose)
    intersection = train_elems.intersection(valid_elems)
    print(f'\nIntersection of training and validation sets: {intersection}')

In [48]:
def inspect_all_samples(samples, verbose=False):
    for idx, sample in enumerate(samples):
        print(f'\nSample {idx}')
        inspect_sample(sample, verbose)

In [49]:
inspect_all_samples(random_samples, verbose=True)


Sample 0
Training set

number of samples: 5
tensor([0.8472], dtype=torch.float64)
tensor([0.9642], dtype=torch.float64)
tensor([1.7783], dtype=torch.float64)
tensor([-1.3539], dtype=torch.float64)
tensor([-1.3539], dtype=torch.float64)

Validation set

number of samples: 6
tensor([-0.5503], dtype=torch.float64)
tensor([-0.8690], dtype=torch.float64)
tensor([-0.1275], dtype=torch.float64)
tensor([1.4629], dtype=torch.float64)
tensor([-0.1406], dtype=torch.float64)
tensor([0.1021], dtype=torch.float64)

Intersection of training and validation sets: set()

Sample 1
Training set

number of samples: 5
tensor([0.9642], dtype=torch.float64)
tensor([-0.1406], dtype=torch.float64)
tensor([-0.8690], dtype=torch.float64)
tensor([-0.1406], dtype=torch.float64)
tensor([0.8472], dtype=torch.float64)

Validation set

number of samples: 6
tensor([-0.5503], dtype=torch.float64)
tensor([-0.1275], dtype=torch.float64)
tensor([1.4629], dtype=torch.float64)
tensor([-1.3539], dtype=torch.float64)
tensor([1