Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ActionIndex and Harmonize use of action-masks #129

Merged
merged 25 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2a6b9d7
docs: added abstract method construct_batch() in GFNAlgorithm()
julienroyd Mar 11, 2024
d6277f3
chore: cleaning up GraphAction.relabel
julienroyd Mar 11, 2024
5d1adf9
chore: replaced Tuple[int,int,int] by ActionIndex() named-tuple.
julienroyd Mar 11, 2024
a6f141a
chore: centralises masking in GraphActionCategorical(), specifically:
julienroyd Mar 12, 2024
4f77507
feat: now mask logits by setting them to -inf (should avoid silent un…
julienroyd Mar 28, 2024
10174fc
fix: wrong variable
julienroyd Mar 28, 2024
c6afaa1
chore: tox
julienroyd Mar 28, 2024
b29fab4
fix: in test
julienroyd Mar 28, 2024
c0d67e1
minor: added assert for safety
julienroyd Mar 29, 2024
b3ee035
debug: reverted to multiplicative masking
julienroyd Apr 1, 2024
b379694
fix: corrected variable name
julienroyd Apr 1, 2024
70a4ec3
fix: detach tensors entering the buffer and sending to cpu
julienroyd Apr 1, 2024
5443f6f
minor: added case 'tuple'
julienroyd Apr 1, 2024
f92690a
fix: added detach and cpu() at the begining of create_batch()
julienroyd Apr 1, 2024
15031a3
minor: removed type cast, now support tuples
julienroyd Apr 1, 2024
f5abfc1
feat: added StrictDataClass to prevent creating new config attributes…
julienroyd Apr 2, 2024
d7d6997
minor: better error message
julienroyd Apr 2, 2024
fb6dac1
Merge branch 'julien-fix-gpu-mem-bust' into julien-harmonize-use-of-m…
julienroyd Apr 3, 2024
ad0db6c
fix: made focus_dir and preferences accessible at the batch level
julienroyd Apr 4, 2024
5ab80df
tox
julienroyd Apr 4, 2024
5f953c6
Merge branch 'julien-fix-gpu-mem-bust' into julien-harmonize-use-of-m…
julienroyd Apr 4, 2024
215e857
revert: back from multiplicative masking to -inf
julienroyd Apr 4, 2024
57da20e
tox
julienroyd Apr 4, 2024
70c68a9
Merge branch 'trunk' into julien-harmonize-use-of-masks
julienroyd Apr 4, 2024
d27a265
refactor: moving action_type_to_mask in graph_building_env.py
julienroyd Apr 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions docs/implementation_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ We separate experiment concerns in four categories:
- maps graphs to torch_geometric `Data`
instances
- maps GraphActions to action indices
- produces action masks
- communicates to the model what inputs it should expect
- The Task class is responsible for computing the reward of a state, and for sampling conditioning information
- The Trainer class is responsible for instanciating everything, and running the training & testing loop
Expand All @@ -24,7 +23,7 @@ This library is built around the idea of generating graphs. We use the `networkx

Some notes:
- graphs are (for now) assumed to be _undirected_. This is encoded for `torch_geometric` by duplicating the edges (contiguously) in both directions. Models still only produce one logit(-row) per edge, so the policy is still assumed to operate on undirected graphs.
- When converting from `GraphAction`s (nx) to so-called `aidx`s, the `aidx`s are encoding-bound, i.e. they point to specific rows and columns in the torch encoding.
- When converting from `GraphAction`s (nx) to `ActionIndex`s (tuple of ints), the action indexes are encoding-bound, i.e. they point to specific rows and columns in the torch encoding.


### Graph policies & graph action categoricals
Expand All @@ -33,12 +32,11 @@ The code contains a specific categorical distribution type for graph actions, `G

Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor.

The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.
The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, action masks and so on; it can also be used to sample from the distribution.

To expand, the logits are always 2d tensors, and there’s going to be one such tensor per “action type” that the agent is allowed to take.
Since graphs have variable number of nodes, and since each node has `n` associated possible action/logits, then the `(n_nodes, n)` tensor will vary from minibatch to minibatch.
In addition,the nodes in said logit tensor belong to different graphs in the minibatch; this is indicated by a `batch` tensor of shape `(n_nodes,)` for nodes (for e.g. edges it would be of shape `(n_edges,)`).

Since graphs have variable number of nodes, and since each node has `n_node_actions` associated possible action/logits, then the `(n_nodes, n_node_actions)` tensor will vary from minibatch to minibatch.
In addition, the nodes in said logit tensor belong to different graphs in the minibatch; this is indicated by a `batch` tensor of shape `(n_nodes,)` for nodes (for e.g. edges it would be of shape `(n_edges,)`).

Here’s an example: say we have 2 graphs in a minibatch, the first has 3 nodes, the second 2 nodes. The logits associated with AddNode will be of shape `(5, n)` (assuming there are `n` types of nodes in the problem). Say `n=2`, and `logits[AddNode] = [[1,2],[3,4],[5,6],[7,8],[9,0]]`, and `batch=[0,0,0,1,1]`.
Then to compute the policy, we have to compute a softmax appropriately, i.e. the softmax for the first graph would be `softmax([1,2,3,4,5,6])` and for the second `softmax([7,8,9,0])` . This is possible thanks to `batch` and is what `GraphActionCategorical` does behind the scenes.
Expand Down
21 changes: 21 additions & 0 deletions src/gflownet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,27 @@ def compute_batch_losses(
"""
raise NotImplementedError()

def construct_batch(self, trajs, cond_info, log_rewards):
"""Construct a batch from a list of trajectories and their information

Typically calls ctx.graph_to_Data and ctx.collate to convert the trajectories into
a batch of graphs and adds the necessary attributes for training.

Parameters
----------
trajs: List[List[tuple[Graph, GraphAction]]]
A list of N trajectories.
cond_info: Tensor
The conditional info that is considered for each trajectory. Shape (N, n_info)
log_rewards: Tensor
The transformed log-reward (e.g. torch.log(R(x) ** beta) ) for each trajectory. Shape (N,)
Returns
-------
batch: gd.Batch
A (CPU) Batch object with relevant attributes added
"""
raise NotImplementedError()

def get_random_action_prob(self, it: int):
if self.is_eval:
return self.global_cfg.algo.valid_random_action_prob
Expand Down
3 changes: 2 additions & 1 deletion src/gflownet/algo/advantage_actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def construct_batch(self, trajs, cond_info, log_rewards):
"""
torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]]
actions = [
self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
self.ctx.GraphAction_to_ActionIndex(g, a)
for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
]
batch = self.ctx.collate(torch_graphs)
batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs])
Expand Down
55 changes: 26 additions & 29 deletions src/gflownet/algo/envelope_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,43 +64,39 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False):
src_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_emb, node_embeddings[e_row]], 1))
dst_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_emb, node_embeddings[e_col]], 1))

def _mask(x, m):
# mask logit vector x with binary mask m
return x * m + self.mask_value * (1 - m)

def _mask_obj(x, m):
# mask logit vector x with binary mask m
return (
x.reshape(x.shape[0], x.shape[1] // self.num_objectives, self.num_objectives) * m[:, :, None]
+ self.mask_value * (1 - m[:, :, None])
).reshape(x.shape)

cat = GraphActionCategorical(
g,
logits=[
raw_logits=[
F.relu(self.emb2stop(graph_embeddings)),
_mask(F.relu(self.emb2add_node(node_embeddings)), g.add_node_mask),
_mask_obj(F.relu(torch.cat([src_anchor_logits, dst_anchor_logits], 1)), g.set_edge_attr_mask),
F.relu(self.emb2add_node(node_embeddings)),
F.relu(torch.cat([src_anchor_logits, dst_anchor_logits], 1)),
],
action_masks=[
1,
g.add_node_mask.repeat(1, self.num_objectives),
g.set_edge_attr_mask.repeat(1, self.num_objectives),
],
keys=[None, "x", "edge_index"],
types=self.action_type_order,
)
r_pred = self.emb2reward(graph_embeddings)
if output_Qs:
return cat, r_pred
cat.masks = [1, g.add_node_mask.cpu(), g.set_edge_attr_mask.cpu()]
# Compute the greedy policy
# See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations
# TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes
w = cond[:, -self.num_objectives :]
w_dot_Q = [
(qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2)
for qi, b in zip(cat.logits, cat.batch)
]
# Set the softmax distribution to a very low temperature to make sure only the max gets
# sampled (and we get random argmax tie breaking for free!):
cat.logits = [i * 100 for i in w_dot_Q]
return cat, r_pred

else:
# Compute the greedy policy
# See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations
# TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes
w = cond[:, -self.num_objectives :]
w_dot_Q = [
(qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2)
for qi, b in zip(cat.logits, cat.batch)
]
cat.action_masks = [1, g.add_node_mask.cpu(), g.set_edge_attr_mask.cpu()]
# Set the softmax distribution to a very low temperature to make sure only the max gets
# sampled (and we get random argmax tie breaking for free!):
cat.logits = [i * 100 for i in w_dot_Q]
return cat, r_pred


class GraphTransformerEnvelopeQL(nn.Module):
Expand Down Expand Up @@ -134,7 +130,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False):
e_row, e_col = g.edge_index[:, ::2]
cat = GraphActionCategorical(
g,
logits=[
raw_logits=[
self.emb2stop(graph_embeddings),
self.emb2add_node(node_embeddings),
self.emb2set_node_attr(node_embeddings),
Expand Down Expand Up @@ -272,7 +268,8 @@ def construct_batch(self, trajs, cond_info, log_rewards):
"""
torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]]
actions = [
self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
self.ctx.GraphAction_to_ActionIndex(g, a)
for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
]
batch = self.ctx.collate(torch_graphs)
batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs])
Expand Down
11 changes: 7 additions & 4 deletions src/gflownet/algo/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,23 @@ def construct_batch(self, trajs, cond_info, log_rewards):
# there are invalid states that make episodes end prematurely (when those invalid states
# have multiple possible parents).

# convert actions to aidx
# convert actions to ActionIndex
parent_actions = [pact for parent in parents for pact, pstate in parent]
parent_actionidcs = [self.ctx.GraphAction_to_aidx(gdata, a) for gdata, a in zip(parent_graphs, parent_actions)]
parent_actionidxs = [
self.ctx.GraphAction_to_ActionIndex(gdata, a) for gdata, a in zip(parent_graphs, parent_actions)
]
# convert state to Data
state_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"][1:]]
terminal_actions = [
self.ctx.GraphAction_to_aidx(self.ctx.graph_to_Data(tj["traj"][-1][0]), tj["traj"][-1][1]) for tj in trajs
self.ctx.GraphAction_to_ActionIndex(self.ctx.graph_to_Data(tj["traj"][-1][0]), tj["traj"][-1][1])
for tj in trajs
]

# Create a batch from [*parents, *states]. This order will make it easier when computing the loss
batch = self.ctx.collate(parent_graphs + state_graphs)
batch.num_parents = torch.tensor([len(i) for i in parents])
batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs])
batch.parent_acts = torch.tensor(parent_actionidcs)
batch.parent_acts = torch.tensor(parent_actionidxs)
batch.terminal_acts = torch.tensor(terminal_actions)
batch.log_rewards = log_rewards
batch.cond_info = cond_info
Expand Down
49 changes: 24 additions & 25 deletions src/gflownet/algo/graph_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
import torch.nn as nn
from torch import Tensor

from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionCategorical, GraphActionType
from gflownet.envs.graph_building_env import (
Graph,
GraphAction,
GraphActionCategorical,
GraphActionType,
action_type_to_mask,
)
from gflownet.models.graph_transformer import GraphTransformerGFN


Expand Down Expand Up @@ -79,6 +85,8 @@ def sample_from_model(
Conditional information of each trajectory, shape (n, n_info)
dev: torch.device
Device on which data is manipulated
random_action_prob: float
Probability of taking a random action at each step

Returns
-------
Expand All @@ -92,17 +100,17 @@ def sample_from_model(
# This will be returned
data = [{"traj": [], "reward_pred": None, "is_valid": True, "is_sink": []} for i in range(n)]
# Let's also keep track of trajectory statistics according to the model
fwd_logprob: List[List[Tensor]] = [[] for i in range(n)]
bck_logprob: List[List[Tensor]] = [[] for i in range(n)]
fwd_logprob: List[List[Tensor]] = [[] for _ in range(n)]
bck_logprob: List[List[Tensor]] = [[] for _ in range(n)]

graphs = [self.env.new() for i in range(n)]
done = [False] * n
graphs = [self.env.new() for _ in range(n)]
done = [False for _ in range(n)]
# TODO: instead of padding with Stop, we could have a virtual action whose probability
# always evaluates to 1. Presently, Stop should convert to a [0,0,0] aidx, which should
# always evaluates to 1. Presently, Stop should convert to a (0,0,0) ActionIndex, which should
# always be at least a valid index, and will be masked out anyways -- but this isn't ideal.
# Here we have to pad the backward actions with something, since the backward actions are
# evaluated at s_{t+1} not s_t.
bck_a = [[GraphAction(GraphActionType.Stop)] for i in range(n)]
bck_a = [[GraphAction(GraphActionType.Stop)] for _ in range(n)]

def not_done(lst):
return [e for i, e in enumerate(lst) if not done[i]]
Expand All @@ -116,27 +124,22 @@ def not_done(lst):
# TODO: compute bck_cat.log_prob(bck_a) when relevant
fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask])
if random_action_prob > 0:
masks = [1] * len(fwd_cat.logits) if fwd_cat.masks is None else fwd_cat.masks
# Device which graphs in the minibatch will get their action randomized
is_random_action = torch.tensor(
self.rng.uniform(size=len(torch_graphs)) < random_action_prob, device=dev
).float()
# Set the logits to some large value if they're not masked, this way the masked
# actions have no probability of getting sampled, and there is a uniform
# distribution over the rest
# Set the logits to some large value to have a uniform distribution
fwd_cat.logits = [
# We don't multiply m by i on the right because we're assume the model forward()
# method already does that
is_random_action[b][:, None] * torch.ones_like(i) * m * 100 + i * (1 - is_random_action[b][:, None])
for i, m, b in zip(fwd_cat.logits, masks, fwd_cat.batch)
is_random_action[b][:, None] * torch.ones_like(i) * 100 + i * (1 - is_random_action[b][:, None])
for i, b in zip(fwd_cat.logits, fwd_cat.batch)
]
if self.sample_temp != 1:
sample_cat = copy.copy(fwd_cat)
sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits]
actions = sample_cat.sample()
else:
actions = fwd_cat.sample()
graph_actions = [self.ctx.aidx_to_GraphAction(g, a) for g, a in zip(torch_graphs, actions)]
graph_actions = [self.ctx.ActionIndex_to_GraphAction(g, a) for g, a in zip(torch_graphs, actions)]
log_probs = fwd_cat.log_prob(actions)
# Step each trajectory, and accumulate statistics
for i, j in zip(not_done(range(n)), range(n)):
Expand Down Expand Up @@ -259,21 +262,17 @@ def not_done(lst):
else:
gbatch = self.ctx.collate(torch_graphs)
action_types = self.ctx.bck_action_type_order
masks = [getattr(gbatch, i.mask_name) for i in action_types]
action_masks = [action_type_to_mask(t, gbatch, assert_mask_exists=True) for t in action_types]
bck_cat = GraphActionCategorical(
gbatch,
logits=[m * 1e6 for m in masks],
keys=[
# TODO: This is not very clean, could probably abstract this away somehow
GraphTransformerGFN._graph_part_to_key[GraphTransformerGFN._action_type_to_graph_part[t]]
for t in action_types
],
masks=masks,
raw_logits=[torch.ones_like(m) for m in action_masks],
keys=[GraphTransformerGFN.action_type_to_key(t) for t in action_types],
action_masks=action_masks,
types=action_types,
)
bck_actions = bck_cat.sample()
graph_bck_actions = [
self.ctx.aidx_to_GraphAction(g, a, fwd=False) for g, a in zip(torch_graphs, bck_actions)
self.ctx.ActionIndex_to_GraphAction(g, a, fwd=False) for g, a in zip(torch_graphs, bck_actions)
]
bck_logprobs = bck_cat.log_prob(bck_actions)

Expand Down
3 changes: 2 additions & 1 deletion src/gflownet/algo/soft_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def construct_batch(self, trajs, cond_info, log_rewards):
"""
torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]]
actions = [
self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
self.ctx.GraphAction_to_ActionIndex(g, a)
for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
]
batch = self.ctx.collate(torch_graphs)
batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs])
Expand Down
Loading
Loading