Skip to content

Commit

Permalink
Adjusted ActNorm to work as described in the paper (#167)
Browse files Browse the repository at this point in the history
* Adjusted ActNorm to work as described in the paper

* Fix off-by-one

* Fix log jacobian computation
  • Loading branch information
LarsKue committed Oct 4, 2023
1 parent bb080cd commit 6912465
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
26 changes: 21 additions & 5 deletions FrEIA/modules/invertible_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(self, dims_in, dims_c=None, init_data: torch.Tensor = None):

self.register_buffer("is_initialized", torch.tensor(False))

dims = next(iter(dims_in))
dims = list(next(iter(dims_in)))
dims[1:] = [1] * len(dims[1:])
self.log_scale = nn.Parameter(torch.empty(1, *dims))
self.loc = nn.Parameter(torch.empty(1, *dims))

Expand All @@ -42,9 +43,24 @@ def scale(self):
return torch.exp(self.log_scale)

def initialize(self, batch: torch.Tensor):
if batch.ndim != self.log_scale.ndim:
raise ValueError(f"Expected batch of dimension {self.log_scale.ndim}, but got {batch.ndim}.")

# we draw the mean and std over all dimensions except the channel dimension
dims = [0] + list(range(2, batch.ndim))

loc = torch.mean(batch, dim=dims, keepdim=True)
scale = torch.std(batch, dim=dims, keepdim=True)

# check for zero std
if torch.any(torch.isclose(scale, torch.tensor(0.0))):
raise ValueError("Failed to initialize ActNorm: One or more channels have zero standard deviation.")

# slice here to avoid silent device move
self.log_scale.data[:] = torch.log(scale)
self.loc.data[:] = loc

self.is_initialized.data = torch.tensor(True)
self.log_scale.data = torch.log(torch.std(batch, dim=0, keepdim=True))
self.loc.data = torch.mean(batch, dim=0, keepdim=True)

def output_dims(self, input_dims):
assert len(input_dims) == 1, "Can only use one input"
Expand All @@ -61,10 +77,10 @@ def forward(self, x, c=None, rev=False, jac=True):

if not rev:
out = (x - self.loc) / self.scale
log_jac_det = -utils.sum_except_batch(self.log_scale)
log_jac_det = -utils.sum_except_batch(self.log_scale) * torch.prod(torch.tensor(x.shape[2:]).float())
else:
out = self.scale * x + self.loc
log_jac_det = utils.sum_except_batch(self.log_scale)
log_jac_det = utils.sum_except_batch(self.log_scale) * torch.prod(torch.tensor(x.shape[2:]).float())

return (out,), log_jac_det

Expand Down
8 changes: 6 additions & 2 deletions tests/test_invertible_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ def test_conv(self):
self.assertStandardMoments(y_)

def assertStandardMoments(self, data):
self.assertTrue(torch.allclose(torch.mean(data, dim=0), torch.zeros(data.shape[-1]), atol=1e-7))
self.assertTrue(torch.allclose(torch.std(data, dim=0), torch.ones(data.shape[-1])))
dims = [0] + list(range(2, data.ndim))
mean = torch.mean(data, dim=dims)
std = torch.std(data, dim=dims)

self.assertTrue(torch.allclose(mean, torch.zeros_like(mean), atol=1e-7))
self.assertTrue(torch.allclose(std, torch.ones_like(std)))


class IResNetTest(unittest.TestCase):
Expand Down

0 comments on commit 6912465

Please sign in to comment.