Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add class subset selection #151

Merged
merged 20 commits into from
Jun 8, 2021
Merged

Add class subset selection #151

merged 20 commits into from
Jun 8, 2021

Conversation

mustafa1728
Copy link
Contributor

@mustafa1728 mustafa1728 commented Jun 4, 2021

Associated with #149 (comment).

Description

Adding class subset selection to core kale API.
This will allow datasets to easily select only a subset of classes for training, validation and testing.

API update description

Adding class_ids=[id_0, id_1 ...] parameter in initialisations of:

should make it use class subset data. With no parameter or None value, dataset will use all classes.

Status

Ready

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • In-line docstrings updated.

@mustafa1728 mustafa1728 added work-in-progress Work in progress that should NOT be merged new feature New feature/module (including request) labels Jun 4, 2021
@mustafa1728 mustafa1728 self-assigned this Jun 4, 2021
@codecov-commenter
Copy link

codecov-commenter commented Jun 4, 2021

Codecov Report

Merging #151 (2a230eb) into main (8177ad9) will decrease coverage by 0.09%.
The diff coverage is 48.57%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #151      +/-   ##
==========================================
- Coverage   88.18%   88.08%   -0.10%     
==========================================
  Files          44       44              
  Lines        4122     4156      +34     
==========================================
+ Hits         3635     3661      +26     
- Misses        487      495       +8     
Impacted Files Coverage Δ
kale/loaddata/video_multi_domain.py 74.80% <40.00%> (-0.90%) ⬇️
kale/loaddata/multi_domain.py 87.80% <50.00%> (-0.59%) ⬇️
kale/loaddata/dataset_access.py 89.47% <100.00%> (+1.97%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8177ad9...2a230eb. Read the comment docs.

@mustafa1728 mustafa1728 marked this pull request as draft June 4, 2021 04:38
@mustafa1728 mustafa1728 changed the title Class subsampling Add Class subsampling Jun 5, 2021
return test_dataset
else:
sub_indices = [i for i in range(0, len(test_dataset)) if test_dataset[i][1] in class_ids]
return torch.utils.data.Subset(test_dataset, sub_indices)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 72-76 is somehow a repetition of line 39-43. Is it possible and worthy to reduce such repetition (e.g. define that 4 lines as a function taking ids and dataset in)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I feel that is definitely worthy. I am thinking a private _get_subset function? Also, I wanted to ask whether this would be alright since this function will come in the 'read the docs' as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code will be reviewed and can be updated so you may make the changes good in your opinion and we will review. When not sure, check how pytorch implements something similar to learn.

)
dataset_subsampled.prepare_data_loaders()

assert len(dataset_subsampled) <= len(dataset)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can have stronger assertions here. Since these digit datasets have the same number of samples for each class, we should have len(dataset_subsampled) == 0.3*len(dataset) since you take 3 out of 10 classes, right?

assert len(source_train) <= len(source.get_train())
assert len(source_test) <= len(source.get_test())

assert isinstance(source_train, torch.utils.data.Dataset)
Copy link
Member

@haipinglu haipinglu Jun 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These type assertions are less useful, right? We used similar assertions when we do not have stronger assertions implementation. In your case, the above can be stronger assertions. Consider to remove these.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Haiping. In addition, my most assertions are too general and simple and only check some parameter type, because we want to achieve high coverage in the first step. However, from now on, we are going to improve our test example by replacing some assertions with stronger ones. Therefore, in your test example, you can explore assertions for a more detailed test than the current code.

@@ -51,3 +53,38 @@ def test_get_train_test(dataset_name, download_path):
assert source.n_classes() == 10
assert isinstance(source_train, torch.utils.data.Dataset)
assert isinstance(source_test, torch.utils.data.Dataset)


@pytest.mark.parametrize("dataset_name", ALL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may not be necessary to loop through all variations, if they were tested elsewhere. We can talk tomorrow.

if dataset_subsampled.flow:
dataset_subsampled._flow_source_by_split = {"train": subsampled_train_val[0]}
dataset_subsampled._flow_target_by_split = {"train": subsampled_train_val[0]}
assert len(dataset_subsampled) <= len(dataset)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comments for digits. Stronger versions are possible.

assert isinstance(subsampled_train_val[1], torch.utils.data.Dataset)

assert len(subsampled_train_val[0]) <= len(train_val[0])
assert len(subsampled_train_val[1]) <= len(train_val[1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comments for digits. Stronger versions are possible.

Copy link
Member

@haipinglu haipinglu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job. Please see the comments. If any further clarification is needed, we can discuss tomorrow (Monday) and/or Tuesday.
You can mark it as Ready for review (bottom of PR) now.

@xianyuanliu Please take a look too before we meet on Monday.

@mustafa1728 mustafa1728 marked this pull request as ready for review June 6, 2021 10:42
Copy link
Member

@xianyuanliu xianyuanliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well done. Need to reorganize the code. If you have other ideas, please comment.

)

logging.debug("Load source Test")
self._source_by_split["test"] = self._source_access.get_test()
self._source_by_split["test"] = self._source_access.get_test_class_subset(self.class_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may keep get_train_val() and get_test() in the mainstream because they are more clear to understand than get_test_class_subset(). We can set a flag (if...else...) in get_train() and get_test() to trigger on the process to get specific class samples from the dataset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue with this is that get_train() and get_test() are being redefined by its child classes. So, similar flags would have to be handled by each of the child classes separately. get_train_val() is not being re-defined, so I have added flags inside it.

Copy link
Contributor Author

@mustafa1728 mustafa1728 Jun 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think another option could be to have the child classes redefine a different function (maybe _get_train() private or get_train_all()) and call this inside get_train() with class_id flags?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think another option could be to have the child classes redefine a different function (maybe _get_train() private or get_train_all()) and call this inside get_train() with class_id flags?

It will be a good idea if feasible.

logging.debug("Load target Test")
self._target_by_split["test"] = self._target_access.get_test()
self._target_by_split["test"] = self._target_access.get_test_class_subset(self.class_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like above.

Returns:
Dataset: a torch.utils.data.Dataset
"""
train_dataset = self.get_train()
train_dataset = self.get_train_class_subset(class_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may give a flag class_ids to get_train() directly without changing self.get_train() here.

kale/loaddata/dataset_access.py Outdated Show resolved Hide resolved
assert len(source_train) <= len(source.get_train())
assert len(source_test) <= len(source.get_test())

assert isinstance(source_train, torch.utils.data.Dataset)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Haiping. In addition, my most assertions are too general and simple and only check some parameter type, because we want to achieve high coverage in the first step. However, from now on, we are going to improve our test example by replacing some assertions with stronger ones. Therefore, in your test example, you can explore assertions for a more detailed test than the current code.

@mustafa1728
Copy link
Contributor Author

A simpler solution could be to keep a single separate function for getting class-subsets and using this function in MultiDomainDatasets and VideoMultiDomainDatasets. The advantage is that the base DatasetAccess class is not modified and it becomes a bit easier to understand as well. The last two commits are in this direction and this separate function is in kale.utils.class_subset.

Copy link
Member

@xianyuanliu xianyuanliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job!

It will be better to put get_class subset in the class Dataset Access if possible. I think it will be easier for people to find it. get_class subset is also an enhancement for Dataset Access.

self._flow_source_by_split["test"] = self._source_access_dict["flow"].get_test()
self._flow_target_by_split["test"] = self._target_access_dict["flow"].get_test()
if self.class_ids is not None:
self._flow_source_by_split["test"] = get_class_subset(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we put this function in the DatasetAccess? Is it feasible?
For example,
self._flow_source_by_split["test"] = self._flow_source_by_split["test"].get_class_subset(self.class_ids)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes agreed, that will be better. Thanks, I'll update it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, so it may not be feasible as is, since self._flow_source_by_split["test"] is a simple torch.Dataset object and not a DatasetAccess object.

The function can be put inside DatasetAccess class as a static function, so it will be used like: dataset = DatasetAccess.get_class_subset(dataset, class_ids). Would this be appropriate?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I ignore the type. No need to change. That seems a little weird :) The current one is okay but it only has a few lines in kale/utils/class_subset.py. Do we have any other choice to put this function? We can keep the current now and discuss with Haiping at the meeting. Thanks!

Copy link
Member

@xianyuanliu xianyuanliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well done! Thanks!


@pytest.mark.parametrize("class_subset", CLASS_SUBSETS)
@pytest.mark.parametrize("val_ratio", VAL_RATIO)
def test_class_subsampling(class_subset, val_ratio, download_path):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try to avoid using "subsampling"

@haipinglu haipinglu enabled auto-merge June 8, 2021 14:04
@haipinglu haipinglu merged commit 136ef5f into main Jun 8, 2021
@haipinglu haipinglu deleted the class-subsampling branch June 8, 2021 15:57
@bobturneruk bobturneruk mentioned this pull request Jun 8, 2021
1 task
@github-actions github-actions bot mentioned this pull request Jun 21, 2021
1 task
@mustafa1728 mustafa1728 changed the title Add Class subsampling Add class subset selection Jun 21, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new feature New feature/module (including request) work-in-progress Work in progress that should NOT be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants