In [None]:
!nvidia-smi

In [None]:
import os

gpu_id = 2
os.environ['CUDA_VISIBLE_DEVICES'] = "{}".format(gpu_id)

In [None]:
from torch import distributions
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from tqdm import tqdm

import matplotlib.pyplot as plt
plt.style.use('ggplot')

from sklearn import cluster, datasets, mixture
from sklearn.preprocessing import StandardScaler

In [None]:
torch.cuda.set_device(0)

In [None]:
n_samples = 2000
noisy_moons = datasets.make_moons(n_samples=n_samples, noise=.05)
X, y = noisy_moons
X = StandardScaler().fit_transform(X)
xlim, ylim = [-2, 2], [-2, 2]
plt.scatter(X[:, 0], X[:, 1], s=10, color='red')
plt.xlim(xlim)
plt.ylim(ylim)

In [None]:
def sample_n01(N):
    # Sample from a normal(0, 1) distribution.
    D = 2
    return np.random.normal(size = (N, D))

In [None]:
# Sample points from random normal(0, I_2). 
# Shift the distribution to (1, 1). 
# Keep track of both distributions.
X_normal_shifted = sample_n01(1000) + 1
X_normal = sample_n01(1000)

In [None]:
plt.scatter(X[:, 0], X[:, 1], s=10, color='red', alpha=1)
plt.scatter(X_normal[:, 0], X_normal[:, 1], s=10, color='green', alpha=0.5)
# plt.scatter(X_normal_shifted[:, 0], X_normal_shifted[:, 1], s=10, color='blue', alpha=0.2)
plt.show()

In [None]:
def log_prob_n01(x):
    # Evaluate log likelihood under the normal distribution.
    return np.sum(- np.square(x) / 2 - np.log(np.sqrt(2 * np.pi)), axis=-1)

### Loglikelihood of the two moons data under the normal distribution.
Under the two moons data, this data is clearly bimodal.

In [None]:
plt.hist(log_prob_n01(X), bins=50)
plt.show()

### For comparison, loglikelihood of the normal(0, I_2) distributed data

In [None]:
plt.hist(log_prob_n01(X_normal), bins=50)
plt.show()

In [None]:
prior = distributions.MultivariateNormal(torch.zeros(2), torch.eye(2))

In [None]:
class NVP(nn.Module):
    def __init__(self, flips, D=2):
        super().__init__()
        self.D = D
        self.flips = flips
        self.prior = distributions.MultivariateNormal(torch.zeros(2), torch.eye(2))
        self.shift_log_scale_fns = nn.ModuleList()
        for _ in flips:
            shift_log_scale_fn = nn.Sequential(
                nn.Linear(1, 256),
                nn.ReLU(),
                nn.Linear(256, 256),
                nn.ReLU(),
                nn.Linear(256, D),
            )
            self.shift_log_scale_fns.append(shift_log_scale_fn)
    
    def forward(self, x, flip_idx):
        # x is of shape [B, H]
        flip = self.flips[flip_idx]
        d = x.shape[-1] // 2
        x1, x2 = x[:, :d], x[:, d:]
        if flip:
            x2, x1 = x1, x2
        net_out = self.shift_log_scale_fns[flip_idx](x1)
        shift = net_out[:, :self.D // 2]
        log_scale = net_out[:, self.D // 2:]
        y2 = x2 * torch.exp(log_scale) + shift
        if flip:
            x1, y2 = y2, x1
        y = torch.cat([x1, y2], -1)
        return y
    
    def inverse_forward(self, y, flip_idx):
        flip = self.flips[flip_idx]
        d = y.shape[-1] // 2
        y1, y2 = y[:, :d], y[:, d:]
        if flip:
            y1, y2 = y2, y1
        net_out = self.shift_log_scale_fns[flip_idx](y1)
        shift = net_out[:, :self.D // 2]
        log_scale = net_out[:, self.D // 2:]
        x2 = (y2 - shift) * torch.exp(-log_scale)
        if flip:
            y1, x2 = x2, y1
        x = torch.cat([y1, x2], -1)
        return x, log_scale
    
    @staticmethod
    def base_log_prob_fn(x):
        return torch.sum(- (x ** 2) / 2 - np.log(np.sqrt(2 * np.pi)), -1)
    
    def base_sample_fn(self, N):
        # sampler random normal(0, I)
        x = self.prior.sample((N, 1)).cuda().squeeze(1)
        return x
        
    def log_prob(self, y, flip_idx):
        x, log_scale = self.inverse_forward(y, flip_idx)
        # This comes from the jacobian. In this case the jacobian is simply the product of the scales,
        # which becomes the sum of log scales in the loglikelihood.
        ildj = - torch.sum(log_scale, -1)
        return self.base_log_prob_fn(x) + ildj
    
    def sample_nvp_chain(self, N):
        xs = []
        x = self.base_sample_fn(N)
        xs.append(x)
        for i, _ in enumerate(self.flips):
            x = self.forward(x, flip_idx=i)
            xs.append(x)
        return x, xs
    
    def log_prob_chain(self, y):
        # Run y through all the necessary inverses, keeping track
        # of the logscale along the way, allowing us to compute the loss.
        temp = y
        logscales = y.data.new(y.shape[0]).zero_()
        for i, _ in enumerate(self.flips):
            temp, logscale = self.inverse_forward(
                temp, 
                flip_idx=len(self.flips) - 1 - i,
            )
            # One logscale per element in a batch per layer of flow.
            logscales += logscale.squeeze(-1)
        return self.base_log_prob_fn(temp) - logscales

In [None]:
flips = [False, True, False, True, False, True]
my_nvp = NVP(flips)

### Loglikelihood of the two moons data under this new distribution.
These are still quite low and bimodal: makes sense, the model hasnt been trained yet.

In [None]:
plt.hist(my_nvp.log_prob_chain(torch.FloatTensor(X)).data.cpu().numpy())
plt.show()

## Training!

In [None]:
"""Here we decide to stack six layers."""
flips = [False, True, False, True, False, True]
learning_rate = 1e-4
model = NVP(flips).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
iters = 10000

In [None]:
train_enum = range(iters - 1)
min_loss = float('inf')
for i in train_enum:
    noisy_moons = datasets.make_moons(n_samples=128, noise=.05)[0].astype(np.float32)
    optimizer.zero_grad()
    batch = torch.FloatTensor(noisy_moons).cuda()
    loss = - torch.mean(model.log_prob_chain(batch))
    if loss.item() < min_loss:
        bestmodel = model
    loss.backward()
    optimizer.step()
    if i % 500 == 0:
        print('Iter {}, loss is {:.3f}'.format(i, loss.item()))

In [None]:
new_Xs, _ = bestmodel.sample_nvp_chain(10000)
new_Xs = new_Xs.data.cpu().numpy()

In [None]:
plt.scatter(new_Xs[:, 0], new_Xs[:, 1], c='r')
plt.show()

## Make that GIF now

In [None]:
from matplotlib import animation, rc
from IPython.display import HTML, Image

def animate(i):
    l = i//48
    t = (float(i%48))/48
    y = (1-t)*xs_list[l] + t*xs_list[l+1]
    paths.set_offsets(y)
    return (paths,)

In [None]:
new_Xs, xs_list = bestmodel.sample_nvp_chain(10000)
new_Xs = new_Xs.data.cpu().numpy()
xs_list = [x.data.cpu().numpy() for x in xs_list]
plt.show()

for x in xs_list:
    plt.scatter(x[:, 0], x[:, 1], c='r', s=1)
    plt.show()
    
fig, ax = plt.subplots()
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)
paths = ax.scatter(xs_list[0][:, 0], xs_list[0][:, 1], s=1, color='red')

anim = animation.FuncAnimation(fig, animate, frames=48*len(flips), interval=1, blit=False)
anim.save('new_anim_circle_curve.gif', writer='imagemagick', fps=60)