In [41]:
import os
import argparse
import glob
from PIL import Image
import numpy as np
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons, make_circles
import torch
import torch.nn as nn
import torch.optim as optim
from torchdiffeq import odeint_adjoint as odeint
device = torch.device('cuda:0')

In [75]:
class ARGS():
    def __init__(self):
        self.adjoint = True
        self.viz = True
        self.niters = 1000
        self.lr = 1e-3
        self.num_samples = 512
        self.width = 64
        self.hidden_dim = 32
        self.gpu = 0
        self.train_dir = None
        self.save = "cnf"
        self.hidden_dim = 32

class CNF(nn.Module):
    """Adapted from the NumPy implementation at:
    https://gist.github.com/rtqichen/91924063aa4cc95e7ef30b3a5491cc52
    """
    def __init__(self, in_out_dim, hidden_dim, width):
        super().__init__()
        self.in_out_dim = in_out_dim
        self.hidden_dim = hidden_dim
        self.width = width
        self.hyper_net = HyperNetwork(in_out_dim, hidden_dim, width)

    def forward(self, t, states):
        z = states[0]
        logp_z = states[1]

        batchsize = z.shape[0]

        with torch.set_grad_enabled(True):
            z.requires_grad_(True)

            W, B, U = self.hyper_net(t)

            Z = torch.unsqueeze(z, 0).repeat(self.width, 1, 1)

            h = torch.tanh(torch.matmul(Z, W) + B)
            dz_dt = torch.matmul(h, U).mean(0)

            dlogp_z_dt = -trace_df_dz(dz_dt, z).view(batchsize, 1)
            
        return (dz_dt, dlogp_z_dt)


def trace_df_dz(f, z):
    """Calculates the trace of the Jacobian df/dz.
    Stolen from: https://github.com/rtqichen/ffjord/blob/master/lib/layers/odefunc.py#L13
    """
    sum_diag = 0.
    for i in range(z.shape[1]):
        sum_diag += torch.autograd.grad(f[:, i].sum(), z, create_graph=True)[0].contiguous()[:, i].contiguous()

    return sum_diag.contiguous()


class HyperNetwork(nn.Module):
    """Hyper-network allowing f(z(t), t) to change with time.

    Adapted from the NumPy implementation at:
    https://gist.github.com/rtqichen/91924063aa4cc95e7ef30b3a5491cc52
    """
    def __init__(self, in_out_dim, hidden_dim, width):
        super().__init__()

        blocksize = width * in_out_dim

        self.fc1 = nn.Linear(1, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 3 * blocksize + width)

        self.in_out_dim = in_out_dim
        self.hidden_dim = hidden_dim
        self.width = width
        self.blocksize = blocksize

    def forward(self, t):
        # predict params
        params = t.reshape(1, 1)
        params = torch.tanh(self.fc1(params))
        params = torch.tanh(self.fc2(params))
        params = self.fc3(params)

        # restructure
        params = params.reshape(-1)
        W = params[:self.blocksize].reshape(self.width, self.in_out_dim, 1)

        U = params[self.blocksize:2 * self.blocksize].reshape(self.width, 1, self.in_out_dim)

        G = params[2 * self.blocksize:3 * self.blocksize].reshape(self.width, 1, self.in_out_dim)
        U = U * torch.sigmoid(G)

        B = params[3 * self.blocksize:].reshape(self.width, 1, 1)
        return [W, B, U]

def get_batch(num_samples):
    # points, _ = make_circles(n_samples=num_samples, noise=0.06, factor=0.5)
    points, _ = make_moons(n_samples=num_samples, noise=0.06)
    x = torch.tensor(points).type(torch.float32).to(device)
    logp_diff_t1 = torch.zeros(num_samples, 1).type(torch.float32).to(device)

    return(x, logp_diff_t1)

In [76]:
t0 = 0
t1 = 10
device = torch.device('cuda:0')
args=ARGS()
# model
func = CNF(in_out_dim=2, hidden_dim=args.hidden_dim, width=args.width).to(device)
optimizer = optim.Adam(func.parameters(), lr=args.lr)
p_z0 = torch.distributions.MultivariateNormal(
    loc=torch.tensor([0.0, 0.0]).to(device),
    covariance_matrix=torch.tensor([[0.1, 0.0], [0.0, 0.1]]).to(device)
)
loss_hist = []

# if args.train_dir is not None:
#     if not os.path.exists(args.train_dir):
#         os.makedirs(args.train_dir)
#     ckpt_path = os.path.join(args.train_dir, 'ckpt.pth')
#     if os.path.exists(ckpt_path):
#         checkpoint = torch.load(ckpt_path)
#         func.load_state_dict(checkpoint['func_state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#         print('Loaded ckpt from {}'.format(ckpt_path))


for itr in range(1, args.niters + 1):
    optimizer.zero_grad()

    x, logp_diff_t1 = get_batch(args.num_samples)

    z_t, logp_diff_t = odeint(
        func,
        (x, logp_diff_t1),
        torch.tensor([t1, t0]).type(torch.float32).to(device),
        atol=1e-5,
        rtol=1e-5,
        method='dopri5',
    )

    z_t0, logp_diff_t0 = z_t[-1], logp_diff_t[-1]

    logp_x = p_z0.log_prob(z_t0).to(device) - logp_diff_t0.view(-1)
    loss = -logp_x.mean(0)

    loss.backward()
    optimizer.step()

    loss_hist.append(loss.item())

    print('Iter: {}, running avg loss: {:.4f}'.format(itr, loss_hist[-1]))



Iter: 1, running avg loss: 6.6777
Iter: 2, running avg loss: 6.1141
Iter: 3, running avg loss: 5.6708
Iter: 4, running avg loss: 5.2306
Iter: 5, running avg loss: 4.8262
Iter: 6, running avg loss: 4.4754
Iter: 7, running avg loss: 4.0927
Iter: 8, running avg loss: 3.7789
Iter: 9, running avg loss: 3.4813
Iter: 10, running avg loss: 3.1485
Iter: 11, running avg loss: 2.9058
Iter: 12, running avg loss: 2.6994
Iter: 13, running avg loss: 2.4585
Iter: 14, running avg loss: 2.3144
Iter: 15, running avg loss: 2.1600
Iter: 16, running avg loss: 2.0869
Iter: 17, running avg loss: 2.0153
Iter: 18, running avg loss: 1.9760
Iter: 19, running avg loss: 1.9636
Iter: 20, running avg loss: 1.9679
Iter: 21, running avg loss: 1.9635
Iter: 22, running avg loss: 1.9585
Iter: 23, running avg loss: 1.9587
Iter: 24, running avg loss: 1.9633
Iter: 25, running avg loss: 1.9505
Iter: 26, running avg loss: 1.9541
Iter: 27, running avg loss: 1.9456
Iter: 28, running avg loss: 1.9321
Iter: 29, running avg loss: 1

In [80]:
%matplotlib inline
if args.viz:
    viz_samples = 30000
    viz_timesteps = 9
    target_sample, _ = get_batch(viz_samples)

    if not os.path.exists(args.save):
        os.makedirs(args.save)
    with torch.no_grad():
        # Generate evolution of samples
        z_t0 = p_z0.sample([viz_samples]).to(device)
        logp_diff_t0 = torch.zeros(viz_samples, 1).type(torch.float32).to(device)

        z_t_samples, _ = odeint(
            func,
            (z_t0, logp_diff_t0),
            torch.tensor(np.linspace(t0, t1, viz_timesteps)).to(device),
            atol=1e-5,
            rtol=1e-5,
            method='rk4',
        )

        # Generate evolution of density
        x = np.linspace(-2, 2, 100)
        y = np.linspace(-2, 2, 100)
        points = np.vstack(np.meshgrid(x, y)).reshape([2, -1]).T

        z_t1 = torch.tensor(points).type(torch.float32).to(device)
        logp_diff_t1 = torch.zeros(z_t1.shape[0], 1).type(torch.float32).to(device)

        z_t_density, logp_diff_t = odeint(
            func,
            (z_t1, logp_diff_t1),
            torch.tensor(np.linspace(t1, t0, viz_timesteps)).to(device),
            atol=1e-5,
            rtol=1e-5,
            method='rk4',
        )

        # Create plots for each timestep
        plt.figure(figsize=(25, 4))
        # plt.show()
        ii = 1
        for (t, z_sample, z_density, logp_diff) in zip(
                np.linspace(t0, t1, viz_timesteps),
                z_t_samples, z_t_density, logp_diff_t
        ):
            # fig = plt.figure(figsize=(12, 4), dpi=200)
            # plt.tight_layout()
            # plt.axis('off')
            # plt.margins(0, 0)
            # fig.suptitle(f'{t:.2f}s')

            # ax1 = fig.add_subplot(1, 3, 1)
            # ax1.set_title('Target')
            # ax1.get_xaxis().set_ticks([])
            # ax1.get_yaxis().set_ticks([])
            # ax2 = fig.add_subplot(1, 3, 2)
            # ax2.set_title('Samples')
            # ax2.get_xaxis().set_ticks([])
            # ax2.get_yaxis().set_ticks([])
            # ax3 = fig.add_subplot(1, 3, 3)
            # ax3.set_title('Log Probability')
            # ax3.get_xaxis().set_ticks([])
            # ax3.get_yaxis().set_ticks([])

            # ax1.hist2d(*target_sample.detach().cpu().numpy().T, bins=300, density=True,
            #             range=[[-2, 2], [-2, 2]],cmap='Greys')

            # ax2.hist2d(*z_sample.detach().cpu().numpy().T, bins=300, density=True,
            #             range=[[-2, 2], [-2, 2]],cmap='Greys')

            logp = p_z0.log_prob(z_density) - logp_diff.view(-1)
            # ax3.tricontourf(*z_t1.detach().cpu().numpy().T,
            #                 np.exp(logp.detach().cpu().numpy()), 200,cmap='Greys')

            # plt.savefig(os.path.join(args.save, f"cnf-viz-{int(t*1000):05d}.jpg"),
            #             pad_inches=0.2, bbox_inches='tight')
            # plt.close()

            # print(ii)
            plt.subplot(1,10,ii)
            plt.axis('off')
            plt.title(f't={t:.2f}')
            plt.hist2d(*z_sample.detach().cpu().numpy().T, bins=300, density=True,range=[[-2, 3], [-1.5, 1.5]],cmap='Greys')
            ii += 1
        # img, *imgs = [Image.open(f) for f in sorted(glob.glob(os.path.join(args.save, f"cnf-viz-*.jpg")))]
        # img.save(fp=os.path.join(args.save, "cnf-viz.gif"), format='GIF', append_images=imgs,
        #             save_all=True, duration=250, loop=0)
        # plt.show()
        plt.subplot(1,viz_timesteps,viz_timesteps)
        plt.axis('off')
        plt.title('Target')
        plt.hist2d(*target_sample.detach().cpu().numpy().T, bins=300, density=True,range=[[-2, 3], [-1.5, 1.5]],cmap='Greys')
    # print('Saved visualization animation at {}'.format(os.path.join(args.save, "cnf-viz.gif")))
    plt.savefig(os.path.join('cnf', "cnf-viz.pdf"), bbox_inches='tight')
