Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overlap between train and test sets within task on multiple splits #31

Merged
merged 8 commits into from Feb 20, 2020

Conversation

MarcCoru
Copy link
Contributor

@MarcCoru MarcCoru commented Feb 14, 2020

Hi,

Thanks for the clean and well-organized code of this repository.
I investigated the train/test split within each dataset and explored the ClassSplitter class.
I noticed that the training and testing partitions of an individual task can be overlapping with multiple class splits.
The random state is initialized once in self.seed() in

self.np_random = np.random.RandomState(seed=seed)

and then used for each task split in (e.g.) get_indices_task()
dataset_indices = self.np_random.permutation(num_samples)

Since the internal seed of the random state advances, samples from dataset_indices are redrawn at each function call and can overlap within multiple function calls.

I wrote this test code on the current version

class DemoTask(Task):
	def __init__(self):
	    super(DemoTask, self).__init__(None)
	    self._inputs = np.arange(10)

	def __len__(self):
	    return len(self._inputs)

	def __getitem__(self, index):
	    return self._inputs[index]

splitter = ClassSplitter(shuffle=True, num_train_per_class=5, num_test_per_class=5)
task = DemoTask()

# split task five times into train and test
for i in range(5):
	tasks_split = splitter(task)
	train_task = tasks_split["train"]
	test_task = tasks_split["test"]

        print("train split: "+str([train_task[i] for i in range(len(train_task))]))
        print("test split:  "+str([test_task[i] for i in range(len(train_task))]))

that returns

train split: [2, 8, 4, 9, 1]
test split:  [6, 7, 3, 0, 5]

train split: [3, 5, 1, 2, 9]
test split:  [8, 0, 6, 7, 4]

train split: [2, 3, 8, 4, 5]
test split:  [1, 0, 6, 9, 7]

train split: [6, 1, 9, 2, 7]
test split:  [5, 8, 0, 3, 4]

train split: [5, 2, 7, 4, 1]
test split:  [0, 6, 8, 9, 3]

Each split nicely separates samples into separate train and test datasets. However, there is an overlap of samples from multiple splits. For instance sample "4" is in train-test-train-test-train.

This pull request fixes the samples for train/test splits with the random state using the initial seed np.random.RandomState(self.random_state_seed) and then shuffles the separate split indices again with the running self.np_random_seed
I also added a pytest with two asserts:

  1. assert len(train_samples.intersection(test_samples)) == 0 checks if there is any overlap within one split (the current version passes this test)
  2. assert len(samples_in_all_test_splits.intersection(samples_in_all_train_splits)) == 0 checks if there is an overlap of samples between different splits (the current version fails this test)

The print output from above looks like this with this pull request:

train split: [4, 2, 8, 9, 1]
test split:  [6, 3, 7, 5, 0]

train split: [8, 9, 1, 2, 4]
test split:  [0, 7, 3, 5, 6]

train split: [1, 4, 9, 8, 2]
test split:  [3, 0, 5, 7, 6]

train split: [4, 1, 8, 9, 2]
test split:  [0, 7, 5, 3, 6]

train split: [2, 1, 4, 8, 9]
test split:  [5, 3, 0, 7, 6]

Thanks again for this repository. It is a pleasure writing working with it.

@MarcCoru
Copy link
Contributor Author

MarcCoru commented Feb 14, 2020

I believe that the checks failed because the expected values for test_seed_class_splitter() changed. The sequence of drawing samples has changed with the PR.

The values seem still the same, the sequence of values and train/test splits have changed

@tristandeleu
Copy link
Owner

This would be a great change! Having different images in two separate calls of the same task is not great, and this would solve this issue.

However fixing the seed this way is creating some issues with CombinationMetaDataset: if you have two tasks sharing the same class, then the examples will be the same in both cases. Ideally, you would want to have different samples for different tasks, even though they have a class in common.
Here is an example (where class 0 appears in both tasks):
image

One solution I can see is to make the seed dependent on the task id (a unique integer identifier for each task). This would require a slight change in the API of the dataset_transform, which would not be an issue. The only problem is that getting a unique integer identifier might not always be feasible because of the (combinatorially large) number of tasks in CombinationMetaDataset.

Here is the snippet to show the tasks:

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from torchmeta.datasets.helpers import omniglot

dataset = omniglot('data', ways=5, shots=5, seed=1, meta_train=True)

task1 = dataset[(0, 1, 2, 3, 4)]
task1_train = task1['train']
task2 = dataset[(5, 6, 7, 8, 0)]
task2_train = task2['train']

fig = plt.figure(figsize=(10, 5))
gs = gridspec.GridSpec(1, 2, figure=fig)

ax0 = fig.add_subplot(gs[0])
ax0.axis('off')
ax0.set_title('Task 1 - Index (0, 1, 2, 3, 4)')

gs0 = gs[0].subgridspec(5, 5)
for index in range(len(task1_train)):
    image, label = task1_train[index]
    ax = fig.add_subplot(gs0[index // 5, index % 5])
    ax.imshow(image[0].numpy(), cmap='gray')
    ax.axis('off')

ax1 = fig.add_subplot(gs[1])
ax1.axis('off')
ax1.set_title('Task 2 - Index (5, 6, 7, 8, 0)')

gs1 = gs[1].subgridspec(5, 5)
for index in range(len(task2_train)):
    image, label = task2_train[index]
    ax = fig.add_subplot(gs1[index // 5, index % 5])
    ax.imshow(image[0].numpy(), cmap='gray')
    ax.axis('off')

plt.show()

@tristandeleu
Copy link
Owner

Another option could be to have a cache somewhere in the splitter for task -> selected indices to include in the train/test datasets (in that case, task could be a tuple for CombinationMetaDataset).

@MarcCoru
Copy link
Contributor Author

MarcCoru commented Feb 19, 2020

One solution I can see is to make the seed dependent on the task id (a unique integer identifier for each task). This would require a slight change in the API of the dataset_transform, which would not be an issue. The only problem is that getting a unique integer identifier might not always be feasible because of the (combinatorially large) number of tasks in CombinationMetaDataset

What about using the hash function hash(task) for the random seed? It is defined by default but can be overwritten by __hash__(self) in the object. A hash should be unique for each task.

@tristandeleu
Copy link
Owner

tristandeleu commented Feb 19, 2020

What about using the hash function hash(task) for the random seed? It is defined by default but can be overwritten by __hash__(self) in the object. A hash should be unique for each task.

That's an even better idea! I would then suggest to

  • Change the API of torchmeta.utils.task.Dataset to include a required argument index
  • Change the __hash__ function for Task/ConcatTask/SubsetTask
  • Update the existing datasets to reflect these changes
  • Use the hash in Splitter

All these changes might be a bit heavy for this single PR, I will make the changes regarding the hash function in a separate PR today and you'll be able to use it to update the splitter. What do you think?

@MarcCoru
Copy link
Contributor Author

Great! that works out for me.

@MarcCoru
Copy link
Contributor Author

Hi Tristan,
I replaced the fixed seeds with the hash function.

I ran the plotting code. Class 0 draws now different samples for the tasks. So it looks fine to my eyes.
myplot

was a pleasure,
Marc

Copy link
Owner

@tristandeleu tristandeleu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except some minor comments, everything looks great! Thank you for the PR!

@@ -148,13 +152,16 @@ def get_indices_task(self, task):
num_samples, self._min_samples_per_class))

if self.shuffle:
dataset_indices = self.np_random.permutation(num_samples)
seed = hash(task) % (2 ** 32 - 1) # Seed must be between 0 and 2**32 - 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be just % (2 ** 32) to be < 2 ** 32. Also, what do you think about adding random_state_seed here as well as hash(task), to have possibly different datasets for different seeds? Something like

seed = (hash(task) + self.random_state_seed) % (2 ** 32)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. That makes sense.
I also removed the comment afterwards since the line was getting a bit long.

@@ -173,13 +180,16 @@ def get_indices_concattask(self, task):
self._min_samples_per_class))

if self.shuffle:
dataset_indices = self.np_random.permutation(num_samples)
seed = hash(task) % (2 ** 32 - 1) # Seed must be between 0 and 2**32 - 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above regarding (hash(task) + self.random_state_seed) % (2 ** 32)

@@ -297,7 +310,8 @@ def get_indices_task(self, task):
num_samples = (min_samples if self.force_equal_per_class
else len(class_indices))
if self.shuffle:
dataset_indices = self.np_random.permutation(num_samples)
seed = hash(task) % (2 ** 32 - 1) # Seed must be between 0 and 2**32 - 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above regarding (hash(task) + self.random_state_seed) % (2 ** 32)

@@ -327,14 +343,17 @@ def get_indices_concattask(self, task):
num_samples = (min_samples if self.force_equal_per_class
else len(dataset))
if self.shuffle:
dataset_indices = self.np_random.permutation(num_samples)
seed = hash(task) % (2 ** 32 - 1) # Seed must be between 0 and 2**32 - 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above regarding (hash(task) + self.random_state_seed) % (2 ** 32)

assert len(train_samples.intersection(test_samples)) == 0

#print("train split: " + str([train_task[i] for i in range(len(train_task))]))
#print("test split: " + str([test_task[i] for i in range(len(train_task))]))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove these comments?

@MarcCoru
Copy link
Contributor Author

Alright. I addressed the comments. Also, I removed the comments after the seeds since the line was getting a bit long. Also, I renamed the test from test_fixed_random_state_class_splitter to test_class_splitter_for_fold_overlaps to reflect the modifications.

@tristandeleu tristandeleu merged commit 63662e3 into tristandeleu:master Feb 20, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants