Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better iterators #122

Merged
merged 14 commits into from
Mar 11, 2024
Merged

Better iterators #122

merged 14 commits into from
Mar 11, 2024

Conversation

bengioe
Copy link
Collaborator

@bengioe bengioe commented Mar 1, 2024

This PR refactors the SamplingIterator class into a DataSource class, where much of the existing functionality is decomposed into methods that can be combine with more clarity. Instead of interacting with SamplingIterator through a long list of sometimes unclear arguments, DataSource works by creating and combining iterators that perform more specific jobs.

Changes:

  • SamplingIterator is now DataSource
  • refactor around devices, device-needing code should now call get_worker_device(). This avoids passing around a device object (which, because of workers, can get confusing)
  • adds an is_eval member to algorithms, and shifts the responsibility of determining random action probabilities to GFNAlgorithm
  • adds SQLiteLogHook to generalize logging as a hook that's added to DataSource instances
  • adds AvgRewardHook, a simple sampling hook to report average reward
  • fixes minor MOO incongruities
  • fixes Nested dataclasses do not reinitialize #123 , whereby nested dataclasses would be only initialized once
  • fixes a timeout condition typo in MultiObjectiveStatsHook

TODO:

  • address TODO items in this draft PR
  • complete transition to get_worker_device and remove device objects being passed around

Copy link
Contributor

@julienroyd julienroyd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking a first look. Great refacto, I think it is much clearer now! Left a few comments (probably more to come).

src/gflownet/data/data_source.py Show resolved Hide resolved
src/gflownet/trainer.py Outdated Show resolved Hide resolved
src/gflownet/trainer.py Outdated Show resolved Hide resolved
@bengioe bengioe marked this pull request as ready for review March 7, 2024 16:42
Copy link
Contributor

@julienroyd julienroyd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! And much needed refacto, very nice result !

Small suggestion: could add just a few words about DataSource in implementation_notes.md, e.g.:

## Data sources

The data used for training GFlowNets can come from a variety of sources. `DataSource` implements these different use-cases as individual iterators that collectively assemble the training batches before passing it to the trainer. Some of these use-cases include:
- Generating new trajectories on-policy
- Sampling trajectories from passed policies from a replay buffer
- Sampling trajectories from a fixed, offline dataset 

`DataSource` also covers validation sets, including cases such as:
- Generating new trajectories (w.r.t a fixed dataset of conditioning goals)
- Evaluating the model's likelihood on trajectories from a fixed, offline dataset

src/gflownet/__init__.py Show resolved Hide resolved
src/gflownet/algo/config.py Show resolved Hide resolved
src/gflownet/trainer.py Outdated Show resolved Hide resolved
src/gflownet/algo/config.py Show resolved Hide resolved
src/gflownet/data/data_source.py Show resolved Hide resolved
src/gflownet/data/data_source.py Show resolved Hide resolved
@bengioe bengioe merged commit 9bf35cd into trunk Mar 11, 2024
4 checks passed
@bengioe bengioe deleted the bengioe-better-iterators branch March 11, 2024 17:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Nested dataclasses do not reinitialize
2 participants