Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed May 9, 2024
1 parent 60722a7 commit f859640
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 10 deletions.
5 changes: 1 addition & 4 deletions docs/implementation_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ The data used for training GFlowNets can come from a variety of sources. `DataSo
- Generating new trajectories (w.r.t a fixed dataset of conditioning goals)
- Evaluating the model's likelihood on trajectories from a fixed, offline dataset

## Multiprocessing
## Multiprocessing

We use the multiprocessing features of torch's `DataLoader` to parallelize data generation and featurization. This is done by setting the `num_workers` (via `cfg.num_workers`) parameter of the `DataLoader` to a value greater than 0. Because workers cannot (easily) use a CUDA handle, we have to resort to a number of tricks.

Expand All @@ -66,6 +66,3 @@ On message serialization, naively sending batches of data and results (`Batch` a
We implement two solutions to this problem (in order of preference):
- using `SharedPinnedBuffer`s, which are shared tensors of fixed size (`cfg.mp_buffer_size`), but initialized once and pinned. This is the fastest solution, but requires that the size of the largest possible batch/return value is known in advance. This should work for any message, but has only been tested with `Batch` and `GraphActionCategorical` messages.
- using `cfg.pickle_mp_messages`, which simply serializes messages with `pickle`. This prevents the creation of lots of shared memory files, but is slower than the `SharedPinnedBuffer` solution. This should work for any message that `pickle` can handle.



1 change: 0 additions & 1 deletion src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ class AlgoConfig(StrictDataClass):
train_det_after: Optional[int] = None
valid_random_action_prob: float = 0.0
sampling_tau: float = 0.0
compute_log_n: bool = False
tb: TBConfig = field(default_factory=TBConfig)
moql: MOQLConfig = field(default_factory=MOQLConfig)
a2c: A2CConfig = field(default_factory=A2CConfig)
Expand Down
2 changes: 1 addition & 1 deletion src/gflownet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class Config(StrictDataClass):
pickle_mp_messages : bool
Whether to pickle messages sent between processes (only relevant if num_workers > 0)
mp_buffer_size : Optional[int]
If specified, use a buffer of this size for passing tensors between processes.
If specified, use a buffer of this size in bytes for passing tensors between processes.
Note that this is only relevant if num_workers > 0.
Also note that this will allocate `num_workers + 2 * number of wrapped objects` buffers.
git_hash : Optional[str]
Expand Down
5 changes: 1 addition & 4 deletions src/gflownet/tasks/seh_frag.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class SEHFragTrainer(StandardOnlineTrainer):
def set_default_hps(self, cfg: Config):
cfg.hostname = socket.gethostname()
cfg.pickle_mp_messages = False
cfg.mp_buffer_size = 32 * 1024**2 # 32 MB
cfg.num_workers = 8
cfg.opt.learning_rate = 1e-4
cfg.opt.weight_decay = 1e-8
Expand Down Expand Up @@ -195,18 +196,14 @@ def main():
config.log_dir = f"./logs/debug_run_seh_frag_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
config.device = "cuda" if torch.cuda.is_available() else "cpu"
config.overwrite_existing_exp = True
config.algo.num_from_policy = 64
config.num_training_steps = 1_00
config.validate_every = 20
config.num_final_gen_steps = 10
config.num_workers = 1
config.opt.lr_decay = 20_000
config.opt.clip_grad_type = "total_norm"
config.algo.sampling_tau = 0.99
config.cond.temperature.sample_dist = "uniform"
config.cond.temperature.dist_params = [0, 64.0]
config.mp_buffer_size = 32 * 1024**2
# config.pickle_mp_messages = True

trial = SEHFragTrainer(config)
trial.run()
Expand Down

0 comments on commit f859640

Please sign in to comment.