# Conditional normalizing flow

Trying out normalizing flow on a single with, conditionining on position, with permutation of the coordinates during the training.

In [116]:
import numpy as np
import pandas as pd
import torch
from nflows.distributions import StandardNormal
from nflows.flows import MaskedAutoregressiveFlow, Flow
from nflows.transforms import CompositeTransform, BatchNorm, MaskedAffineAutoregressiveTransform, RandomPermutation, ReversePermutation
from torch.optim import Adam
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib
import itertools

from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import matplotlib

In [117]:
n_dim = 6
n_layers = 4
n_hidden_features = 5
n_condition_features = 12
use_random_masks = False
use_random_permutations = False
n_blocks = 4
use_residual_blocks = True
dropout_probability = 0.0
use_batch_norm_within_layers = True
use_batch_norm_between_layers = True


In [118]:
# initialize flow
if use_random_permutations:
    permutation_constructor = RandomPermutation
else:
    permutation_constructor = ReversePermutation

layers = []
for _ in range(n_layers):
    layers.append(permutation_constructor(n_dim))
    layers.append(
        MaskedAffineAutoregressiveTransform(
            features=n_dim,
            hidden_features=n_hidden_features,
            context_features=n_condition_features,
            num_blocks=n_blocks,
            use_residual_blocks=use_residual_blocks,
            random_mask=use_random_masks,
            dropout_probability=dropout_probability,
            use_batch_norm=use_batch_norm_within_layers,
        )
    )

    if use_batch_norm_between_layers:
        layers.append(BatchNorm(n_dim))

flow = Flow(
    transform=CompositeTransform(layers),
    distribution=StandardNormal([n_dim]),
)

optimizer = Adam(flow.parameters(), lr=1e-3)

data = pd.read_csv('data/df_all.csv').set_index('dates')
position = torch.tensor(np.load('data/position.npy').astype(np.float32))

sst = torch.tensor(data.iloc[0]).float()

In [119]:
N = 10000

with tqdm(total=N) as pbar:
    
    for i in range(N):

        optimizer.zero_grad()

        ssts = torch.empty((0, n_dim))
        positions = torch.empty((0, 2 * n_dim))
        for _ in range(16):
            perm = torch.randperm(6)
            ssts = torch.cat([ssts, sst[perm][None, :]])
            positions = torch.cat([positions, position.reshape(6, 2)[perm].reshape(1, -1)])

        loss = torch.mean(-flow.log_prob(ssts, context=positions))
        loss.backward()
        optimizer.step()

        pbar.update(1)
        pbar.set_description(f"{loss.item():.4f}")

    optimizer.zero_grad()


-6.3999: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [02:04<00:00, 80.24it/s]


In [120]:
def sample(flow, noise, position):
    position = position.repeat(len(noise)).reshape(len(noise), -1)
    return torch.squeeze(flow._transform.inverse(noise, context=position)[0])


In [121]:
def plot(i):
    
    perm = torch.tensor(list(itertools.permutations([1, 2, 3, 4, 5, 6]))[i])

    fig, ax = plt.subplots(ncols=2, figsize=(10, 4))

    lats, lons = position[:6], position[6:]

    cmap = matplotlib.cm.get_cmap('viridis')
    norm = matplotlib.colors.Normalize(vmin=sst.min(), vmax=sst.max())

    for i, (lat, lon) in enumerate(zip(lats, lons)):
        ax[0].scatter(lat, lon, color=cmap(norm(sst[i])), s=100)

    perm = torch.randperm(6)
    lats, lons = position[:6], position[6:]
    lats = lats[perm]
    lons = lons[perm]
    sst_pred = sample(flow, noise, torch.cat([lats, lons])).detach().numpy()
    
    for i, (lat, lon) in enumerate(zip(lats, lons)):
        ax[1].scatter(lat, lon, color=cmap(norm(sst_pred[i])), s=100)


In [122]:
flow.eval()
noise = torch.randn((1, 6))

interact(plot, i=widgets.IntSlider(min=0, max=719, step=1, value=10))

interactive(children=(IntSlider(value=10, description='i', max=719), Output()), _dom_classes=('widget-interact…

<function __main__.plot(i)>