Skip to content

Commit

Permalink
debug: reverted to multiplicative masking
Browse files Browse the repository at this point in the history
  • Loading branch information
julienroyd committed Apr 1, 2024
1 parent c0d67e1 commit b3ee035
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,12 @@ def _apply_action_masks(self):
)

def _mask(self, x, m):
"""
mask logit vector x with binary mask m, -1000 is a tiny log-value
Note to self: we can't use torch.inf here, because inf * 0 is nan
"""
assert m.dtype == torch.float
return x.masked_fill(m == 0., -torch.inf)
return x * m + -1000 * (1 - m)

def detach(self):
new = copy.copy(self)
Expand Down Expand Up @@ -752,8 +756,25 @@ def sample(self) -> List[ActionIndex]:
u = [torch.rand(i.shape, device=self.dev) for i in self._masked_logits]
# Gumbel noise
gumbel = [logit - (-noise.log()).log() for logit, noise in zip(self._masked_logits, u)]

if self._action_masks is not None:
gumbel_safe = [
torch.where(
mask == 1,
torch.maximum(
x,
torch.nextafter(
torch.tensor(torch.finfo(x.dtype).min, dtype=x.dtype), torch.tensor(0.0, dtype=x.dtype)
).to(x.device),
),
torch.finfo(x.dtype).min,
)
for x, mask in zip(gumbel, self._action_masks)
]
else:
gumbel_safe = gumbel
# Take the argmax
return self.argmax(x=gumbel)
return self.argmax(x=gumbel_safe)

def argmax(
self,
Expand Down

0 comments on commit b3ee035

Please sign in to comment.