-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better iterators #122
Better iterators #122
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taking a first look. Great refacto, I think it is much clearer now! Left a few comments (probably more to come).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! And much needed refacto, very nice result !
Small suggestion: could add just a few words about DataSource in implementation_notes.md, e.g.:
## Data sources
The data used for training GFlowNets can come from a variety of sources. `DataSource` implements these different use-cases as individual iterators that collectively assemble the training batches before passing it to the trainer. Some of these use-cases include:
- Generating new trajectories on-policy
- Sampling trajectories from passed policies from a replay buffer
- Sampling trajectories from a fixed, offline dataset
`DataSource` also covers validation sets, including cases such as:
- Generating new trajectories (w.r.t a fixed dataset of conditioning goals)
- Evaluating the model's likelihood on trajectories from a fixed, offline dataset
This PR refactors the
SamplingIterator
class into aDataSource
class, where much of the existing functionality is decomposed into methods that can be combine with more clarity. Instead of interacting withSamplingIterator
through a long list of sometimes unclear arguments,DataSource
works by creating and combining iterators that perform more specific jobs.Changes:
SamplingIterator
is nowDataSource
get_worker_device()
. This avoids passing around a device object (which, because of workers, can get confusing)is_eval
member to algorithms, and shifts the responsibility of determining random action probabilities toGFNAlgorithm
SQLiteLogHook
to generalize logging as a hook that's added toDataSource
instancesAvgRewardHook
, a simple sampling hook to report average rewardMultiObjectiveStatsHook
TODO:
get_worker_device
and remove device objects being passed around