Skip to content

Commit

Permalink
Added identity init to CoupledRationalQuadraticSpline (#55)
Browse files Browse the repository at this point in the history
* Added identity init to CoupledRationalQuadraticSpline

* using torch.nn instead of nn
  • Loading branch information
mattcleigh committed Dec 4, 2023
1 parent df54633 commit 8272cdc
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion normflows/flows/neural_spline/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
activation=nn.ReLU,
dropout_probability=0.0,
reverse_mask=False,
init_identity=True,
):
"""Constructor
Expand All @@ -43,11 +44,12 @@ def __init__(
activation (torch module): Activation function
dropout_probability (float): Dropout probability of the NN
reverse_mask (bool): Flag whether the reverse mask should be used
init_identity (bool): Flag, initialize transform as identity
"""
super().__init__()

def transform_net_create_fn(in_features, out_features):
return ResidualNet(
net = ResidualNet(
in_features=in_features,
out_features=out_features,
context_features=num_context_channels,
Expand All @@ -57,6 +59,12 @@ def transform_net_create_fn(in_features, out_features):
dropout_probability=dropout_probability,
use_batch_norm=False,
)
if init_identity:
torch.nn.init.constant_(net.final_layer.weight, 0.0)
torch.nn.init.constant_(
net.final_layer.bias, np.log(np.exp(1 - DEFAULT_MIN_DERIVATIVE) - 1)
)
return net

self.prqct = PiecewiseRationalQuadraticCoupling(
mask=create_alternating_binary_mask(num_input_channels, even=reverse_mask),
Expand Down

0 comments on commit 8272cdc

Please sign in to comment.