Skip to content

Commit

Permalink
typo + impl notes
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Mar 11, 2024
1 parent 4811e7c commit c6d5613
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
20 changes: 20 additions & 0 deletions docs/implementation_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,23 @@ The code contains a specific categorical distribution type for graph actions, `G
Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor.

The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.

To expand, the logits are always 2d tensors, and there’s going to be one such tensor per “action type” that the agent is allowed to take.
Since graphs have variable number of nodes, and since each node has `n` associated possible action/logits, then the `(n_nodes, n)` tensor will vary from minibatch to minibatch.
In addition,the nodes in said logit tensor belong to different graphs in the minibatch; this is indicated by a `batch` tensor of shape `(n_nodes,)` for nodes (for e.g. edges it would be of shape `(n_edges,)`).


Here’s an example: say we have 2 graphs in a minibatch, the first has 3 nodes, the second 2 nodes. The logits associated with AddNode will be of shape `(5, n)` (assuming there are `n` types of nodes in the problem). Say `n=2`, and `logits[AddNode] = [[1,2],[3,4],[5,6],[7,8],[9,0]]`, and `batch=[0,0,0,1,1]`.
Then to compute the policy, we have to compute a softmax appropriately, i.e. the softmax for the first graph would be `softmax([1,2,3,4,5,6])` and for the second `softmax([7,8,9,0])` . This is possible thanks to `batch` and is what `GraphActionCategorical` does behind the scenes.
Now that would be for when we only have the `AddNode` action. With more than one action we also have to compute the log-softmax log-normalization factor over the logits of these other tensors, log-add them together and then substract it from all corresponding logits.

## Data sources

The data used for training GFlowNets can come from a variety of sources. `DataSource` implements these different use-cases as individual iterators that collectively assemble the training batches before passing it to the trainer. Some of these use-cases include:
- Generating new trajectories on-policy
- Sampling trajectories from passed policies from a replay buffer
- Sampling trajectories from a fixed, offline dataset

`DataSource` also covers validation sets, including cases such as:
- Generating new trajectories (w.r.t a fixed dataset of conditioning goals)
- Evaluating the model's likelihood on trajectories from a fixed, offline dataset
2 changes: 1 addition & 1 deletion src/gflownet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _make_data_loader(self, src):

def build_training_data_loader(self) -> DataLoader:
# Since the model may be used by a worker in a different process, we need to wrap it.
# See implementation_nodes.md for more details.
# See implementation_notes.md for more details.
model = self._wrap_for_mp(self.sampling_model)
replay_buffer = self._wrap_for_mp(self.replay_buffer)

Expand Down

0 comments on commit c6d5613

Please sign in to comment.