In [10]:
import numpy as np
from torch.utils.data import DataLoader

from datastore.data import RandomMultiTaskData
from datastore.api.data import Subset

In [2]:
multitask_data = RandomMultiTaskData(num_samples=10, num_tasks=2, num_classes=10)

In [3]:
dataloader = DataLoader(multitask_data, batch_size=2)

In [4]:
for idx, (data, labels) in enumerate(dataloader):
    print(f'Data: {data}\nLabels: {labels}\n')

Data: tensor([-0.7124,  0.7538], dtype=torch.float64)
Labels: {'task0': tensor([1, 2]), 'task1': tensor([5, 8])}

Data: tensor([-0.0445,  0.4518], dtype=torch.float64)
Labels: {'task0': tensor([8, 8]), 'task1': tensor([3, 8])}

Data: tensor([1.3451, 0.5323], dtype=torch.float64)
Labels: {'task0': tensor([6, 2]), 'task1': tensor([5, 1])}

Data: tensor([1.3502, 0.8612], dtype=torch.float64)
Labels: {'task0': tensor([4, 5]), 'task1': tensor([8, 1])}

Data: tensor([ 1.4787, -1.0454], dtype=torch.float64)
Labels: {'task0': tensor([7, 3]), 'task1': tensor([7, 7])}



In [5]:
multitask_data.data

array([-0.71239066,  0.75376638, -0.04450308,  0.45181234,  1.34510171,
        0.53233789,  1.3501879 ,  0.86121137,  1.47868574, -1.04537713])

In [6]:
multitask_data.labels

{'task0': array([1, 2, 8, 8, 6, 2, 4, 5, 7, 3]),
 'task1': array([5, 8, 3, 8, 5, 1, 8, 1, 7, 7])}

In [8]:
multitask_data.dataframe()

Unnamed: 0,data,task0,task1
0,-0.712391,1,5
1,0.753766,2,8
2,-0.044503,8,3
3,0.451812,8,8
4,1.345102,6,5
5,0.532338,2,1
6,1.350188,4,8
7,0.861211,5,1
8,1.478686,7,7
9,-1.045377,3,7


### Test the ability to subset with labels contained in dictionaries

In [11]:
# grab the first two indices
subset = Subset(multitask_data, indices=[0,1])

In [12]:
loader = DataLoader(subset, batch_size=2)

In [13]:
for (x, y) in loader:
    print(f'Data: {x}\nLabel: {y}')

Data: tensor([-0.7124,  0.7538], dtype=torch.float64)
Label: {'task0': tensor([1, 2]), 'task1': tensor([5, 8])}


### Test methods common to MultiTaskDatasets

In [14]:
# Find out which tasks we are working with
multitask_data.get_tasks()

dict_keys(['task0', 'task1'])

In [15]:
# Get the labels for a task
multitask_data.get_label('task1')

array([5, 8, 3, 8, 5, 1, 8, 1, 7, 7])

In [16]:
# Delete a task
multitask_data.del_label('task1')

In [17]:
multitask_data.get_tasks()

dict_keys(['task0'])