-
-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add class subset selection #151
Conversation
Codecov Report
@@ Coverage Diff @@
## main #151 +/- ##
==========================================
- Coverage 88.18% 88.08% -0.10%
==========================================
Files 44 44
Lines 4122 4156 +34
==========================================
+ Hits 3635 3661 +26
- Misses 487 495 +8
Continue to review full report at Codecov.
|
kale/loaddata/dataset_access.py
Outdated
return test_dataset | ||
else: | ||
sub_indices = [i for i in range(0, len(test_dataset)) if test_dataset[i][1] in class_ids] | ||
return torch.utils.data.Subset(test_dataset, sub_indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 72-76 is somehow a repetition of line 39-43. Is it possible and worthy to reduce such repetition (e.g. define that 4 lines as a function taking ids and dataset in)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I feel that is definitely worthy. I am thinking a private _get_subset function? Also, I wanted to ask whether this would be alright since this function will come in the 'read the docs' as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code will be reviewed and can be updated so you may make the changes good in your opinion and we will review. When not sure, check how pytorch implements something similar to learn.
tests/loaddata/test_digits_access.py
Outdated
) | ||
dataset_subsampled.prepare_data_loaders() | ||
|
||
assert len(dataset_subsampled) <= len(dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can have stronger assertions here. Since these digit datasets have the same number of samples for each class, we should have len(dataset_subsampled) == 0.3*len(dataset) since you take 3 out of 10 classes, right?
tests/loaddata/test_digits_access.py
Outdated
assert len(source_train) <= len(source.get_train()) | ||
assert len(source_test) <= len(source.get_test()) | ||
|
||
assert isinstance(source_train, torch.utils.data.Dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These type assertions are less useful, right? We used similar assertions when we do not have stronger assertions implementation. In your case, the above can be stronger assertions. Consider to remove these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with Haiping. In addition, my most assertions are too general and simple and only check some parameter type, because we want to achieve high coverage in the first step. However, from now on, we are going to improve our test example by replacing some assertions with stronger ones. Therefore, in your test example, you can explore assertions for a more detailed test than the current code.
tests/loaddata/test_digits_access.py
Outdated
@@ -51,3 +53,38 @@ def test_get_train_test(dataset_name, download_path): | |||
assert source.n_classes() == 10 | |||
assert isinstance(source_train, torch.utils.data.Dataset) | |||
assert isinstance(source_test, torch.utils.data.Dataset) | |||
|
|||
|
|||
@pytest.mark.parametrize("dataset_name", ALL) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may not be necessary to loop through all variations, if they were tested elsewhere. We can talk tomorrow.
tests/loaddata/test_video_access.py
Outdated
if dataset_subsampled.flow: | ||
dataset_subsampled._flow_source_by_split = {"train": subsampled_train_val[0]} | ||
dataset_subsampled._flow_target_by_split = {"train": subsampled_train_val[0]} | ||
assert len(dataset_subsampled) <= len(dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the comments for digits. Stronger versions are possible.
tests/loaddata/test_video_access.py
Outdated
assert isinstance(subsampled_train_val[1], torch.utils.data.Dataset) | ||
|
||
assert len(subsampled_train_val[0]) <= len(train_val[0]) | ||
assert len(subsampled_train_val[1]) <= len(train_val[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the comments for digits. Stronger versions are possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job. Please see the comments. If any further clarification is needed, we can discuss tomorrow (Monday) and/or Tuesday.
You can mark it as Ready for review
(bottom of PR) now.
@xianyuanliu Please take a look too before we meet on Monday.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well done. Need to reorganize the code. If you have other ideas, please comment.
kale/loaddata/multi_domain.py
Outdated
) | ||
|
||
logging.debug("Load source Test") | ||
self._source_by_split["test"] = self._source_access.get_test() | ||
self._source_by_split["test"] = self._source_access.get_test_class_subset(self.class_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may keep get_train_val()
and get_test()
in the mainstream because they are more clear to understand than get_test_class_subset()
. We can set a flag (if...else...) in get_train()
and get_test()
to trigger on the process to get specific class samples from the dataset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the issue with this is that get_train()
and get_test()
are being redefined by its child classes. So, similar flags would have to be handled by each of the child classes separately. get_train_val()
is not being re-defined, so I have added flags inside it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think another option could be to have the child classes redefine a different function (maybe _get_train()
private or get_train_all()
) and call this inside get_train()
with class_id flags?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think another option could be to have the child classes redefine a different function (maybe
_get_train()
private orget_train_all()
) and call this insideget_train()
with class_id flags?
It will be a good idea if feasible.
kale/loaddata/multi_domain.py
Outdated
logging.debug("Load target Test") | ||
self._target_by_split["test"] = self._target_access.get_test() | ||
self._target_by_split["test"] = self._target_access.get_test_class_subset(self.class_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like above.
kale/loaddata/dataset_access.py
Outdated
Returns: | ||
Dataset: a torch.utils.data.Dataset | ||
""" | ||
train_dataset = self.get_train() | ||
train_dataset = self.get_train_class_subset(class_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may give a flag class_ids
to get_train()
directly without changing self.get_train()
here.
tests/loaddata/test_digits_access.py
Outdated
assert len(source_train) <= len(source.get_train()) | ||
assert len(source_test) <= len(source.get_test()) | ||
|
||
assert isinstance(source_train, torch.utils.data.Dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with Haiping. In addition, my most assertions are too general and simple and only check some parameter type, because we want to achieve high coverage in the first step. However, from now on, we are going to improve our test example by replacing some assertions with stronger ones. Therefore, in your test example, you can explore assertions for a more detailed test than the current code.
A simpler solution could be to keep a single separate function for getting class-subsets and using this function in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job!
It will be better to put get_class subset
in the class Dataset Access
if possible. I think it will be easier for people to find it. get_class subset
is also an enhancement for Dataset Access
.
self._flow_source_by_split["test"] = self._source_access_dict["flow"].get_test() | ||
self._flow_target_by_split["test"] = self._target_access_dict["flow"].get_test() | ||
if self.class_ids is not None: | ||
self._flow_source_by_split["test"] = get_class_subset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we put this function in the DatasetAccess
? Is it feasible?
For example,
self._flow_source_by_split["test"] = self._flow_source_by_split["test"].get_class_subset(self.class_ids)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes agreed, that will be better. Thanks, I'll update it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh, so it may not be feasible as is, since self._flow_source_by_split["test"]
is a simple torch.Dataset
object and not a DatasetAccess
object.
The function can be put inside DatasetAccess class as a static function, so it will be used like: dataset = DatasetAccess.get_class_subset(dataset, class_ids)
. Would this be appropriate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I ignore the type. No need to change. That seems a little weird :) The current one is okay but it only has a few lines in kale/utils/class_subset.py
. Do we have any other choice to put this function? We can keep the current now and discuss with Haiping at the meeting. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well done! Thanks!
tests/loaddata/test_digits_access.py
Outdated
|
||
@pytest.mark.parametrize("class_subset", CLASS_SUBSETS) | ||
@pytest.mark.parametrize("val_ratio", VAL_RATIO) | ||
def test_class_subsampling(class_subset, val_ratio, download_path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try to avoid using "subsampling"
Associated with #149 (comment).
Description
Adding class subset selection to core kale API.
This will allow datasets to easily select only a subset of classes for training, validation and testing.
API update description
Adding
class_ids=[id_0, id_1 ...]
parameter in initialisations of:should make it use class subset data. With no parameter or None value, dataset will use all classes.
Status
Ready
Types of changes