In [1]:
import numpy as np
from torch.utils.data import DataLoader

from datastore.api.data import InMemoryDataset, MultiTaskDataset

In [27]:
class RandomMultiTaskData(InMemoryDataset, MultiTaskDataset):
    """ Random multiclass dataset - Useful for quick iterating """

    def __init__(self, num_samples: int, num_tasks: int, num_classes: int, seed: int=13):
        np.random.seed(seed)
        self.data = np.random.randn(num_samples)
        self.create_labels(num_tasks, num_classes, num_samples)
        
    def create_labels(self, num_tasks, num_classes, num_samples):
        for i in range(num_tasks):
            self.labels[f'task{i}'] = np.random.randint(num_classes, size=num_samples)

    def load_data(self):
        return self.data, self.labels

    def __repr__(self):
        return f'Random multitask supervised dataset'

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.index_labels(idx)

In [28]:
multitask_data = RandomMultiTaskData(num_samples=10, num_tasks=2, num_classes=10)

In [29]:
dataloader = DataLoader(multitask_data, batch_size=2)

In [30]:
for idx, (data, labels) in enumerate(dataloader):
    print(f'Data: {data}\nLabels: {labels}\n')

Data: tensor([-0.7124,  0.7538], dtype=torch.float64)
Labels: {'task0': tensor([1, 2]), 'task1': tensor([5, 8])}

Data: tensor([-0.0445,  0.4518], dtype=torch.float64)
Labels: {'task0': tensor([8, 8]), 'task1': tensor([3, 8])}

Data: tensor([1.3451, 0.5323], dtype=torch.float64)
Labels: {'task0': tensor([6, 2]), 'task1': tensor([5, 1])}

Data: tensor([1.3502, 0.8612], dtype=torch.float64)
Labels: {'task0': tensor([4, 5]), 'task1': tensor([8, 1])}

Data: tensor([ 1.4787, -1.0454], dtype=torch.float64)
Labels: {'task0': tensor([7, 3]), 'task1': tensor([7, 7])}



In [33]:
multitask_data.data

array([-0.71239066,  0.75376638, -0.04450308,  0.45181234,  1.34510171,
        0.53233789,  1.3501879 ,  0.86121137,  1.47868574, -1.04537713])

In [32]:
multitask_data.labels

{'task0': array([1, 2, 8, 8, 6, 2, 4, 5, 7, 3]),
 'task1': array([5, 8, 3, 8, 5, 1, 8, 1, 7, 7])}