<h1>Spectral Analysis on Forward Pass</h1>

In this notebook, we are looking at whether or not the forward pass of flow models is autoregressive. For all models we are assuming:

\begin{align}
x_0 &\sim \mathcal{N}(0, 1) \\

x_1 &\sim \text{CIFAR-10}
\end{align}

We then want to move from $x_1$ to $x_0$ following table 1 in https://arxiv.org/abs/2302.00482.

## Imports

In [1]:
from matplotlib import rc
rc('animation', html='jshtml')
import matplotlib.pyplot as plt
from matplotlib import animation

import torch
import torch.nn as nn
from torchvision import datasets, transforms

## Load the dataset (CIFAR-10)

In [2]:
dataset = datasets.CIFAR10(
    root="../../../data",
    train=True,
    download=False,
    transform=transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    ),
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4,
    drop_last=True,
)

In [3]:
from torchcfm.conditional_flow_matching import (
    TargetConditionalFlowMatcher, 
    SchrodingerBridgeConditionalFlowMatcher, 
    VariancePreservingConditionalFlowMatcher, 
    ExactOptimalTransportConditionalFlowMatcher,
    VariancePreservingDiffusion,
)

def get_forward_pass(model, x0, x1, t):
    t, xt, ut = model.sample_location_and_conditional_flow(x0, x1, t)
    return xt

FM_NAME = "otcfm"
if FM_NAME == "tcfm":
    FM = TargetConditionalFlowMatcher()
elif FM_NAME == "sbcfm":
    FM = SchrodingerBridgeConditionalFlowMatcher()
elif FM_NAME == "vpcfm":
    FM = VariancePreservingConditionalFlowMatcher()
elif FM_NAME == "otcfm":
    FM = ExactOptimalTransportConditionalFlowMatcher()
elif FM_NAME == "vpdiff":
    FM = VariancePreservingDiffusion()

x1 = next(iter(dataloader))[0]
x0 = torch.normal(0, 1, x1.size())

batch_size = x0.size(0)

TIME_STEPS = 100
t = torch.linspace(0, 1, TIME_STEPS)
t = t.repeat(batch_size, 1)


In [4]:
import matplotlib.pyplot as plt
import numpy as np

all_images = torch.zeros_like(x1)
all_images = all_images.permute(0, 2, 3, 1).cpu().unsqueeze(0)
all_images = all_images.repeat(t.size(1), 1, 1, 1, 1)

mu_ts = torch.zeros_like(all_images)
sigma_ts = torch.zeros_like(t).permute(1, 0)

for i in range(t.size(1)):
    t_ = t[:, i]
    xt = get_forward_pass(FM, x0, x1, t_)
    xt = xt.detach()
    xt = xt / 2 + 0.5
    xt = xt.clip(0, 1)
    xt = xt.permute(0, 2, 3, 1)
    all_images[i] = xt
    mu_t = FM.compute_mu_t(x0, x1, t_).permute(0, 2, 3, 1)
    sigma_t = FM.compute_sigma_t(t_)
    mu_ts[i] = mu_t
    sigma_ts[i] = sigma_t

all_images = all_images.numpy()

In [5]:
import pysteps.utils.spectral as spectral

def calc_mean_log_rapsd(x, num_examples):
  spectra = []
  for k in range(num_examples):
    rapsd, frequencies = spectral.rapsd(x[k, ...], fft_method=np.fft, return_freq=True)
    spectra.append(rapsd)

  mean_log_rapsd = np.mean(np.array([np.log(s + 1e-30) for s in spectra]), axis=0)
  return mean_log_rapsd, frequencies

Pysteps configuration file found at: /Users/oskarjor/miniconda3/envs/torchcfm/lib/python3.10/site-packages/pysteps/pystepsrc



In [6]:
total_time_steps = all_images.shape[0]
num_examples = all_images.shape[1]

images = all_images[-1]
mean_log_rapsd, frequencies = calc_mean_log_rapsd(images[..., 1], num_examples)

In [7]:
i = 0
mu_t = mu_ts[i]
sigma_t = sigma_ts[i]

sigma_t = sigma_t.repeat(32, 32, 3, 1).permute(3, 0, 1, 2)

In [8]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5), width_ratios=[2, 1])

ax1.plot(frequencies[1:], np.exp(mean_log_rapsd)[1:], c='red', marker='o', markersize=3, label='images')
ax1.set_xscale('log')
ax1.set_yscale('log')
ax1.set_xlabel('frequency')
ax1.set_ylabel('power')
ax2.axis('off')

num_steps = total_time_steps
cycles_per_sec = 0.2
selected_image_idx = 0

noise_scale_min, noise_scale_max = 0.02, 50.0
noise_scales = np.logspace(np.log10(noise_scale_min), np.log10(noise_scale_max), num_steps)

artists = []
for i in range(num_steps):
  noisy_images = all_images[-i-1]
  mu_t = mu_ts[-i-1]
  sigma_t = sigma_ts[-i-1].repeat(32, 32, 3, 1).permute(3, 0, 1, 2)
  noise = torch.normal(mu_t, sigma_t).cpu().numpy()
  mean_log_rapsd_sum, _ = calc_mean_log_rapsd(noisy_images[..., 1], num_examples=num_examples)
  # mean_log_rapsd_noise, _ = calc_mean_log_rapsd(noise[..., 1], num_examples=num_examples)
  container = ax1.plot(frequencies[1:], np.exp(mean_log_rapsd_sum)[1:], c='green', marker='o', markersize=3, label='noisy images')
  # container += ax1.plot(frequencies[1:], np.exp(mean_log_rapsd_noise)[1:], c='blue', marker='o', markersize=3, label='noise')
  
  container += [ax2.imshow(np.clip((noisy_images[selected_image_idx] + 1) / 2, 0, 1))]

  artists.append(container)

plt.close()  # Avoid showing static plot, we only want to see the animation.

anim = animation.ArtistAnimation(fig=fig, artists=artists, interval=1000 / (num_steps * cycles_per_sec))
anim

In [9]:
anim.save(f'cifar10_noise_scale_{FM_NAME}.gif', writer='imagemagick', fps=30)

MovieWriter imagemagick unavailable; using Pillow instead.


In [19]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5), width_ratios=[2, 1])

ax1.plot(frequencies[1:], np.exp(mean_log_rapsd)[1:], c='red', marker='o', markersize=3, label='images')
ax1.set_xscale('log')
ax1.set_yscale('log')
ax1.set_xlabel('frequency')
ax1.set_ylabel('power')
ax2.axis('off')

noise = np.random.normal(0, 1, size=images.shape)

noise_scales = np.logspace(np.log10(0.02), np.log10(2.0), num_steps)

artists = []
for i in range(num_steps):
  # current_noise = noise * noise_scales[i]
  current_noise = np.random.normal(0, noise_scales[i], size=images.shape)
  noisy_images = images + current_noise
  mean_log_rapsd_sum_noise, _ = calc_mean_log_rapsd(current_noise[..., 1], num_examples=num_examples)
  mean_log_rapsd_sum, _ = calc_mean_log_rapsd(noisy_images[..., 1], num_examples=num_examples)
  container = ax1.plot(frequencies[1:], np.exp(mean_log_rapsd_sum_noise)[1:], c='blue', marker='o', markersize=3, label='noise')
  container += ax1.plot(frequencies[1:], np.exp(mean_log_rapsd_sum)[1:], c='green', marker='o', markersize=3, label='noise')
  
  container += [ax2.imshow(np.clip((current_noise[selected_image_idx] + 1) / 2, 0, 1))]

  artists.append(container)

plt.close()  # Avoid showing static plot, we only want to see the animation.

anim = animation.ArtistAnimation(fig=fig, artists=artists, interval=1000 / (num_steps * cycles_per_sec))
anim