New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tf.data] graduate rejection_resample API from experimental to tf.data.Dataset #48894
[tf.data] graduate rejection_resample API from experimental to tf.data.Dataset #48894
Conversation
In [81]: import numpy as np
...: import tensorflow as tf
...:
...: init_dist = [0.6 , 0.4]
...: target_dist = [0.5, 0.5]
...: num_classes = len(init_dist)
...: num_samples = 10000
...: data_np = np.random.choice(num_classes, num_samples, p=init_dist)
...: dataset = tf.data.Dataset.from_tensor_slices(data_np)
...: vals = defaultdict(int)
...: for i in dataset:
...: vals[i.numpy()]+=1
...: print("Initial distribution: {}".format(vals))
Initial distribution: defaultdict(<class 'int'>, {1: 4040, 0: 5960})
In [82]: resampler = tf.data.experimental.rejection_resample(
...: class_func=lambda x: x,
...: target_dist=target_dist,
...: initial_dist=init_dist)
...:
...: dataset = dataset.apply(resampler)
...:
...: from collections import defaultdict
...: vals = defaultdict(int)
...: for i in dataset:
...: vals[i[-1].numpy()]+=1
...: print("Resampled distribution: {}".format(vals))
Proportion of examples rejected by sampler is high: [0.6][0.6 0.4][0 1]
Proportion of examples rejected by sampler is high: [0.6][0.6 0.4][0 1]
Proportion of examples rejected by sampler is high: [0.6][0.6 0.4][0 1]
Proportion of examples rejected by sampler is high: [0.6][0.6 0.4][0 1]
Proportion of examples rejected by sampler is high: [0.6][0.6 0.4][0 1]
Proportion of examples rejected by sampler is high: [0.6][0.6 0.4][0 1]
Proportion of examples rejected by sampler is high: [0.6][0.6 0.4][0 1]
Proportion of examples rejected by sampler is high: [0.6][0.6 0.4][0 1]
Proportion of examples rejected by sampler is high: [0.6][0.6 0.4][0 1]
Proportion of examples rejected by sampler is high: [0.6][0.6 0.4][0 1]
Resampled distribution: defaultdict(<class 'int'>, {1: 8080, 0: 5960}) cc: @jsimsa I was playing around with the API and observed some weird behavior. In the above example, it seems like elements are being added to the |
The dataset produced by In other words, it is not unexpected that the number of output elements is greater that the cardinality of original input dataset but it is unexpected that the distribution does not match the target distribution. This is related to an issue we have recently fixed for @yangustc07 FYI With my fix patched, the following program (which I adopted from your example):
Produces the following output:
|
6c91ddd
to
1e21166
Compare
PR #49009 has been raised as a pre-requisite for the current one. |
@kvignesh1420 Can you please resolve conflicts? Thanks! |
@gbaned the file changes in this PR conflict with the ones in pre-requisite PR's. I will resolve all the conflicts in this PR once the prereqs are merged. Hope it's fine. |
@kvignesh1420 Can you please resolve conflicts? Thanks! |
7996827
to
7e0ace2
Compare
@aaudiber could you please take a look? thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @kvignesh1420
A `Dataset` | ||
""" | ||
|
||
target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving all of our implementations into dataset_ops.py
is making dataset_ops.py
quite long and hard to navigate. In later PRs we should consider moving the dataset transformations from dataset_ops.py
into their own files, similar to what we do for experimental ops. It would also make graduating experimental ops more straightforward. @jsimsa what are your thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a valid concern and I think it would be a useful refactor. That said, we should do this in a manner which does not require updating callsites that import dataset_ops
. One option would be that we keep dataset_ops.py
as a shim to import symbols from the "per transformation" modules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aaudiber @jsimsa we might also want to consider the circular dependencies and file duplications during the promotion/refactor process if we are going to maintain separate files per transformation and use them in dataset_ops.py
. Also, if we are using dataset_ops.py
as a shim where the API layout is unaffected for the users, we might have to maintain a new file with the actual functionality of the API's in dataset_ops.py
so that circular dependencies can be prevented. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for making dataset_ops
a shim. The refactor should have no user-facing impact. We can use LazyLoader in the dataset impl files to handle the circular reference on dataset_ops.py
. Keeping dataset_ops
readable is much more important than avoiding circular dependencies between dataset implementations and dataset_ops
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay sounds good!
This PR graduates the
tf.data.experimental.rejection_resample
API intotf.data.Dataset.rejection_resample
by making the following changes:rejection_resample()
method toDatasetV2
class.rejection_resample_test
target from experimental/kernel_tests to kernel_testsTEST LOG