In [1]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
from conditional_nf import NormFlow, ParameterNetwork

D = 4
arch_type = 'coupling'
num_stages = 1
num_layers = 2
num_units = 15

nf = NormFlow(D, arch_type, num_stages, num_layers, num_units)

print(nf.D_params)

1268


In [2]:
"""D_eta = 4
D_params = nf.D_params
hidden_layers = [100]

pn = ParameterNetwork(D_eta, hidden_layers, D_params)
eta = torch.zeros(31, D_eta)
theta = pn(eta)
"""

'D_eta = 4\nD_params = nf.D_params\nhidden_layers = [100]\n\npn = ParameterNetwork(D_eta, hidden_layers, D_params)\neta = torch.zeros(31, D_eta)\ntheta = pn(eta)\n'

In [5]:
from collections import OrderedDict

class ConditionedNormFlow(torch.nn.Module):
    def __init__(self, nf, D_x, hidden_layers):
        super(ConditionedNormFlow, self).__init__()
        self.nf = nf
        self.D_x = D_x
        self.hidden_layers = hidden_layers
        self.D_params = nf.D_params

        layers = [('linear1', torch.nn.Linear(D_eta, hidden_layers[0])), ('relu1', torch.nn.ReLU())]
        for i in range(1, len(hidden_layers)):
            layers.append(('linear%d' % (i+1), torch.nn.Linear(hidden_layers[i-1], hidden_layers[i])))
            layers.append(('relu%d' % (i+1), torch.nn.ReLU()))
        layers.append(('linear%d' % (len(hidden_layers)+1), torch.nn.Linear(hidden_layers[-1], self.D_params)))

        layer_dict = OrderedDict(layers)
        self.param_net = torch.nn.Sequential(layer_dict)

    def __call__(self, x, N=100):
        params = self.param_net(x)
        z, log_det = nf(params, N=N)
        return z, log_det
        
    
D_x = 4
hidden_layers = [100]
cnf = ConditionedNormFlow(nf, D_x, hidden_layers)


In [6]:
eta = torch.randn(31, D_eta)
z, log_det = cnf(eta, N=100)
print(z.shape, log_det.shape)

torch.Size([31, 200, 2]) torch.Size([31, 2])
