Example of training a flow using `glasflow`

In [None]:
from glasflow.flows import RealNVP
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
import seaborn as sns
import sklearn.datasets as datasets
import torch
from torch import optim

# Update the plotting style
sns.set_context("notebook")
sns.set_palette("colorblind")

In [None]:
x, y = datasets.make_moons(128, noise=0.05)
plt.scatter(x[:, 0], x[:, 1])
plt.show()

In [None]:
flow = RealNVP(
    n_inputs=2,
    n_transforms=5,
    n_neurons=32,
    batch_norm_between_transforms=True,
)

In [None]:
optimizer = optim.Adam(flow.parameters())

In [None]:
num_iter = 5000
train_loss = []

for i in range(num_iter):
    t_loss = 0

    x, y = datasets.make_moons(128, noise=0.1)
    x = torch.tensor(x, dtype=torch.float32)
    optimizer.zero_grad()
    loss = -flow.log_prob(inputs=x).mean()
    loss.backward()
    optimizer.step()
    t_loss += loss.item()

    if (i + 1) % 500 == 0:
        xline = torch.linspace(-1.5, 2.5, 100)
        yline = torch.linspace(-0.75, 1.25, 100)
        xgrid, ygrid = torch.meshgrid(xline, yline)
        xyinput = torch.cat(
            [xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], dim=1
        )

        with torch.no_grad():
            zgrid = flow.log_prob(xyinput).exp().reshape(100, 100)

        plt.contourf(xgrid.numpy(), ygrid.numpy(), zgrid.numpy())
        plt.title("iteration {}".format(i + 1))
        plt.show()

    train_loss.append(t_loss)

In [None]:
plt.plot(train_loss)
plt.xlabel("Iteration", fontsize=12)
plt.ylabel("Training loss", fontsize=12)
plt.show()

## Drawing samples from the flow

We can now draw samples from the trained flow.

In [None]:
n = 1000
flow.eval()
with torch.no_grad():
    generated_samples = flow.sample(1000)

In [None]:
plt.scatter(generated_samples[:, 0], generated_samples[:, 1])
plt.show()

## Plotting the latent space

We can pass samples through the flow and produces samples in the latent space. These samples (z) should be Gaussian.

In [None]:
flow.eval()
with torch.no_grad():
    z_, _ = flow.forward(x)

In [None]:
g = np.linspace(-5, 5, 100)
plt.plot(g, norm.pdf(g), label="Standard Gaussian")

plt.hist(z_[:, 0], density=True, label="z_0")
plt.hist(z_[:, 1], density=True, label="z_1")
plt.legend()
plt.show()