In [1]:
# #@title Autoload all modules
%load_ext autoreload
%autoreload 2

from dataclasses import dataclass, field
import matplotlib.pyplot as plt
import io
import csv
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib
import importlib
import os
import functools
import itertools
import torch
from losses import get_optimizer
from models.ema import ExponentialMovingAverage

import torch.nn as nn
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_gan as tfgan
import tqdm
import io
import likelihood
import controllable_generation
from utils import restore_checkpoint
sns.set(font_scale=2)
sns.set(style="whitegrid")

import models
from models import utils as mutils
from models import ncsnv2
from models import ncsnpp
from models import ddpm as ddpm_model
from models import layerspp
from models import layers
from models import normalization
import sampling
from likelihood import get_likelihood_fn
from sde_lib import VESDE, VPSDE, subVPSDE
from sampling import (ReverseDiffusionPredictor, 
                      LangevinCorrector, 
                      EulerMaruyamaPredictor, 
                      AncestralSamplingPredictor, 
                      NoneCorrector, 
                      NonePredictor,
                      AnnealedLangevinDynamics)
import datasets
import wrapper
from fid_utils import get_fid

2022-10-18 21:10:11.036232: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0





In [2]:
# @title Load the score-based model
sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"}
if sde.lower() == 'vesde':
  from configs.ve import cifar10_ncsnpp_continuous as configs
  ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
  config = configs.get_config()  
  sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
  sampling_eps = 1e-5
elif sde.lower() == 'vpsde':
  from configs.vp import cifar10_ddpmpp_continuous as configs  
  ckpt_filename = "exp/vp/cifar10_ddpmpp_continuous/checkpoint_8.pth"
  config = configs.get_config()
  sde = VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
  sampling_eps = 1e-3
elif sde.lower() == 'subvpsde':
  from configs.subvp import cifar10_ddpmpp_continuous as configs
  ckpt_filename = "exp/subvp/cifar10_ddpmpp_continuous/checkpoint_26.pth"
  config = configs.get_config()
  sde = subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
  sampling_eps = 1e-3

batch_size =   64#@param {"type":"integer"}
config.training.batch_size = batch_size
config.eval.batch_size = batch_size

random_seed = 0 #@param {"type": "integer"}

sigmas = mutils.get_sigmas(config)
scaler = datasets.get_data_scaler(config)
inverse_scaler = datasets.get_data_inverse_scaler(config)
score_model = mutils.create_model(config)

optimizer = get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(),
                               decay=config.model.ema_rate)
state = dict(step=0, optimizer=optimizer,
             model=score_model, ema=ema)

state = restore_checkpoint(ckpt_filename, state, config.device)
ema.copy_to(score_model.parameters())

In [3]:
#@title Visualization code

def image_grid(x):
  size = config.data.image_size
  channels = config.data.num_channels
  img = x.reshape(-1, size, size, channels)
  w = int(np.sqrt(img.shape[0]))
  img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels))
  return img

def show_samples(x):
  x = x.permute(0, 2, 3, 1).detach().cpu().numpy()
  img = image_grid(x)
  plt.figure(figsize=(8,8))
  plt.axis('off')
  plt.imshow(img)
  plt.show()


In [5]:
#@title PC sampling
img_size = config.data.image_size
channels = config.data.num_channels
shape = (batch_size, channels, img_size, img_size)
predictor = ReverseDiffusionPredictor #@param ["EulerMaruyamaPredictor", "AncestralSamplingPredictor", "ReverseDiffusionPredictor", "None"] {"type": "raw"}
corrector = LangevinCorrector #@param ["LangevinCorrector", "AnnealedLangevinDynamics", "None"] {"type": "raw"}
snr = 0.16 #@param {"type": "number"}
n_steps =  1#@param {"type": "integer"}
probability_flow = False #@param {"type": "boolean"}
sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector,
                                      inverse_scaler, snr, n_steps=n_steps,
                                      probability_flow=probability_flow,
                                      continuous=config.training.continuous,
                                      eps=sampling_eps, device=config.device)
# x, n = sampling_fn(score_model)
# show_samples(x)
config.eval.batch_size = 32
get_fid(config, sampling_fn, score_model, eval_dir='assets/stats', job_name='_'.join(ckpt_filename.split('/')[:-1]))

2022-10-18 21:11:34.564823: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2022-10-18 21:11:34.565050: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2022-10-18 21:11:34.567121: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.695GHz coreCount: 82 deviceMemorySize: 23.70GiB deviceMemoryBandwidth: 871.81GiB/s
2022-10-18 21:11:34.568755: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 1 with properties: 
pciBusID: 0000:25:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.695GHz coreCount: 82 deviceMemorySize: 23.70GiB deviceMemoryBandwidth: 871.81GiB/s
2022-10-18 21:11:34.568818: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11

num of all generations: 50000 eval_batch_size: 64
sampling -- round: 0


2022-10-18 21:14:45.564794: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2022-10-18 21:14:45.617105: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2300000000 Hz
2022-10-18 21:14:46.594582: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2022-10-18 21:14:47.284792: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2022-10-18 21:14:47.291098: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
2022-10-18 21:14:47.356700: E tensorflow/stream_executor/cuda/cuda_dnn.cc:336] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2022-10-18 21:14:47.368810: E tensorflow/stream_executor/cuda/cuda_dnn.cc:336] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR


UnknownError: 3 root error(s) found.
  (0) Unknown:  Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.
	 [[{{node PartitionedCall_1/PartitionedCall/PartitionedCall/inception/conv}}]]
  (1) Unknown:  Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.
	 [[{{node PartitionedCall_1/PartitionedCall/PartitionedCall/inception/conv}}]]
	 [[PartitionedCall_1/PartitionedCall/PartitionedCall/inception/MatMul/_8]]
  (2) Unknown:  Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.
	 [[{{node PartitionedCall_1/PartitionedCall/PartitionedCall/inception/conv}}]]
	 [[split/_2]]
0 successful operations.
0 derived errors ignored. [Op:__inference_run_inception_distributed_1645]

Function call stack:
run_inception_distributed -> run_inception_distributed -> run_inception_distributed


In [4]:
sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector,
                                      inverse_scaler, snr, n_steps=n_steps,
                                      probability_flow=probability_flow,
                                      continuous=config.training.continuous,
                                      eps=sampling_eps, device=config.device, 
                                      use_wrapper=True,
                                      calibration=False,
                                      score_mean='mean_scores/exp_ve_cifar10_ncsnpp_continuous_checkpoint_24.pth')

# x, n = sampling_fn(score_model)
# show_samples(x)
get_fid(config, sampling_fn, score_model, eval_dir='assets/stats', job_name='our')

NameError: name 'shape' is not defined

In [None]:
# plot the momentum of estimated scores

score_fn = mutils.get_score_fn(sde, score_model, train=False, continuous=config.training.continuous)
n_estimates = 1
config.eval.batch_size = 1024

# Build data iterators
train_ds, _, _ = datasets.get_dataset(config,
                                      uniform_dequantization=config.data.uniform_dequantization,
                                      evaluation=True)
train_iter = iter(train_ds)

with torch.no_grad():
    timesteps = torch.linspace(sde.T, sampling_eps, sde.N, device=config.device)

    score_sum = torch.zeros(sde.N, channels, img_size, img_size).to(config.device)
    score_normsqr_sum = torch.zeros(sde.N).to(config.device)
    n_data = 0
    
    while 1:
        try:
            batch = next(train_iter)
        except StopIteration:
            break
        x = torch.from_numpy(batch['image']._numpy()).to(config.device).float()
        x = x.permute(0, 3, 1, 2)
        x = scaler(x)

        for i in range(sde.N):
            t = timesteps[i]
            vec_t = torch.ones(x.shape[0], device=t.device) * t 
            mean, std = sde.marginal_prob(x, vec_t)
            
            for _ in range(n_estimates):
                perturbed_data = mean + std[:, None, None, None] * torch.randn_like(x)
                score = score_fn(perturbed_data, vec_t)
                score_sum[i] += score.sum(0)
                score_normsqr_sum[i] += (score.flatten(1).norm(dim=1) ** 2).sum(0)
        
        n_data += n_estimates * x.shape[0]
        print(n_data / n_estimates)

        score_mean = score_sum / n_data
        score_mean_normsqr = score_mean.flatten(1).norm(dim=1) ** 2
        score_normsqr_mean = score_normsqr_sum / n_data
        ratio = (score_mean_normsqr / score_normsqr_mean).sqrt()
        torch.save(score_mean.cpu(), "mean_scores/{}_".format(n_data // n_estimates) + ckpt_filename.replace("/", "_"))

        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 4))
        ax1.plot(range(sde.N), score_normsqr_mean.cpu().numpy())
        ax2.plot(range(sde.N), score_mean_normsqr.cpu().numpy())
        ax3.plot(range(sde.N), ratio.cpu().numpy())
        plt.show()


In [None]:
### a baseline
batch_size = 64
img_size = config.data.image_size
channels = config.data.num_channels
shape = (batch_size, channels, img_size, img_size)
probability_flow = False

score_fn = mutils.get_score_fn(sde, score_model, train=False, continuous=config.training.continuous)
rsde = sde.reverse(score_fn, probability_flow)

with torch.no_grad():
    # Initial sample
    x = sde.prior_sampling(shape).to(config.device)
    timesteps = torch.linspace(sde.T, sampling_eps, sde.N, device=config.device)

    for i in range(sde.N):
        t = timesteps[i]
        vec_t = torch.ones(shape[0], device=t.device) * t

        f, G = rsde.discretize(x, vec_t)
        z = torch.randn_like(x)
        x_mean = x - f
        x = x_mean + G[:, None, None, None] * z

    x = inverse_scaler(x_mean)
    
show_samples(x)

In [None]:
from functools import partial
timesteps = torch.linspace(sde.T, sampling_eps, sde.N, device=config.device)
new_score_fn = wrapper.score_fn_wrapper(sde, score_fn, config.training.continuous, score_mean='mean_scores/exp_ve_cifar10_ncsnpp_continuous_checkpoint_24.pth', device=timesteps.device)
new_score_fn = partial(new_score_fn, timesteps=timesteps)
rsde = sde.reverse(new_score_fn, probability_flow)

with torch.no_grad():
    # Initial sample
    x = sde.prior_sampling(shape).to(config.device)
    
    for i in range(sde.N):
        t = timesteps[i]
        vec_t = torch.ones(shape[0], device=t.device) * t

        f, G = rsde.discretize(x, vec_t)
        z = torch.randn_like(x)
        x_mean = x - f
        x = x_mean + G[:, None, None, None] * z

    x = inverse_scaler(x_mean)
    
show_samples(x)
