In [None]:
%load_ext autoreload
%autoreload 2
import os

import matplotlib.pyplot as plt
import torch
import torchsde
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid
from tqdm import tqdm

from torchcfm.conditional_flow_matching import *
from torchcfm.models.unet import UNetModel
from myUnetWrapper import *
import torchdiffeq

import numpy as np
from numpy import *
import matplotlib.pyplot as plt
import torch

from typing import List
import time
from torchdyn.core import NeuralODE

from tqdm import tqdm
from torch.distributions.multivariate_normal import MultivariateNormal
import pickle
from copy import deepcopy
import gc

root_dir = "/hpc/group/mastatlab/gw74/HWD/"
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
torch.set_default_device(device)

In [None]:
import metric.pytorch_ssim
from metric.IS_score import *
from metric.Fid_score import *
from torchmetrics.image.kid import KernelInceptionDistance

# 0. Load Data

In [None]:
Images = np.load(root_dir + "Images(28x28).npy")
WriterInfo = np.load(root_dir + "WriterInfo.npy")
# first 2 number: digit & ID

# normalize the images and send to torch and cuda
Images = torch.tensor(Images/255).float().to(device)

In [None]:
num00 = 0
num05 = 8
num10 = 6

selId_00 = WriterInfo[:,0] == num00
selId_05 = WriterInfo[:,0] == num05
selId_10 = WriterInfo[:,0] == num10

image_00 = Images[selId_00,:,:]
image_05 = Images[selId_05,:,:]
image_10 = Images[selId_10,:,:]
id_00 = WriterInfo[selId_00,1]
id_05 = WriterInfo[selId_05,1]
id_10 = WriterInfo[selId_10,1]

In [None]:
# train-test split
n_total = image_00.shape[0]
train_prop = 0.8

train_id = np.random.choice(n_total, int(n_total*train_prop), replace=False)
test_id = np.setdiff1d(np.arange(n_total), train_id)

image_00_train = image_00[train_id,:,:]
image_05_train = image_05[train_id,:,:]
image_10_train = image_10[train_id,:,:]
id_00_train = id_00[train_id]
id_05_train = id_05[train_id]
id_10_train = id_10[train_id]

image_00_test = image_00[test_id,:,:]
image_05_test = image_05[test_id,:,:]
image_10_test = image_10[test_id,:,:]
id_00_test = id_00[test_id]
id_05_test = id_05[test_id]
id_10_test = id_10[test_id]

In [None]:
id_unique = unique(id_00_train)
n_per = 10
n_sec = int(np.ceil(id_unique.shape[0]/n_per))

id_grp = []
for ll in range(n_sec):
    if ll < n_sec-1:
        idx_tmp = np.arange(n_per) + ll*n_per
        id_grp.append(id_unique[idx_tmp])
    else:
        idx_tmp = np.arange(n_per*ll, id_unique.shape[0])
        id_grp.append(id_unique[idx_tmp])

# 1. Functions

## 1.1 ICFM

In [None]:
def icfm_fit(model, optimizer, d0, d1, id0, id1, n_epochs, id_grp, sigma = 0.0,
            cond = True):
    
    FM = ConditionalFlowMatcher(sigma=sigma)
    for epoch in tqdm(range(n_epochs)):
        for ll in range(len(id_grp)):
            optimizer.zero_grad()
            d1_idx_tmp = np.concatenate([np.random.permutation(np.where((id1 == idx))[0]) 
                            for idx in id_grp[ll]], axis = 0)
            d1_tmp = d1[d1_idx_tmp,:,:]
            x1 = d1_tmp.reshape(-1,1,28,28)
            
            if d0 is None:
                x0 = torch.randn_like(x1)
            else:
                d0_idx_tmp = np.concatenate([np.random.permutation(np.where((id0 == idx))[0]) 
                            for idx in id_grp[ll]], axis = 0)
                d0_tmp = d0[d0_idx_tmp,:,:]
                x0 = d0_tmp.reshape(-1,1,28,28)
            
            t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
            if cond:
                vt = model(t, torch.cat((x0, xt), 1))
            else:
                vt = model(t, xt)
            
            loss = torch.mean((vt - ut) ** 2)
            loss.backward()
            optimizer.step()
            
    return model

## 1.2 GP-ICFM

In [None]:
def calc_r(ti, tj):
    r = ti[...,None] - tj[...,None,:]
    r[r == 0] = 1e-15
    return r
def k11(r, alpha, l):
    return (alpha**2)*torch.exp(-0.5 * ((r/l)**2))
def k12(r, alpha, l):
    return (alpha**2/l**2)*r*torch.exp(-0.5*((r/l)**2))
def k22(r, alpha, l):
    return (alpha**2/l**4)*(l**2 - r**2)*torch.exp(-0.5*((r/l)**2))

In [None]:
def cov_mat2(ti, tj, alpha, l, sig2_diag = 1e-8):
    
    r = calc_r(ti, tj)
    nB = r.shape[0]
    nt = r.shape[1]
    
    Sig11 = k11(r, alpha, l) + (torch.eye(nt)*sig2_diag).repeat(nB,1,1)
    Sig12 = k12(r, alpha, l)
    Sig21 = Sig12.permute(0, 2, 1)
    Sig22 = k22(r, alpha, l)
    
    block_row1 = torch.cat([Sig11, Sig12], dim=2)
    block_row2 = torch.cat([Sig21, Sig22], dim=2)
    Sig = torch.cat([block_row1, block_row2], dim = 1)
    Sig = (Sig + Sig.permute(0, 2, 1))/2
    
    return Sig

In [None]:
def samp_x_dx2(t_mat, alpha, l, x_obs, t_obs, sig2_diag = 1e-8):
    
    nB = x_obs.shape[0]
    dim = x_obs.shape[2]
    nt = t_mat.shape[1]
    nt_obs = t_obs.shape[0]
    
    r_obs_x = calc_r(t_obs, t_mat)
    r_obs_obs = calc_r(t_obs, t_obs)
    
    Sig_11 = cov_mat2(t_mat, t_mat, alpha, l, sig2_diag)
    k_obs_x = k11(r_obs_x, alpha, l)
    k_obs_dx = k12(r_obs_x, alpha, l)
    
    Sig_21 = torch.cat([k_obs_x, k_obs_dx], dim=2)
    Sig_12 = Sig_21.permute(0, 2, 1)
    
    Sig_22_sing = k11(r_obs_obs, alpha, l) + torch.eye(nt_obs)*sig2_diag
    Sig_22_inv_sing = torch.linalg.inv(Sig_22_sing)
    Sig_22 = Sig_22_sing.repeat(nB,1,1)
    Sig_22_inv = Sig_22_inv_sing.repeat(nB,1,1)
    
    Sig_cond = Sig_11 - torch.bmm(torch.bmm(Sig_12, Sig_22_inv), Sig_21)
    Sig_cond = (Sig_cond + Sig_cond.permute(0, 2, 1))/2
    
    # svd_add_idx = np.where(sum((torch.linalg.eigvals(Sig_cond).real>=0).T) != Sig_cond.shape[1])[0]
    svd_add_idx = torch.sum((torch.linalg.eigvals(Sig_cond).real>=0).T, axis = 0) != Sig_cond.shape[1]
    U, S, Vh = torch.linalg.svd(Sig_cond[svd_add_idx,:,:])
#     U, S, Vh = torch.linalg.svd(Sig_cond)
    Sig_cond_add = torch.bmm(torch.bmm(Vh.permute(0, 2, 1), torch.diag_embed(S + 1e-8)), Vh)
    Sig_cond[svd_add_idx,:,:] = (Sig_cond_add + Sig_cond_add.permute(0, 2, 1))/2
#     Sig_cond = (Sig_cond_add + Sig_cond_add.permute(0, 2, 1))/2
    
    mu_A = torch.bmm(Sig_12, Sig_22_inv)
    
    x_samps = torch.zeros((nB, nt, dim))
    dx_samps = torch.zeros((nB, nt, dim))
    
    for dd in range(dim):
        x_obs_tmp = x_obs[:,:,dd]
        x_obs_tmp_batch = torch.reshape(x_obs_tmp, (nB, nt_obs, 1))
        mu_new = torch.bmm(mu_A, x_obs_tmp_batch).reshape((nB, 2*nt))
        try:
            x_dx_samps_tmp = MultivariateNormal(loc=mu_new, covariance_matrix=Sig_cond).rsample()
        except:
            print('torch fail')
            x_dx_samps_tmp = np.zeros((nB, 2*nt))
            for bb in range(nB):
                x_dx_samps_tmp[bb,:] = np.random.multivariate_normal(mu_new[bb,:], Sig_cond[bb,:,:])
            x_dx_samps_tmp = torch.from_numpy(x_dx_samps_tmp)
            
        x_samps[:,:,dd] = x_dx_samps_tmp[:,0:nt]
        dx_samps[:,:,dd] = x_dx_samps_tmp[:,nt:(2*nt)]
    
    return x_samps, dx_samps

In [None]:
def gp_icfm_fit(model, optimizer, d00, d05, d10,
                id00, id05, id10,
                n_epochs, id_grp, alpha, l, sig2_diag = 0, cond = True):
    for epoch in tqdm(range(n_epochs)):
        for ll in range(len(id_grp)):
            optimizer.zero_grad()
            
            d10_idx_tmp = np.concatenate([np.random.permutation(np.where((id10 == idx))[0]) 
                            for idx in id_grp[ll]], axis = 0)
            d10_tmp = d10[d10_idx_tmp,:,:]
            x10 = d10_tmp.reshape(-1,1,28,28)
            
            d05_idx_tmp = np.concatenate([np.random.permutation(np.where((id05 == idx))[0]) 
                            for idx in id_grp[ll]], axis = 0)
            d05_tmp = d05[d05_idx_tmp,:,:]
            x05 = d05_tmp.reshape(-1,1,28,28)
            
            if d00 is None:
                x00 = torch.randn_like(x10)
            else:
                d00_idx_tmp = np.concatenate([np.random.permutation(np.where((id00 == idx))[0]) 
                            for idx in id_grp[ll]], axis = 0)
                d00_tmp = d00[d00_idx_tmp,:,:]
                x00 = d00_tmp.reshape(-1,1,28,28)
            
            n_samp = x10.shape[0]
            
            xall_trans = torch.zeros(n_samp, 3, 28*28)
            xall_trans[:,0,:] = torch.reshape(x00, (n_samp, -1))
            xall_trans[:,1,:] = torch.reshape(x05, (n_samp, -1))
            xall_trans[:,2,:] = torch.reshape(x10, (n_samp, -1))
            
            t_mat = torch.rand((n_samp,1))
            try:
                xt_batch, ut_batch = samp_x_dx2(t_mat, alpha, l, xall_trans,
                                                torch.tensor([0, 0.5, 1]), sig2_diag)
            except:
                print('sample fail')
                pass
            
            t = torch.reshape(t_mat, (-1, ))
            xt = torch.reshape(xt_batch, (n_samp, 1, 28, 28))
            ut = torch.reshape(ut_batch, (n_samp, 1, 28, 28))
            
            if cond:
                vt = model(t, torch.cat((x00, xt), 1))
            else:
                vt = model(t, xt)

            loss = torch.mean((vt - ut) ** 2)
            loss.backward()
            optimizer.step()
            
    return model

## 1.3 fitting

In [None]:
def fit_icfm_grad(model, d0, d1, id0, id1, id_grp,
                  sigma = 0.0,
                  lr_grad = [1e-3, 8e-4, 5e-4, 2e-4],
                  n_epoch_grad = [100, 100, 100, 100],
                  cond = True):
    n_grad = len(lr_grad)
    for ll in range(n_grad):
        optimizer = torch.optim.Adam(model.parameters(), lr = lr_grad[ll])
        model = icfm_fit(model, optimizer, d0, d1,
                         id0, id1, n_epoch_grad[ll],
                         id_grp, sigma = sigma, cond = cond)
    return model

In [None]:
def fit_gp_icfm_grad(model, d00, d05, d10,
                    id00, id05, id10, id_grp, alpha, l, sig2_diag = 0,
                    lr_grad = [1e-3, 8e-4, 5e-4, 2e-4],
                    n_epoch_grad = [100, 100, 100, 100],
                    cond = True):
    n_grad = len(lr_grad)
    for ll in range(n_grad):
        optimizer = torch.optim.Adam(model.parameters(), lr = lr_grad[ll])
        
        model = gp_icfm_fit(model, optimizer, d00, d05, d10,
                            id00, id05, id10,
                            n_epoch_grad[ll], id_grp, alpha, l,
                            sig2_diag = sig2_diag, cond = cond)
        
    return model  

## 1.4 Plotting

In [None]:
def gen_samp(x0, model, device, cond = True):
    
    if cond:
        traj = torchdiffeq.odeint(
            lambda t, x: model.forward(t, torch.cat((x0, x), 1)),
            x0,
            torch.linspace(0, 1, 2, device=device),
            atol=1e-4,
            rtol=1e-4,
            method="dopri5",
        )
    else:
        traj = torchdiffeq.odeint(
            lambda t, x: model.forward(t, x),
            x0,
            torch.linspace(0, 1, 2, device=device),
            atol=1e-4,
            rtol=1e-4,
            method="dopri5",
        )
    
    samp_out = traj[-1,:,:,:,:].clip(0,1)
    
    return samp_out

In [None]:
def plot_grid(x0_gen, model, device, n_sec = 5, cond = True):
    
    if cond:
        traj = torchdiffeq.odeint(
            lambda t, x: model.forward(t, torch.cat((x0_gen, x), 1)),
            x0_gen,
            torch.linspace(0, 1, n_sec, device=device),
            atol=1e-4,
            rtol=1e-4,
            method="dopri5",
        )
    else:
        traj = torchdiffeq.odeint(
            lambda t, x: model.forward(t, x),
            x0_gen,
            torch.linspace(0, 1, n_sec, device=device),
            atol=1e-4,
            rtol=1e-4,
            method="dopri5",
        )
    
    grid = make_grid(
        traj.view([-1, 1, 28, 28]).clip(0, 1), value_range=(0, 1), padding=0, nrow=10
    )
    img = ToPILImage()(grid)
    plt.imshow(img)

In [None]:
def plot_grid_comb(x0_gen, model0, model1, device, cond = True):
    
    if cond:
        traj0 = torchdiffeq.odeint(
            lambda t, x: model0.forward(t, torch.cat((x0_gen, x), 1)),
            x0_gen,
            torch.linspace(0, 1, 5, device=device),
            atol=1e-4,
            rtol=1e-4,
            method="dopri5",
        )
        
        traj1 = torchdiffeq.odeint(
            lambda t, x: model1.forward(t, torch.cat((traj0[-1,:].view([-1, 1, 28, 28]).clip(0, 1), x), 1)),
            traj0[-1,:].view([-1, 1, 28, 28]).clip(0, 1),
            torch.linspace(0, 1, 5, device=device),
            atol=1e-4,
            rtol=1e-4,
            method="dopri5",
        )
        
    else:
        traj0 = torchdiffeq.odeint(
            lambda t, x: model0.forward(t, x),
            x0_gen,
            torch.linspace(0, 1, 5, device=device),
            atol=1e-4,
            rtol=1e-4,
            method="dopri5",
        )
        
        traj1 = torchdiffeq.odeint(
            lambda t, x: model1.forward(t, x),
            traj0[-1,:].view([-1, 1, 28, 28]).clip(0, 1),
            torch.linspace(0, 1, 5, device=device),
            atol=1e-4,
            rtol=1e-4,
            method="dopri5",
        )
    
    
    traj0_trans = traj0.view([-1, 1, 28, 28]).clip(0, 1)
    traj1_trans = traj1[1:,:,:,:,:].view([-1, 1, 28, 28]).clip(0, 1)    
    traj_cat = torch.cat((traj0_trans, traj1_trans), 0) 

    grid = make_grid(
        traj_cat, value_range=(0, 1), padding=0, nrow=10
    )

    img = ToPILImage()(grid)
    plt.imshow(img)

## 1.5 evaluation

In [None]:
def gen_1_traj(x0_gen, model, device, n_sec = 5, cond = True):
    gc.collect()
    torch.cuda.empty_cache()
    if cond:
        traj = torchdiffeq.odeint(
            lambda t, x: model.forward(t, torch.cat((x0_gen, x), 1)),
            x0_gen,
            torch.linspace(0, 1, n_sec, device=device),
            atol=1e-4,
            rtol=1e-4,
            method="dopri5",
        )
    else:
        traj = torchdiffeq.odeint(
            lambda t, x: model.forward(t, x),
            x0_gen,
            torch.linspace(0, 1, n_sec, device=device),
            atol=1e-4,
            rtol=1e-4,
            method="dopri5",
        )
    
    traj = traj.clip(0, 1)
    
    return traj

In [None]:
def fid_calc(all_images, test_data):
    gc.collect()
    torch.cuda.empty_cache()
    up = nn.Upsample(size=(299, 299), mode='bilinear').type(torch.cuda.FloatTensor)
    
    
    # all_images = up(torch.Tensor(all_images).cuda(0)).cpu().numpy()
    all_images = up(torch.Tensor(all_images)).cpu().numpy()
    all_images = np.transpose(all_images,(0,2,3,1))
    all_images = np.repeat(all_images,3,axis=3)
    
    real_image = np.repeat(test_data,3,axis=1)
    # real_image=up(real_image.cuda(0)).cpu().numpy()
    real_image=up(real_image).cpu().numpy()
    real_images=np.transpose(real_image,(0,2,3,1))
    
    Fid = calculate_fid(all_images, real_images, use_multiprocessing=False, batch_size=4)
    
    return Fid

# 2. Fitting

## 2.0 noise to $x_0$

In [None]:
model0 = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
model0 = fit_icfm_grad(model0, None, image_00_train, id_00_train, id_05_train, id_grp,
              sigma = 0.0,
              lr_grad = [1e-3, 8e-4, 5e-4, 2e-4, 8e-5, 2e-5, 8e-6, 2e-6],
              n_epoch_grad = [200, 200, 200, 200, 100, 100, 100, 100],
              cond = False)

## 2.1 ICFM

$x_0$ to $x_{0.5}$

In [None]:
%%capture output
# unconditional
model_icfm_uncond0 = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
model_icfm_uncond0 = fit_icfm_grad(model_icfm_uncond0, image_00_train, image_05_train,
                                   id_00_train, id_05_train, id_grp,
                                   sigma = 0.0,
                                   lr_grad = [8e-5, 2e-5, 8e-6, 2e-6],
                                   n_epoch_grad = [100, 100, 100, 100],
                                   cond = False)

# conditional
model_icfm_cond0 = UNetModelWrapper2(dim=(1, 28, 28), in_channels = 2,
                                     out_channels = 1,
                                     num_channels=32, num_res_blocks=1).to(device)
model_icfm_cond0 = fit_icfm_grad(model_icfm_cond0, image_00_train, image_05_train,
                                 id_00_train, id_05_train, id_grp,
                                 sigma = 0.0,
                                 lr_grad = [8e-5, 2e-5, 8e-6, 2e-6],
                                 n_epoch_grad = [100, 100, 100, 100],
                                 cond = True)

$x_{0.5}$ to $x_1$

In [None]:
%%capture output
# unconditional
model_icfm_uncond1 = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
model_icfm_uncond1 = fit_icfm_grad(model_icfm_uncond1, image_05_train, image_10_train,
                                   id_05_train, id_10_train, id_grp,
                                   sigma = 0.0,
                                   lr_grad = [8e-5, 2e-5, 8e-6, 2e-6],
                                   n_epoch_grad = [100, 100, 100, 100],
                                   cond = False)

# conditional
model_icfm_cond1 = UNetModelWrapper2(dim=(1, 28, 28), in_channels = 2,
                                     out_channels = 1,
                                     num_channels=32, num_res_blocks=1).to(device)
model_icfm_cond1 = fit_icfm_grad(model_icfm_cond1, image_05_train, image_10_train,
                                 id_05_train, id_10_train, id_grp,
                                 sigma = 0.0,
                                 lr_grad = [8e-5, 2e-5, 8e-6, 2e-6],
                                 n_epoch_grad = [100, 100, 100, 100],
                                 cond = True)

## 2.2 GP-ICFM

In [None]:
alpha = 1
l = 6 # 6
sig2_diag = 0

In [None]:
%%capture output
# unconditional
model_gp_icfm_uncond = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
model_gp_icfm_uncond = fit_gp_icfm_grad(model_gp_icfm_uncond,
                                        image_00_train, image_05_train, image_10_train,
                                        id_00_train, id_05_train, id_10_train, id_grp,
                                        alpha, l, sig2_diag = sig2_diag,
                                        lr_grad = [8e-5, 2e-5, 8e-6, 2e-6],
                                        n_epoch_grad = [100, 100, 100, 100],
                                        cond = False)

# conditional
model_gp_icfm_cond = UNetModelWrapper2(dim=(1, 28, 28), in_channels = 2,
                                       out_channels = 1,
                                       num_channels=32, num_res_blocks=1).to(device)
model_gp_icfm_cond = fit_gp_icfm_grad(model_gp_icfm_cond,
                                      image_00_train, image_05_train, image_10_train,
                                      id_00_train, id_05_train, id_10_train, id_grp,
                                      alpha, l, sig2_diag = sig2_diag,
                                      lr_grad = [8e-5, 2e-5, 8e-6, 2e-6],
                                      n_epoch_grad = [100, 100, 100, 100],
                                      cond = True)

In [None]:
rootFolder = "/hpc/group/mastatlab/gw74/HWD"

# torch.save(model0.state_dict(), rootFolder + "/initial0.pt")

# torch.save(model_icfm_uncond0.state_dict(), rootFolder + "/icfm_uncond0.pt")
# torch.save(model_icfm_cond0.state_dict(), rootFolder + "/icfm_cond0.pt")
# torch.save(model_icfm_uncond1.state_dict(), rootFolder + "/icfm_uncond1.pt")
# torch.save(model_icfm_cond1.state_dict(), rootFolder + "/icfm_cond1.pt")
# torch.save(model_gp_icfm_uncond.state_dict(), rootFolder + "/gp_icfm_uncond.pt")
# torch.save(model_gp_icfm_cond.state_dict(), rootFolder + "/gp_icfm_cond.pt")

# model0.load_state_dict(torch.load(rootFolder + "/initial0.pt"))
# model_icfm_uncond0.load_state_dict(torch.load(rootFolder + "/icfm_uncond0.pt"))
# model_icfm_uncond1.load_state_dict(torch.load(rootFolder + "/icfm_uncond1.pt"))
# model_icfm_cond0.load_state_dict(torch.load(rootFolder + "/icfm_cond0.pt"))
# model_icfm_cond1.load_state_dict(torch.load(rootFolder + "/icfm_cond1.pt"))
# model_gp_icfm_uncond.load_state_dict(torch.load(rootFolder + "/gp_icfm_uncond.pt"))
# model_gp_icfm_cond.load_state_dict(torch.load(rootFolder + "/gp_icfm_cond.pt"))


# CPU version, in case GPU is not available
model0.load_state_dict(torch.load(rootFolder + "/initial0.pt", map_location=torch.device('cpu')))
model_icfm_uncond0.load_state_dict(torch.load(rootFolder + "/icfm_uncond0.pt", map_location=torch.device('cpu')))
model_icfm_uncond1.load_state_dict(torch.load(rootFolder + "/icfm_uncond1.pt", map_location=torch.device('cpu')))
model_icfm_cond0.load_state_dict(torch.load(rootFolder + "/icfm_cond0.pt", map_location=torch.device('cpu')))
model_icfm_cond1.load_state_dict(torch.load(rootFolder + "/icfm_cond1.pt", map_location=torch.device('cpu')))
model_gp_icfm_uncond.load_state_dict(torch.load(rootFolder + "/gp_icfm_uncond.pt", map_location=torch.device('cpu')))
model_gp_icfm_cond.load_state_dict(torch.load(rootFolder + "/gp_icfm_cond.pt", map_location=torch.device('cpu')))

# 3. Plotting

In [None]:
plot_dir = "/hpc/home/gw74/diff_model/FM/submission/plots/5_HWD"

In [None]:
gc.collect()
torch.cuda.empty_cache()

torch.manual_seed(0)
x0_noise = torch.randn(10, 1, 28, 28, device=device)
x0_gen = gen_samp(x0_noise, model0, device, cond = False)

In [None]:
# unconditional ICFM
# model_icfm_uncond0.load_state_dict(torch.load(rootFolder + "/icfm_uncond0.pt"))
# model_icfm_uncond1.load_state_dict(torch.load(rootFolder + "/icfm_uncond1.pt"))
plot_grid_comb(x0_gen, model_icfm_uncond0, model_icfm_uncond1, device, cond = False);
plt.xticks([]);
plt.yticks(np.arange(0, 9*28, 28)+ 15, np.arange(0, 1.125, 0.125));
plt.savefig(plot_dir + "/icfm_uncond.svg")

In [None]:
# conditional ICFM
# model_icfm_cond0.load_state_dict(torch.load(rootFolder + "/icfm_cond0.pt"))
# model_icfm_cond1.load_state_dict(torch.load(rootFolder + "/icfm_cond1.pt"))
plot_grid_comb(x0_gen, model_icfm_cond0, model_icfm_cond1, device, cond = True)
plt.xticks([]);
plt.yticks(np.arange(0, 9*28, 28)+ 15, np.arange(0, 1.125, 0.125));
plt.savefig(plot_dir + "/icfm_cond.svg")

In [None]:
# unconditional GP-ICFM
# model_gp_icfm_uncond.load_state_dict(torch.load(rootFolder + "/gp_icfm_uncond.pt"))
plot_grid(x0_gen, model_gp_icfm_uncond, device, n_sec = 9, cond = False)
plt.xticks([]);
plt.yticks(np.arange(0, 9*28, 28)+ 15, np.arange(0, 1.125, 0.125));
plt.savefig(plot_dir + "/gp_icfm_uncond.svg")

In [None]:
# conditional GP-ICFM
# model_gp_icfm_cond.load_state_dict(torch.load(rootFolder + "/gp_icfm_cond.pt"))
plot_grid(x0_gen, model_gp_icfm_cond, device, n_sec = 9, cond = True)
plt.xticks([]);
plt.yticks(np.arange(0, 9*28, 28)+ 15, np.arange(0, 1.125, 0.125));
plt.savefig(plot_dir + "/gp_icfm_cond.svg")

# 4. FID

## 4.1 generate samples

In [None]:
batch_eval = 32 # 32
rep_eval = 10 # 40 # 10 is enough, total test size = 272

torch.set_default_device('cpu')
device_def = torch.get_default_device() 

In [None]:
# %%capture output
# generate starting points
gc.collect()
torch.cuda.empty_cache()

init_list = []
trace_list_icfm_uncond = [] # 9 grids
trace_list_icfm_cond = [] # 9 grids

trace_list_gpicfm_uncond = [] # 9 grids
trace_list_gpicfm_cond = [] # 9 grids

for ii in tqdm(range(rep_eval)):
    
    gc.collect()
    torch.cuda.empty_cache()
    try:
        x0_noise = torch.randn(batch_eval, 1, 28, 28, device=device)
        x0_tm = gen_samp(x0_noise, model0.to(device), device, cond = False)
        
    except:
        x0_noise = torch.randn(batch_eval, 1, 28, 28)
        x0_tm = gen_samp(x0_noise.to(device_def), model0.to(device_def),
                         device_def, cond = False)
    init_list.append(x0_tm)
    
    # icfm_unconditional
    gc.collect()
    torch.cuda.empty_cache()
    try:
        traj0_uncond_tmp = gen_1_traj(x0_tm.to(device), model_icfm_uncond0.to(device), device,
                                  n_sec = 5, cond = False)
        x_05_tmp = traj0_uncond_tmp[-1,:,:,:,:]
        traj1_uncond_tmp = gen_1_traj(x_05_tmp, model_icfm_uncond1.to(device), device,
                                  n_sec = 5, cond = False)
    except:
        traj0_uncond_tmp = gen_1_traj(x0_tm.to(device_def), model_icfm_uncond0.to(device_def), device_def,
                                      n_sec = 5, cond = False)
        x_05_tmp = traj0_uncond_tmp[-1,:,:,:,:]
        traj1_uncond_tmp = gen_1_traj(x_05_tmp, model_icfm_uncond1.to(device_def), device_def,
                                      n_sec = 5, cond = False)
    trace_icfm_uncond_tmp = torch.cat((traj0_uncond_tmp,
                                       traj1_uncond_tmp[1:,:,:,:,:]), axis = 0).detach().cpu().numpy()
    
    # icfm conditional
    gc.collect()
    torch.cuda.empty_cache()
    try:
        traj0_cond_tmp = gen_1_traj(x0_tm.to(device), model_icfm_cond0.to(device), device,
                                n_sec = 5, cond = True)
        x_05_tmp = traj0_cond_tmp[-1,:,:,:,:]
        traj1_cond_tmp = gen_1_traj(x_05_tmp, model_icfm_cond1.to(device), device,
                                    n_sec = 5, cond = True)
    except:
        traj0_cond_tmp = gen_1_traj(x0_tm.to(device_def), model_icfm_cond0.to(device_def), device_def,
                                    n_sec = 5, cond = True)
        x_05_tmp = traj0_cond_tmp[-1,:,:,:,:]
        traj1_cond_tmp = gen_1_traj(x_05_tmp, model_icfm_cond1.to(device_def), device_def,
                                    n_sec = 5, cond = True)
    trace_icfm_cond_tmp = torch.cat((traj0_cond_tmp,
                                     traj1_cond_tmp[1:,:,:,:,:]), axis = 0).detach().cpu().numpy()
    
    # gp-icfm unconditional
    gc.collect()
    torch.cuda.empty_cache()
    try:
        traj_gp_uncond_tmp = gen_1_traj(x0_tm.to(device), model_gp_icfm_uncond.to(device), device,
                                    n_sec = 9, cond = False).detach().cpu().numpy()
    except:
        traj_gp_uncond_tmp = gen_1_traj(x0_tm.to(device_def), model_gp_icfm_uncond.to(device_def), device_def,
                                        n_sec = 9, cond = False).detach().cpu().numpy()
    
    # gp-icfm conditional
    gc.collect()
    torch.cuda.empty_cache()
    try:
        traj_gp_cond_tmp = gen_1_traj(x0_tm.to(device), model_gp_icfm_cond.to(device), device,
                                  n_sec = 9, cond = True).detach().cpu().numpy()
    except:
        traj_gp_cond_tmp = gen_1_traj(x0_tm.to(device_def), model_gp_icfm_cond.to(device_def), device_def,
                                      n_sec = 9, cond = True).detach().cpu().numpy()
    
    if ii == 0:
        for jj in range(9):
            trace_list_icfm_uncond.append(trace_icfm_uncond_tmp[jj,:])
            trace_list_icfm_cond.append(trace_icfm_cond_tmp[jj,:])
            trace_list_gpicfm_uncond.append(traj_gp_uncond_tmp[jj,:])
            trace_list_gpicfm_cond.append(traj_gp_cond_tmp[jj,:])
    else:
        for jj in range(9):
            trace_list_icfm_uncond[jj] = np.concatenate((trace_list_icfm_uncond[jj],
                                                         trace_icfm_uncond_tmp[jj,:]),axis=0)
            trace_list_icfm_cond[jj] = np.concatenate((trace_list_icfm_cond[jj],
                                                         trace_icfm_cond_tmp[jj,:]),axis=0)
            trace_list_gpicfm_uncond[jj] = np.concatenate((trace_list_gpicfm_uncond[jj],
                                                         traj_gp_uncond_tmp[jj,:]),axis=0)
            trace_list_gpicfm_cond[jj] = np.concatenate((trace_list_gpicfm_cond[jj],
                                                         traj_gp_cond_tmp[jj,:]),axis=0)

In [None]:
import pickle

image_list_icfm_uncond = [] # 9 grids
image_list_icfm_cond = [] # 9 grids
image_list_gpicfm_uncond = [] # 9 grids
image_list_gpicfm_cond = [] # 9 grids

for ii in range(1, 9):
    
    with open("trace_list_icfm_uncond_" + str(ii), "rb") as fp: trace_list_icfm_uncond_tmp = pickle.load(fp);  
    with open("trace_list_icfm_cond_" + str(ii), "rb") as fp: trace_list_icfm_cond_tmp = pickle.load(fp);
    with open("trace_list_gpicfm_uncond_" + str(ii), "rb") as fp: trace_list_gpicfm_uncond_tmp = pickle.load(fp);
    with open("trace_list_gpicfm_cond_" + str(ii), "rb") as fp: trace_list_gpicfm_cond_tmp = pickle.load(fp);
    
    for jj in range(9):
        
        if ii == 1:
            image_list_icfm_uncond.append(trace_list_icfm_uncond_tmp[jj])
            image_list_icfm_cond.append(trace_list_icfm_cond_tmp[jj])
            image_list_gpicfm_uncond.append(trace_list_gpicfm_uncond_tmp[jj])
            image_list_gpicfm_cond.append(trace_list_gpicfm_cond_tmp[jj])
        else:
            image_list_icfm_uncond[jj] = np.concatenate((image_list_icfm_uncond[jj],
                                                         trace_list_icfm_uncond_tmp[jj]),axis=0)
            image_list_icfm_cond[jj] = np.concatenate((image_list_icfm_cond[jj],
                                                         trace_list_icfm_cond_tmp[jj]),axis=0)
            image_list_gpicfm_uncond[jj] = np.concatenate((image_list_gpicfm_uncond[jj],
                                                         trace_list_gpicfm_uncond_tmp[jj]),axis=0)
            image_list_gpicfm_cond[jj] = np.concatenate((image_list_gpicfm_cond[jj],
                                                         trace_list_gpicfm_cond_tmp[jj]),axis=0)

## 4.2 calculate FID

In [None]:
def fid_calc(all_images_raw, test_data):
    gc.collect()
    torch.cuda.empty_cache()
    up = nn.Upsample(size=(299, 299), mode='bilinear').type(torch.cuda.FloatTensor)
    
    all_images = all_images_raw[272:(test_data.shape[0] + 272),:,:,:]
    # all_images = up(torch.Tensor(all_images).cuda(0)).cpu().numpy()
    all_images = up(torch.Tensor(all_images)).cpu().numpy()
    all_images = np.transpose(all_images,(0,2,3,1))
    all_images = np.repeat(all_images,3,axis=3)
    
    real_image = np.repeat(test_data,3,axis=1)
    # real_image=up(real_image.cuda(0)).cpu().numpy()
    real_image=up(real_image).cpu().numpy()
    real_images=np.transpose(real_image,(0,2,3,1))
    
    Fid = calculate_fid(all_images, real_images, use_multiprocessing=False, batch_size=4)
    
    return Fid

In [None]:
np.random.seed(0)
n_tot = image_list_icfm_uncond[0].shape[0]
n_test = image_00_test.shape[0]
img_idx = np.random.choice(n_tot, size=n_test, replace=False)

In [None]:
%%capture output
fid_00_icfm_uncond = np.zeros(9)
fid_00_icfm_cond = np.zeros(9)
fid_00_gp_uncond = np.zeros(9)
fid_00_gp_cond = np.zeros(9)

fid_05_icfm_uncond = np.zeros(9)
fid_05_icfm_cond = np.zeros(9)
fid_05_gp_uncond = np.zeros(9)
fid_05_gp_cond = np.zeros(9)

fid_10_icfm_uncond = np.zeros(9)
fid_10_icfm_cond = np.zeros(9)
fid_10_gp_uncond = np.zeros(9)
fid_10_gp_cond = np.zeros(9)

for ll in tqdm(range(9)):
    
    fid_00_icfm_uncond[ll] = fid_calc(image_list_icfm_uncond[ll][img_idx,:,:,:], image_00_test.view([-1, 1, 28, 28]))
    fid_00_icfm_cond[ll] = fid_calc(image_list_icfm_cond[ll][img_idx,:,:,:], image_00_test.view([-1, 1, 28, 28]))
    fid_00_gp_uncond[ll] = fid_calc(image_list_gpicfm_uncond[ll][img_idx,:,:,:], image_00_test.view([-1, 1, 28, 28]))
    fid_00_gp_cond[ll] = fid_calc(image_list_gpicfm_cond[ll][img_idx,:,:,:], image_00_test.view([-1, 1, 28, 28]))

    fid_05_icfm_uncond[ll] = fid_calc(image_list_icfm_uncond[ll][img_idx,:,:,:], image_05_test.view([-1, 1, 28, 28]))
    fid_05_icfm_cond[ll] = fid_calc(image_list_icfm_cond[ll][img_idx,:,:,:], image_05_test.view([-1, 1, 28, 28]))
    fid_05_gp_uncond[ll] = fid_calc(image_list_gpicfm_uncond[ll][img_idx,:,:,:], image_05_test.view([-1, 1, 28, 28]))
    fid_05_gp_cond[ll] = fid_calc(image_list_gpicfm_cond[ll][img_idx,:,:,:], image_05_test.view([-1, 1, 28, 28]))

    fid_10_icfm_uncond[ll] = fid_calc(image_list_icfm_uncond[ll][img_idx,:,:,:], image_10_test.view([-1, 1, 28, 28]))
    fid_10_icfm_cond[ll] = fid_calc(image_list_icfm_cond[ll][img_idx,:,:,:], image_10_test.view([-1, 1, 28, 28]))
    fid_10_gp_uncond[ll] = fid_calc(image_list_gpicfm_uncond[ll][img_idx,:,:,:], image_10_test.view([-1, 1, 28, 28]))
    fid_10_gp_cond[ll] = fid_calc(image_list_gpicfm_cond[ll][img_idx,:,:,:], image_10_test.view([-1, 1, 28, 28]))

fid_all = np.column_stack((fid_00_icfm_uncond, fid_00_icfm_cond, fid_00_gp_uncond, fid_00_gp_cond,
           fid_05_icfm_uncond, fid_05_icfm_cond, fid_05_gp_uncond, fid_05_gp_cond,
           fid_10_icfm_uncond, fid_10_icfm_cond, fid_10_gp_uncond, fid_10_gp_cond))
np.savetxt("fid.csv", fid_all, delimiter=",")

## 4.3 plot results

In [None]:
plot_dir = "/hpc/home/gw74/diff_model/FM/submission/plots/5_HWD"
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['text.usetex'] = False
plt.rcParams.update({'font.size': 14})

In [None]:
plt.plot(fid_00_icfm_uncond);
plt.plot(fid_00_icfm_cond);
plt.plot(fid_00_gp_uncond);
plt.plot(fid_00_gp_cond);
plt.title("FID to '0'")
plt.legend(['ICFM-uncond', 'ICFM-cond', 'GP-ICFM-uncond', 'GP-ICFM-cond'])
plt.savefig(plot_dir + "/fid0.svg")

In [None]:
plt.plot(fid_05_icfm_uncond);
plt.plot(fid_05_icfm_cond);
plt.plot(fid_05_gp_uncond);
plt.plot(fid_05_gp_cond);
plt.title("FID to '8'")
plt.legend(['ICFM-uncond', 'ICFM-cond', 'GP-ICFM-uncond', 'GP-ICFM-cond'])
plt.savefig(plot_dir + "/fid8.svg")

In [None]:
plt.plot(fid_10_icfm_uncond);
plt.plot(fid_10_icfm_cond);
plt.plot(fid_10_gp_uncond);
plt.plot(fid_10_gp_cond);
plt.title("FID to '6'");
plt.legend(['ICFM-uncond', 'ICFM-cond', 'GP-ICFM-uncond', 'GP-ICFM-cond'])
plt.savefig(plot_dir + "/fid6.svg")