In [None]:
%load_ext autoreload
%autoreload 2
import os

import matplotlib.pyplot as plt
import torch
import torchsde
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid
from tqdm import tqdm

from torchcfm.conditional_flow_matching import *
from torchcfm.models.unet import UNetModel
from torchcfm.optimal_transport import OTPlanSampler

savedir = "models/mnist"
os.makedirs(savedir, exist_ok=True)

import numpy as np
import matplotlib.pyplot as plt
import torch
# from torchcfm.optimal_transport import OTPlanSampler

from typing import List
import time
from torchdyn.core import NeuralODE

from tqdm import tqdm
from torch.distributions.multivariate_normal import MultivariateNormal
# import ot
# import ot.plot
import pickle
from copy import deepcopy
import gc

import metric.pytorch_ssim
from metric.IS_score import *
from metric.Fid_score import *
from torchmetrics.image.kid import KernelInceptionDistance

# 0. Load Data

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
torch.set_default_device(device)
# torch.get_default_device()

batch_size = 128

trainset = datasets.MNIST(
    "../data",
    train=True,
    download=False,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]),
)

train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, drop_last=True,
    generator=torch.Generator(device)
)

In [None]:
testset = datasets.MNIST(
    "../data",
    train=False,
    download=False,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]),
)

# 1. Functions

Compare 4 models:
1. Independent CFM (I-CFM)
2. Optimal transport CFM (OT-CFM)
3. GP independent CFM (GP-I-CFM)
4. GP optimal transport CFM (GP-OT-CFM)

## 1.0 Common Functions

In [None]:
def nonGP_model_fit(model, optimizer, FM, n_epochs, train_loader, device, subset = False, iMax = None):
    for epoch in tqdm(range(n_epochs)):
        for i, data in enumerate(train_loader):
            
            if subset:
                if i > iMax:
                    break
            
            
            optimizer.zero_grad()
            x1 = data[0].to(device)
            x0 = torch.randn_like(x1)
            t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
            vt = model(t, xt)
            loss = torch.mean((vt - ut) ** 2)
            loss.backward()
            optimizer.step()
    
    return model

In [None]:
def calc_r(ti, tj):
    r = ti[...,None] - tj[...,None,:]
    r[r == 0] = 1e-15
    return r
def k11(r, alpha, l):
    return (alpha**2)*torch.exp(-0.5 * ((r/l)**2))
def k12(r, alpha, l):
    return (alpha**2/l**2)*r*torch.exp(-0.5*((r/l)**2))
def k22(r, alpha, l):
    return (alpha**2/l**4)*(l**2 - r**2)*torch.exp(-0.5*((r/l)**2))

In [None]:
def cov_mat2(ti, tj, alpha, l, sig2_diag = 1e-8):
    
    r = calc_r(ti, tj)
    nB = r.shape[0]
    nt = r.shape[1]
    
    Sig11 = k11(r, alpha, l) + (torch.eye(nt)*sig2_diag).repeat(nB,1,1)
    Sig12 = k12(r, alpha, l)
    Sig21 = Sig12.permute(0, 2, 1)
    Sig22 = k22(r, alpha, l)
    
    block_row1 = torch.cat([Sig11, Sig12], dim=2)
    block_row2 = torch.cat([Sig21, Sig22], dim=2)
    Sig = torch.cat([block_row1, block_row2], dim = 1)
    Sig = (Sig + Sig.permute(0, 2, 1))/2
    
    return Sig

In [None]:
def samp_x_dx2(t_mat, alpha, l, x_obs, t_obs, sig2_diag = 1e-8):
    
    nB = x_obs.shape[0]
    dim = x_obs.shape[2]
    nt = t_mat.shape[1]
    nt_obs = t_obs.shape[0]
    
    r_obs_x = calc_r(t_obs, t_mat)
    r_obs_obs = calc_r(t_obs, t_obs)
    
    Sig_11 = cov_mat2(t_mat, t_mat, alpha, l, sig2_diag)
    k_obs_x = k11(r_obs_x, alpha, l)
    k_obs_dx = k12(r_obs_x, alpha, l)
    
    Sig_21 = torch.cat([k_obs_x, k_obs_dx], dim=2)
    Sig_12 = Sig_21.permute(0, 2, 1)
    
    Sig_22_sing = k11(r_obs_obs, alpha, l) + torch.eye(nt_obs)*sig2_diag
    Sig_22_inv_sing = torch.linalg.inv(Sig_22_sing)
    Sig_22 = Sig_22_sing.repeat(nB,1,1)
    Sig_22_inv = Sig_22_inv_sing.repeat(nB,1,1)
    
    Sig_cond = Sig_11 - torch.bmm(torch.bmm(Sig_12, Sig_22_inv), Sig_21)
    Sig_cond = (Sig_cond + Sig_cond.permute(0, 2, 1))/2
    
    # svd_add_idx = np.where(sum((torch.linalg.eigvals(Sig_cond).real>=0).T) != Sig_cond.shape[1])[0]
    svd_add_idx = sum((torch.linalg.eigvals(Sig_cond).real>=0).T) != Sig_cond.shape[1]
    U, S, Vh = torch.linalg.svd(Sig_cond[svd_add_idx,:,:])
    Sig_cond_add = torch.bmm(torch.bmm(Vh.permute(0, 2, 1), torch.diag_embed(S + 1e-8)), Vh)
    Sig_cond[svd_add_idx,:,:] = (Sig_cond_add + Sig_cond_add.permute(0, 2, 1))/2
    
    mu_A = torch.bmm(Sig_12, Sig_22_inv)
    
    x_samps = torch.zeros((nB, nt, dim))
    dx_samps = torch.zeros((nB, nt, dim))
    
    for dd in range(dim):
        x_obs_tmp = x_obs[:,:,dd]
        x_obs_tmp_batch = torch.reshape(x_obs_tmp, (nB, nt_obs, 1))
        mu_new = torch.bmm(mu_A, x_obs_tmp_batch).reshape((nB, 2*nt))
        try:
            x_dx_samps_tmp = MultivariateNormal(loc=mu_new, covariance_matrix=Sig_cond).rsample()
        except:
            print('use numpy')
            x_dx_samps_tmp = np.zeros((nB, 2*nt))
            for bb in range(nB):
                x_dx_samps_tmp[bb,:] = np.random.multivariate_normal(mu_new[bb,:], Sig_cond[bb,:,:])
            x_dx_samps_tmp = torch.from_numpy(x_dx_samps_tmp)
            
        x_samps[:,:,dd] = x_dx_samps_tmp[:,0:nt]
        dx_samps[:,:,dd] = x_dx_samps_tmp[:,nt:(2*nt)]
    
    return x_samps, dx_samps

## 1.1 I-CFM

In [None]:
def icfm_fit(model, optimizer, sigma, n_epochs, train_loader, device, subset = False, iMax = None):
    FM = ConditionalFlowMatcher(sigma=sigma)
    model = nonGP_model_fit(model, optimizer, FM, n_epochs, train_loader, device, subset, iMax)
    return model

## 1.2 OT-CFM

In [None]:
def otcfm_fit(model, optimizer, sigma, n_epochs, train_loader, device, subset = False, iMax = None):
    FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
    model = nonGP_model_fit(model, optimizer, FM, n_epochs, train_loader, device, subset, iMax)
    return model

## 1.3 GP-ICFM

In [None]:
def gp_icfm_fit(model, optimizer, alpha, l, sig2_diag, n_epochs, train_loader, device,
               subset = False, iMax = None):
    
    for epoch in tqdm(range(n_epochs)):
        for i, data in enumerate(train_loader):
            
            if subset:
                if i > iMax:
                    break
            
            optimizer.zero_grad()

            x1 = data[0].to(device)
            x0 = torch.randn_like(x1)

            btch_size = x1.shape[0]
            x01_trans = torch.zeros(btch_size, 2, 28*28)
            x01_trans[:,0,:] = torch.reshape(x0, (btch_size, -1))
            x01_trans[:,1,:] = torch.reshape(x1, (btch_size, -1))

            t_mat = torch.rand((btch_size,1))
            
            try:
                xt_batch, ut_batch = samp_x_dx2(t_mat, alpha, l, x01_trans,
                                                torch.tensor([0, 1]), sig2_diag)
            except:
                print('fail')
                pass

            t = torch.reshape(t_mat, (-1, ))
            xt = torch.reshape(xt_batch, (btch_size, 1, 28, 28))
            ut = torch.reshape(ut_batch, (btch_size, 1, 28, 28))
            vt = model(t, xt)

            loss = torch.mean((vt - ut) ** 2)
            loss.backward()
            optimizer.step()
    
    return model

## 1.4 GP-OTCFM

In [None]:
def gp_otcfm_fit(model, optimizer, alpha, l, sig2_diag, n_epochs, train_loader, device,
                subset = False, iMax = None):
    
    ot_sampler = OTPlanSampler(method="exact")
    
    for epoch in tqdm(range(n_epochs)):
        for i, data in enumerate(train_loader):
            
            if subset:
                if i > iMax:
                    break
            
            optimizer.zero_grad()

            x1 = data[0].to(device)
            x0 = torch.randn_like(x1)
            
            x0, x1 = ot_sampler.sample_plan(x0, x1)
            
            btch_size = x1.shape[0]
            x01_trans = torch.zeros(btch_size, 2, 28*28)
            x01_trans[:,0,:] = torch.reshape(x0, (btch_size, -1))
            x01_trans[:,1,:] = torch.reshape(x1, (btch_size, -1))

            t_mat = torch.rand((btch_size,1))
            
            try:
                xt_batch, ut_batch = samp_x_dx2(t_mat, alpha, l, x01_trans,
                                                torch.tensor([0, 1]), sig2_diag)
            except:
                print('fail')
                pass

            t = torch.reshape(t_mat, (-1, ))
            xt = torch.reshape(xt_batch, (btch_size, 1, 28, 28))
            ut = torch.reshape(ut_batch, (btch_size, 1, 28, 28))
            vt = model(t, xt)

            loss = torch.mean((vt - ut) ** 2)
            loss.backward()
            optimizer.step()
    
    return model

# 3. Model Training

In [None]:
sigma = 0.0
n_epochs = 5

# truncate the samples
subset = False
iMax = None

alpha = 1
l = 5 # let's try l = 4 later...
sig2_diag = 0

nRep = 5
rootFolder = "/hpc/home/gw74/diff_model/FM/image/model_store/unconditional_short"

In [None]:
%%capture output
torch.set_default_device(device)
load_pre = False

gc.collect()
torch.cuda.empty_cache()

for ll in range(nRep):
    
    # 1. icfm
    model_icfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
    if load_pre:
        model_icfm.load_state_dict(torch.load(rootFolder + "/icfm" + str(ll) + ".pt"))
    optimizer_icfm = torch.optim.Adam(model_icfm.parameters(), lr = 2e-4)
    model_icfm = icfm_fit(model_icfm, optimizer_icfm, sigma, n_epochs,
                          train_loader, device, subset = subset, iMax = iMax)
    torch.save(model_icfm.state_dict(), rootFolder + "/icfm" + str(ll) + ".pt")
    
    
    # 2. otcfm
    model_otcfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
    if load_pre:
        model_otcfm.load_state_dict(torch.load(rootFolder + "/otcfm" + str(ll) + ".pt"))
    optimizer_otcfm = torch.optim.Adam(model_otcfm.parameters(), lr = 2e-4)
    model_otcfm = otcfm_fit(model_otcfm, optimizer_otcfm, sigma, n_epochs,
                          train_loader, device, subset = subset, iMax = iMax)
    torch.save(model_otcfm.state_dict(), rootFolder + "/otcfm" + str(ll) + ".pt")
    
    
    # 3. gp-icfm
    model_gp_icfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
    if load_pre:
        model_gp_icfm.load_state_dict(torch.load(rootFolder + "/gp_icfm" + str(ll) + ".pt"))
    optimizer_gp_icfm = torch.optim.Adam(model_gp_icfm.parameters(), lr = 2e-4)
    model_gp_icfm = gp_icfm_fit(model_gp_icfm, optimizer_gp_icfm, alpha, l, sig2_diag, n_epochs, train_loader,
                                device, subset = subset, iMax = iMax)
    torch.save(model_gp_icfm.state_dict(), rootFolder + "/gp_icfm" + str(ll) + ".pt")
    
    
    # 4. gp-otcfm
    model_gp_otcfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
    if load_pre:
        model_gp_otcfm.load_state_dict(torch.load(rootFolder + "/gp_otcfm" + str(ll) + ".pt"))
    optimizer_gp_otcfm = torch.optim.Adam(model_gp_otcfm.parameters(), lr = 2e-4)
    model_gp_otcfm = gp_otcfm_fit(model_gp_otcfm, optimizer_gp_otcfm, alpha, l, sig2_diag, n_epochs, train_loader,
                                device, subset = subset, iMax = iMax)
    torch.save(model_gp_otcfm.state_dict(), rootFolder + "/gp_otcfm" + str(ll) + ".pt")

# 4. Plotting (check)

In [None]:
def plotFun(model, x0_gen):
    node = NeuralODE(model, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
    
    with torch.no_grad():
        traj = node.trajectory(
            x0_gen, t_span=torch.linspace(0, 1, 2, device=device),
        )
    grid = make_grid(
        traj[-1, :100].view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10
    )

    img = ToPILImage()(grid)
    plt.imshow(img)

In [None]:
torch.manual_seed(0)
x0_gen = torch.randn(100, 1, 28, 28, device=device)
nCheck = 0

In [None]:
model_icfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
model_icfm.load_state_dict(torch.load(rootFolder + "/icfm" + str(nCheck) + ".pt"))
plotFun(model_icfm, x0_gen)

In [None]:
model_otcfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
model_otcfm.load_state_dict(torch.load(rootFolder + "/otcfm" + str(nCheck) + ".pt"))
plotFun(model_otcfm, x0_gen)

In [None]:
model_gp_icfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
model_gp_icfm.load_state_dict(torch.load(rootFolder + "/gp_icfm" + str(nCheck) + ".pt"))
plotFun(model_gp_icfm, x0_gen)

In [None]:
model_gp_otcfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
model_gp_otcfm.load_state_dict(torch.load(rootFolder + "/gp_otcfm" + str(nCheck) + ".pt"))
plotFun(model_gp_otcfm, x0_gen)

# 5. Model Evaluation

In [None]:
import metric.pytorch_ssim
from metric.IS_score import *
from metric.Fid_score import *
from torchmetrics.image.kid import KernelInceptionDistance

In [None]:
batch_eval = 32
rep_eval = 40

torch.set_default_device('cpu')
test_loader = torch.utils.data.DataLoader(
    testset, batch_size=batch_eval, shuffle=True, drop_last=True,
    generator=torch.Generator('cpu')
)

In [None]:
def gen_1_img(node, x0_gen):
    with torch.no_grad():
        traj = node.trajectory(
        x0_gen, t_span=torch.linspace(0, 1, 2, device=device),)
    traj = traj[-1, :].clip(-1, 1)
    return traj

def gen_img(node, rep_eval, batch_eval):
    for i in range(rep_eval):
        torch.manual_seed(i)
        x0_gen_id = torch.randn(batch_eval, 1, 28, 28, device=device)
        sampled_x = gen_1_img(node, x0_gen_id)

        recon_images = sampled_x.detach().cpu().numpy()
        if i==0:
            all_images=recon_images
        else:
            all_images = np.concatenate((all_images,recon_images),axis=0)

    return all_images

In [None]:
%%capture output
node_list_icfm = []
node_list_otcfm = []
node_list_gp_icfm = []
node_list_gp_otcfm = []

img_list_icfm = []
img_list_otcfm = []
img_list_gp_icfm = []
img_list_gp_otcfm = []

for ll in tqdm(range(nRep)):
    
    # 1. icfm
    model_icfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
    model_icfm.load_state_dict(torch.load(rootFolder + "/icfm" + str(ll) + ".pt"))
    node = NeuralODE(model_icfm, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
    all_images = gen_img(node, rep_eval, batch_eval)
    node_list_icfm.append(node)
    img_list_icfm.append(all_images)
    
    # 2. otcfm
    model_otcfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
    model_otcfm.load_state_dict(torch.load(rootFolder + "/otcfm" + str(ll) + ".pt"))
    node = NeuralODE(model_otcfm, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
    all_images = gen_img(node, rep_eval, batch_eval)
    node_list_otcfm.append(node)
    img_list_otcfm.append(all_images)
    
    # 3. gp_icfm
    model_gp_icfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
    model_gp_icfm.load_state_dict(torch.load(rootFolder + "/gp_icfm" + str(ll) + ".pt"))
    node = NeuralODE(model_gp_icfm, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
    all_images = gen_img(node, rep_eval, batch_eval)
    node_list_gp_icfm.append(node)
    img_list_gp_icfm.append(all_images)
    
    # 4. gp_otcfm
    model_gp_otcfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
    model_gp_otcfm.load_state_dict(torch.load(rootFolder + "/gp_otcfm" + str(ll) + ".pt"))
    node = NeuralODE(model_gp_otcfm, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
    all_images = gen_img(node, rep_eval, batch_eval)
    node_list_gp_otcfm.append(node)
    img_list_gp_otcfm.append(all_images)

In [None]:
import pickle

with open("img_list_icfm", "wb") as fp: pickle.dump(img_list_icfm, fp);
with open("img_list_otcfm", "wb") as fp: pickle.dump(img_list_otcfm, fp);    
with open("img_list_gp_icfm", "wb") as fp: pickle.dump(img_list_gp_icfm, fp);    
with open("img_list_gp_otcfm", "wb") as fp: pickle.dump(img_list_gp_otcfm, fp);

# with open("img_list_icfm", "rb") as fp: img_list_icfm = pickle.load(fp);
# with open("img_list_otcfm", "rb") as fp: img_list_otcfm = pickle.load(fp);
# with open("img_list_gp_icfm", "rb") as fp: img_list_gp_icfm = pickle.load(fp);
# with open("img_list_gp_otcfm", "rb") as fp: img_list_gp_otcfm = pickle.load(fp);

## 5.1 KID

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
for i, (images, labels) in enumerate(test_loader):
    
    if i==0:
        real_images_all = np.repeat(images,3,axis=1)
    else:
        real_image = np.repeat(images,3,axis=1)
        real_images_all = np.concatenate((real_images_all,real_image),axis=0)

In [None]:
all_real_n = real_images_all.shape[0]

In [None]:
def kid_calc(real_images, all_images, subset_size = rep_eval):
    
    kid = KernelInceptionDistance(subset_size = subset_size)
    
    real_trans = 255*(real_images + 1)/2
    all_trans = 255*(np.repeat(all_images,3,axis=1) + 1)/2
    
    A = torch.from_numpy(real_trans).type(torch.uint8)
    B = torch.from_numpy(all_trans).type(torch.uint8)
    
    kid.update(A, real=True)
    kid.update(B, real=False)
    kid_mean, kid_std = kid.compute()
    
    return kid_mean.item()

In [None]:
%%capture output

gc.collect()
torch.cuda.empty_cache()

kid_icfm = np.zeros(nRep)
kid_otcfm = np.zeros(nRep)
kid_gp_icfm = np.zeros(nRep)
kid_gp_otcfm = np.zeros(nRep)

for ll in tqdm(range(nRep)):
    gc.collect()
    torch.cuda.empty_cache()
    real_img_id = np.random.choice(all_real_n, batch_eval*rep_eval, replace = False)
    real_images = real_images_all[real_img_id,:,:,:]
    
    kid_icfm[ll] = kid_calc(real_images, img_list_icfm[ll], subset_size = rep_eval)
    kid_otcfm[ll] = kid_calc(real_images, img_list_otcfm[ll], subset_size = rep_eval)
    kid_gp_icfm[ll] = kid_calc(real_images, img_list_gp_icfm[ll], subset_size = rep_eval)
    kid_gp_otcfm[ll] = kid_calc(real_images, img_list_gp_otcfm[ll], subset_size = rep_eval)

In [None]:
print('icfm: {:.3f} +- {:.3f}'.format(np.mean(kid_icfm), np.std(kid_icfm)))
print('otcfm: {:.3f} +- {:.3f}'.format(np.mean(kid_otcfm), np.std(kid_otcfm)))
print('gp_icfm: {:.3f} +- {:.3f}'.format(np.mean(kid_gp_icfm), np.std(kid_gp_icfm)))
print('gp_otcfm: {:.3f} +- {:.3f}'.format(np.mean(kid_gp_otcfm), np.std(kid_gp_otcfm)))

## 5.2 FID

In [None]:
def fid_calc(all_images, test_loader, rep_eval = rep_eval):
    gc.collect()
    torch.cuda.empty_cache()
    up = nn.Upsample(size=(299, 299), mode='bilinear').type(torch.cuda.FloatTensor)
    
    all_images = up(torch.Tensor((all_images + 1)/2).cuda(0)).cpu().numpy()
    all_images = np.transpose(all_images,(0,2,3,1))
    all_images = np.repeat(all_images,3,axis=3)
    
    for i, (images, labels) in enumerate(test_loader):
        
        images = (images + 1)/2
        
        if i == rep_eval:
            break
        
        if i == 0:
            # real_image = np.repeat(images,3,axis=1)
            real_image = np.repeat((images + 1)/2,3,axis=1)
            real_image=up(real_image.cuda(0)).cpu().numpy()
            real_images=np.transpose(real_image,(0,2,3,1))
        else:
            # real_image = np.repeat(images,3,axis=1)
            real_image = np.repeat((images + 1)/2,3,axis=1)
            real_image=up(real_image.cuda(0)).cpu().numpy()
            real_image=np.transpose(real_image,(0,2,3,1))
            real_images = np.concatenate((real_images,real_image),axis=0)
        
    Fid = calculate_fid(all_images, real_images, use_multiprocessing=False, batch_size=4)
    return Fid

In [None]:
gc.collect()
torch.cuda.empty_cache()

fid_icfm = np.zeros(nRep)
fid_otcfm = np.zeros(nRep)
fid_gp_icfm = np.zeros(nRep)
fid_gp_otcfm = np.zeros(nRep)

for ll in tqdm(range(nRep)):
    
    fid_icfm[ll] = fid_calc(img_list_icfm[ll], test_loader, rep_eval = rep_eval)
    fid_otcfm[ll] = fid_calc(img_list_otcfm[ll], test_loader, rep_eval = rep_eval)
    fid_gp_icfm[ll] = fid_calc(img_list_gp_icfm[ll], test_loader, rep_eval = rep_eval)
    fid_gp_otcfm[ll] = fid_calc(img_list_gp_otcfm[ll], test_loader, rep_eval = rep_eval)

In [None]:
print('icfm: {:.3f} +- {:.3f}'.format(np.mean(fid_icfm), np.std(fid_icfm)))
print('otcfm: {:.3f} +- {:.3f}'.format(np.mean(fid_otcfm), np.std(fid_otcfm)))
print('gp_icfm: {:.3f} +- {:.3f}'.format(np.mean(fid_gp_icfm), np.std(fid_gp_icfm)))
print('gp_otcfm: {:.3f} +- {:.3f}'.format(np.mean(fid_gp_otcfm), np.std(fid_gp_otcfm)))

# 5. Summarize

Fit the model 100 ($5\times 20$) times, and summarize them via histogrms, mean and std error.

In [None]:
readFolder = "/hpc/group/mastatlab/gw74/mnist_eval"

In [None]:
kid_icfm = []
kid_otcfm = []
kid_gp_icfm = []
kid_gp_otcfm = []

fid_icfm = []
fid_otcfm = []
fid_gp_icfm = []
fid_gp_otcfm = []

for ll in range(20): 

    with open(readFolder + "/kid_icfm_" + str(ll), "rb") as fp: kid_icfm_tmp = pickle.load(fp);   
    with open(readFolder + "/kid_otcfm_" + str(ll), "rb") as fp: kid_otcfm_tmp = pickle.load(fp);
    with open(readFolder + "/kid_gp_icfm_" + str(ll), "rb") as fp: kid_gp_icfm_tmp = pickle.load(fp);
    with open(readFolder + "/kid_gp_otcfm_" + str(ll), "rb") as fp: kid_gp_otcfm_tmp = pickle.load(fp);    

    with open(readFolder + "/fid_icfm_" + str(ll), "rb") as fp: fid_icfm_tmp = pickle.load(fp);   
    with open(readFolder + "/fid_otcfm_" + str(ll), "rb") as fp: fid_otcfm_tmp = pickle.load(fp);
    with open(readFolder + "/fid_gp_icfm_" + str(ll), "rb") as fp: fid_gp_otcfm_tmp = pickle.load(fp);
    with open(readFolder + "/fid_gp_otcfm_" + str(ll), "rb") as fp: fid_gp_icfm_tmp = pickle.load(fp);
    
    kid_icfm.append(kid_icfm_tmp)
    kid_otcfm.append(kid_otcfm_tmp)
    kid_gp_icfm.append(kid_gp_icfm_tmp)
    kid_gp_otcfm.append(kid_gp_otcfm_tmp)
    
    fid_icfm.append(fid_icfm_tmp)
    fid_otcfm.append(fid_otcfm_tmp)
    fid_gp_icfm.append(fid_gp_icfm_tmp)
    fid_gp_otcfm.append(fid_gp_otcfm_tmp)

In [None]:
kid_icfm_all = np.ravel(kid_icfm)
kid_otcfm_all = np.ravel(kid_otcfm)
kid_gp_icfm_all = np.ravel(kid_gp_icfm)
kid_gp_otcfm_all = np.ravel(kid_gp_otcfm)

fid_icfm_all = np.ravel(fid_icfm)
fid_otcfm_all = np.ravel(fid_otcfm)
fid_gp_icfm_all = np.ravel(fid_gp_icfm)
fid_gp_otcfm_all = np.ravel(fid_gp_otcfm)

## 5.1 mean (std error)

In [None]:
print('icfm: {:.4f} +- {:.4f}'.format(np.mean(kid_icfm_all), np.std(kid_icfm_all)))
print('otcfm: {:.4f} +- {:.4f}'.format(np.mean(kid_otcfm_all), np.std(kid_otcfm_all)))
print('gp_icfm: {:.4f} +- {:.4f}'.format(np.mean(kid_gp_icfm_all), np.std(kid_gp_icfm_all)))
print('gp_otcfm: {:.4f} +- {:.4f}'.format(np.mean(kid_gp_otcfm_all), np.std(kid_gp_otcfm_all)))

In [None]:
print('icfm: {:.4f} +- {:.4f}'.format(np.mean(fid_icfm_all), np.std(fid_icfm_all)))
print('otcfm: {:.4f} +- {:.4f}'.format(np.mean(fid_otcfm_all), np.std(fid_otcfm_all)))
print('gp_icfm: {:.4f} +- {:.4f}'.format(np.mean(fid_gp_icfm_all), np.std(fid_gp_icfm_all)))
print('gp_otcfm: {:.4f} +- {:.4f}'.format(np.mean(fid_gp_otcfm_all), np.std(fid_gp_otcfm_all)))

## 5.2 histograms

In [None]:
plot_dir = "/hpc/home/gw74/diff_model/FM/submission/plots/4_MNIST"
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['text.usetex'] = False
plt.rcParams.update({'font.size': 14})

In [None]:
plt.rcParams['figure.figsize'] = [4, 3]
plt.hist(kid_icfm_all, alpha = 0.5, bins = 50);
plt.hist(kid_otcfm_all, alpha = 0.5, bins = 50);
plt.hist(kid_gp_icfm_all, alpha = 0.5, bins = 50);
plt.hist(kid_gp_otcfm_all, alpha = 0.5, bins = 50);
plt.legend(['i-cfm', 'ot-cfm', 'gp-icfm', 'gp-otcfm']);
plt.title('KID, 100 seeds');
plt.savefig(plot_dir + "/1_kid.svg")

In [None]:
plt.hist(fid_icfm_all, alpha = 0.5, bins = 50);
plt.hist(fid_otcfm_all, alpha = 0.5, bins = 50);
plt.hist(fid_gp_icfm_all, alpha = 0.5, bins = 50);
plt.hist(fid_gp_otcfm_all, alpha = 0.5, bins = 50);
plt.legend(['i-cfm', 'ot-cfm', 'gp-icfm', 'gp-otcfm']);
plt.title('FID, 100 seeds');
plt.savefig(plot_dir + "/2_fid.svg")