Skip to content

Commit

Permalink
use adamax and 2 hidden layers for coupling
Browse files Browse the repository at this point in the history
  • Loading branch information
rtqichen committed Jan 11, 2019
1 parent d612674 commit 809f528
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 6 additions & 2 deletions lib/layers/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
class CouplingLayer(nn.Module):
"""Used in 2D experiments."""

def __init__(self, d, intermediate_dim=32, swap=False):
def __init__(self, d, intermediate_dim=64, swap=False):
nn.Module.__init__(self)
self.d = d - (d // 2)
self.swap = swap
self.net_s_t = nn.Sequential(
nn.Linear(self.d, intermediate_dim), nn.ReLU(inplace=True), nn.Linear(intermediate_dim, (d - self.d) * 2)
nn.Linear(self.d, intermediate_dim),
nn.ReLU(inplace=True),
nn.Linear(intermediate_dim, intermediate_dim),
nn.ReLU(inplace=True),
nn.Linear(intermediate_dim, (d - self.d) * 2),
)

def forward(self, x, logpx=None, reverse=False):
Expand Down
4 changes: 2 additions & 2 deletions train_discrete_toy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F")

parser.add_argument('--save', type=str, default='experiments/cnf')
parser.add_argument('--viz_freq', type=int, default=5000)
parser.add_argument('--viz_freq', type=int, default=1000)
parser.add_argument('--val_freq', type=int, default=1000)
parser.add_argument('--log_freq', type=int, default=100)
parser.add_argument('--gpu', type=int, default=0)
Expand Down Expand Up @@ -120,7 +120,7 @@ def compute_loss(args, model, batch_size=None):
logger.info(model)
logger.info("Number of trainable parameters: {}".format(count_parameters(model)))

optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
optimizer = optim.Adamax(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

time_meter = utils.RunningAverageMeter(0.98)
loss_meter = utils.RunningAverageMeter(0.98)
Expand Down

0 comments on commit 809f528

Please sign in to comment.