You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am attempting to implement a task sampler that samples 1-vs-all tasks. That is every task is a binary classification problem where the class 0 contains shots of one random label, and the class 1 contains shots of all other classes except the label that represents class 0.
I inherited from MetaDataset and did the following (my question is below the code):
`
def __ getitem __(self, index):
if not isinstance(index, int):
raise ValueError('The index of a `OneVsAllMetaDataset` must be an integer')
# create 2 datasets for the task: first one corresponds to label=index, second one contains all other labels
idx_set = [i for i in range(len(self.dataset))]
del idx_set[index]
# Use deepcopy on `Categorical` target transforms, to avoid any side
# effect across tasks.
dataset_one = ConcatTask([self.dataset[index]],
1,
target_transform=wrap_transform(Categorical(),
self._copy_categorical_one,
transform_type=Categorical))
dataset_vs_all = ConcatTask([self.dataset[i] for i in idx_set],
1,
target_transform=wrap_transform(Categorical(),
self._copy_categorical_vs_all,
transform_type=Categorical))
task = ConcatTask([dataset_one, dataset_vs_all],
self.num_classes_per_task)
if self.dataset_transform is not None:
task = self.dataset_transform(task)
return task
`
After applying the ClassSplitter and the BatchMetaDataLoader, I get correct tasks, but I am not happy with the labels. They look like this for a single task with 5 shots:
Hi, great work on the library!
I am attempting to implement a task sampler that samples 1-vs-all tasks. That is every task is a binary classification problem where the class 0 contains shots of one random label, and the class 1 contains shots of all other classes except the label that represents class 0.
I inherited from
MetaDataset
and did the following (my question is below the code):`
def __ getitem __(self, index):
`
After applying the
ClassSplitter
and theBatchMetaDataLoader
, I get correct tasks, but I am not happy with the labels. They look like this for a single task with 5 shots:tensor([[ 0, 0, 0, 0, 0, 211, 727, 613, 198, 435]])
They are the outputs of the Categorical transform. But I want this instead:
tensor([[ 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]])
since I want to represent a binary task. Do you have any hint how to fix this?
I would be happy to do a pull request with the code once it is fixed. 1-vs-all samplers are commonly used these days.
Cheers!
The text was updated successfully, but these errors were encountered: