New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Overlap between train and test sets within task on multiple splits #31
Conversation
I believe that the checks failed because the expected values for The values seem still the same, the sequence of values and train/test splits have changed |
Another option could be to have a cache somewhere in the splitter for |
What about using the hash function hash(task) for the random seed? It is defined by default but can be overwritten by |
That's an even better idea! I would then suggest to
All these changes might be a bit heavy for this single PR, I will make the changes regarding the hash function in a separate PR today and you'll be able to use it to update the splitter. What do you think? |
Great! that works out for me. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Except some minor comments, everything looks great! Thank you for the PR!
torchmeta/transforms/splitters.py
Outdated
@@ -148,13 +152,16 @@ def get_indices_task(self, task): | |||
num_samples, self._min_samples_per_class)) | |||
|
|||
if self.shuffle: | |||
dataset_indices = self.np_random.permutation(num_samples) | |||
seed = hash(task) % (2 ** 32 - 1) # Seed must be between 0 and 2**32 - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be just % (2 ** 32)
to be < 2 ** 32. Also, what do you think about adding random_state_seed
here as well as hash(task)
, to have possibly different datasets for different seeds? Something like
seed = (hash(task) + self.random_state_seed) % (2 ** 32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree. That makes sense.
I also removed the comment afterwards since the line was getting a bit long.
torchmeta/transforms/splitters.py
Outdated
@@ -173,13 +180,16 @@ def get_indices_concattask(self, task): | |||
self._min_samples_per_class)) | |||
|
|||
if self.shuffle: | |||
dataset_indices = self.np_random.permutation(num_samples) | |||
seed = hash(task) % (2 ** 32 - 1) # Seed must be between 0 and 2**32 - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above regarding (hash(task) + self.random_state_seed) % (2 ** 32)
torchmeta/transforms/splitters.py
Outdated
@@ -297,7 +310,8 @@ def get_indices_task(self, task): | |||
num_samples = (min_samples if self.force_equal_per_class | |||
else len(class_indices)) | |||
if self.shuffle: | |||
dataset_indices = self.np_random.permutation(num_samples) | |||
seed = hash(task) % (2 ** 32 - 1) # Seed must be between 0 and 2**32 - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above regarding (hash(task) + self.random_state_seed) % (2 ** 32)
torchmeta/transforms/splitters.py
Outdated
@@ -327,14 +343,17 @@ def get_indices_concattask(self, task): | |||
num_samples = (min_samples if self.force_equal_per_class | |||
else len(dataset)) | |||
if self.shuffle: | |||
dataset_indices = self.np_random.permutation(num_samples) | |||
seed = hash(task) % (2 ** 32 - 1) # Seed must be between 0 and 2**32 - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above regarding (hash(task) + self.random_state_seed) % (2 ** 32)
torchmeta/tests/test_splitters.py
Outdated
assert len(train_samples.intersection(test_samples)) == 0 | ||
|
||
#print("train split: " + str([train_task[i] for i in range(len(train_task))])) | ||
#print("test split: " + str([test_task[i] for i in range(len(train_task))])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove these comments?
Alright. I addressed the comments. Also, I removed the comments after the seeds since the line was getting a bit long. Also, I renamed the test from |
Hi,
Thanks for the clean and well-organized code of this repository.
I investigated the train/test split within each dataset and explored the
ClassSplitter
class.I noticed that the training and testing partitions of an individual task can be overlapping with multiple class splits.
The random state is initialized once in
self.seed()
inpytorch-meta/torchmeta/transforms/splitters.py
Line 17 in 35f4cc2
and then used for each task split in (e.g.)
get_indices_task()
pytorch-meta/torchmeta/transforms/splitters.py
Line 176 in 35f4cc2
Since the internal seed of the random state advances, samples from
dataset_indices
are redrawn at each function call and can overlap within multiple function calls.I wrote this test code on the current version
that returns
Each split nicely separates samples into separate train and test datasets. However, there is an overlap of samples from multiple splits. For instance sample "4" is in
train-test-train-test-train
.This pull request fixes the samples for train/test splits with the random state using the initial seed
np.random.RandomState(self.random_state_seed)
and then shuffles the separate split indices again with the runningself.np_random_seed
I also added a
pytest
with two asserts:assert len(train_samples.intersection(test_samples)) == 0
checks if there is any overlap within one split (the current version passes this test)assert len(samples_in_all_test_splits.intersection(samples_in_all_train_splits)) == 0
checks if there is an overlap of samples between different splits (the current version fails this test)The print output from above looks like this with this pull request:
Thanks again for this repository. It is a pleasure writing working with it.