Skip to content

Commit

Permalink
revert: back from multiplicative masking to -inf
Browse files Browse the repository at this point in the history
This reverts commit b3ee035.
  • Loading branch information
julienroyd committed Apr 4, 2024
1 parent 5f953c6 commit 215e857
Showing 1 changed file with 2 additions and 23 deletions.
25 changes: 2 additions & 23 deletions src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,12 +613,8 @@ def _apply_action_masks(self):
)

def _mask(self, x, m):
"""
mask logit vector x with binary mask m, -1000 is a tiny log-value
Note to self: we can't use torch.inf here, because inf * 0 is nan
"""
assert m.dtype == torch.float
return x * m + -1000 * (1 - m)
return x.masked_fill(m == 0., -torch.inf)

def detach(self):
new = copy.copy(self)
Expand Down Expand Up @@ -756,25 +752,8 @@ def sample(self) -> List[ActionIndex]:
u = [torch.rand(i.shape, device=self.dev) for i in self._masked_logits]
# Gumbel noise
gumbel = [logit - (-noise.log()).log() for logit, noise in zip(self._masked_logits, u)]

if self._action_masks is not None:
gumbel_safe = [
torch.where(
mask == 1,
torch.maximum(
x,
torch.nextafter(
torch.tensor(torch.finfo(x.dtype).min, dtype=x.dtype), torch.tensor(0.0, dtype=x.dtype)
).to(x.device),
),
torch.finfo(x.dtype).min,
)
for x, mask in zip(gumbel, self._action_masks)
]
else:
gumbel_safe = gumbel
# Take the argmax
return self.argmax(x=gumbel_safe)
return self.argmax(x=gumbel)

def argmax(
self,
Expand Down

0 comments on commit 215e857

Please sign in to comment.