In [None]:
#####################
# ## COLAB SETUP ## #
#####################

# set to True if running on colab to connect google drive, otherwise set to False
USE_COLAB = True
# set to True if using TPU connected to google cloud project
USE_TPU = True
# project name for TPU connection
PROJECT_NAME=""

if USE_COLAB:
  from google.colab import drive
  drive.mount('/content/drive')

  if USE_TPU:
    # authenticate google cloud credentials
    from google.colab import auth
    auth.authenticate_user()
    # give colab access to project
    !gcloud config set project $PROJECT_NAME

In [None]:
if USE_COLAB:
  !pip install -q tfds-nightly
  !pip install tensorflow_addons
  !git clone https://github.com/point0bar1/ebm-life-cycle

In [3]:
####################
# ## PARAMETERS ## #
####################

config = {
  # paths for connecting to cloud storage
  "exp_folder": 'fid_out/',  # folder name for results
  "exp_dir": '',  # location for saving files

  # device type ('tpu' or 'gpu' or 'cpu')
  "device_type": 'tpu' if USE_TPU else 'gpu',

  # exp params
  "exp_type": "folder",
  "num_fid_rounds": 520,
  "batch_size": 96,
  "image_dims": [128, 128, 3], # cifar10: 32x32, celeb_a 64x64, imagenet: 128x128
  "split": "train",

  # data type and augmentation parameters
  "data_type": 'imagenet2012', # cifar10, celeb_a, imagenet2012
  "random_crop": False,

  # ebm network
  "net_type": 'ebm_sngan',
  "ebm_weights": "",  # path to EBM weights

  # langevin sampling parameters
  "mcmc_steps": 320,
  "epsilon": 3e-3,
  "mcmc_init": "coop",  # set to "coop" for generator or "data" for longrun from data
  "mcmc_temp": 1e-7,
  # clipping parameters
  "clip_langevin_grad": False,
  "max_langevin_norm": 0.25,

  "gen_type": "gen_sngan",
  "z_sz": 128,
  "gen_weights": ""  # path to generator net weights
}

In [None]:
# save images from tf2 model to png files to use original fid code for evaluation

import os
import sys
from datetime import datetime
import pickle
from tqdm import tqdm
import importlib
from pathlib import Path

import numpy as np
from PIL import Image

import tensorflow as tf
import tensorflow_datasets as tfds

sys.path.insert(0, '/content/ebm-life-cycle')
from init import init_strategy, initialize_nets_and_optim, initialize_data
from data import get_dataset
from utils import setup_exp, plot_ims

import argparse


def save_samples(strategy, config, ebm, gen=None, train_iterator=None, save_str='samples.pdf'):

  @tf.function
  def langevin_update(states_in):
    if config['mcmc_init'] == 'coop':
      # re-draw samples to avoid duplication on tpu device
      images_samp = tf.identity(gen(states_in))
    else:
      images_samp = tf.identity(states_in)

    # initial samples for visual check
    images_samp_init = tf.identity(images_samp)

    # langevin updates
    if config['mcmc_steps'] > 0:
      for i in tf.range(int(config['mcmc_steps'])):
        with tf.GradientTape() as tape:
          tape.watch(images_samp)
          energy = tf.math.reduce_sum(ebm(images_samp, training=False)) / config['mcmc_temp']
        grads = tape.gradient(energy, images_samp)
        # clip gradient norm (set to large value that won't interfere with standard dynamics)
        if config['clip_langevin_grad']:
          grads = tf.clip_by_norm(grads, config['max_langevin_norm'] / ((config['epsilon'] ** 2) / 2), axes=[1, 2, 3])

        # update images
        images_samp -= ((config['epsilon'] ** 2) / 2) * grads
        images_samp += config['epsilon'] * tf.random.normal(shape=tpu_tensor_size)

    return images_samp, images_samp_init

  per_replica_batch_size = config['batch_size'] // strategy.num_replicas_in_sync
  images_np_1 = np.zeros([0] + config['image_dims'])
  images_np_2 = np.zeros([0] + config['image_dims'])

  for i in range(config['num_fid_rounds']):
    print('Batch {} of {}'.format(i+1, config['num_fid_rounds']))

    # data images
    images_data = next(train_iterator)

    # generate samples from model
    if config['mcmc_init'] == 'data':
      sample_init = next(gen)
    elif config['mcmc_init'] == 'coop':
      z_init_tf = gen.generate_latent_z(config['batch_size'])
      def get_z_init(ctx):
        rep_id = ctx.replica_id_in_sync_group
        return z_init_tf[(rep_id*per_replica_batch_size):((rep_id+1)*per_replica_batch_size)]
      sample_init = strategy.experimental_distribute_values_from_function(get_z_init)
    else:
      raise ValueError('Invalid mcmc_init')

    # run langevin updates on initial state
    images_sample, images_sample_init = strategy.run(langevin_update, args=(sample_init,))

    # visualize initial and final samples for first batch
    if i == 0:
      plot_ims(os.path.join(config['exp_dir'], config['exp_folder'], 'images/' + save_str), 
               strategy.gather(images_sample, 0))
      plot_ims(os.path.join(config['exp_dir'], config['exp_folder'], 'images/init_' + save_str), 
               strategy.gather(images_sample_init, 0))
      plot_ims(os.path.join(config['exp_dir'], config['exp_folder'], 'images/data_' + save_str), 
               strategy.gather(images_data, 0))

    # record batch images
    p1 = Path(os.path.join(config['exp_dir'], config['exp_folder'], 'numpy_out/images1.npy'))
    with p1.open('ab') as f:
        images_data_rescale = np.rint(255 * (np.clip(strategy.gather(images_data, 0).numpy(), -1, 1) + 1) / 2)
        np.save(f, images_data_rescale.astype(np.uint8))
    p2 = Path(os.path.join(config['exp_dir'], config['exp_folder'], 'numpy_out/images2.npy'))
    with p2.open('ab') as f:
        images_sample_rescale = np.rint(255 * (np.clip(strategy.gather(images_sample, 0).numpy(), -1, 1) + 1) / 2)
        np.save(f, images_sample_rescale.astype(np.uint8))


###############
# ## SETUP ## #
###############

# setup folders, save code, set seed and get device
setup_exp(os.path.join(config['exp_dir'], config['exp_folder']), 
          ['images', 'numpy_out'],
          [os.path.join(config['exp_dir'], code_file) for 
           code_file in ['fid_save_ims.ipynb', 'nets.py', 'utils.py', 'data.py', 'init.py']],
          gs_path=None, save_to_cloud=False)

if config['device_type'] == 'tpu':
  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
  tf.config.experimental_connect_to_cluster(resolver)
  # This is the TPU initialization code that has to be at the beginning.
  tf.tpu.experimental.initialize_tpu_system(resolver)
  print("All devices: ", tf.config.list_logical_devices('TPU'))
  # Set up TPU Distribution
  strategy = tf.distribute.TPUStrategy(resolver)
else:
  strategy = init_strategy(config)


##################################################
# ## INITIALIZE NETS, DATA, PERSISTENT STATES ## #
##################################################

# load nets and optim
ebm, _, gen, _ = initialize_nets_and_optim(config, strategy)
ebm.trainable = False
if gen is not None:
  gen.trainable = False

# test deterministic output of ebm
with strategy.scope():
  state_test = tf.random.normal(shape=[3]+config['image_dims'])
  ebm_out_1 = ebm(state_test)
  ebm_out_2 = ebm(state_test[0:2])
ebm_out_1 = strategy.gather(ebm_out_1, axis=0)
ebm_out_2 = strategy.gather(ebm_out_2, axis=0)
print('EBM Determinism Test (should be close to 0): ', 
      tf.math.reduce_max(tf.math.abs(ebm_out_1[0] - ebm_out_2[0])))

# test deterministic output of gen
if gen is not None:
  with strategy.scope():
    gen_z = gen.generate_latent_z(3)
    gen_out_1 = gen(gen_z)
    gen_out_2 = gen(gen_z[0:2])
  gen_out_1 = strategy.gather(gen_out_1, axis=0)
  gen_out_2 = strategy.gather(gen_out_2, axis=0)
  print('Gen Determinism Test (should be close to 0): ', 
        tf.math.reduce_max(tf.math.abs(gen_out_1[0] - gen_out_2[0])))

# generator for data
train_iterator, _, _ = initialize_data(config, strategy)
if config['mcmc_init'] == 'data':
  # generator for data mcmc init
  gen, _, _ = initialize_data(config, strategy)

# Calculate per replica batch size, and distribute the datasets
per_replica_batch_size = config['batch_size'] // strategy.num_replicas_in_sync
batch_size = per_replica_batch_size * strategy.num_replicas_in_sync
tpu_tensor_size = [per_replica_batch_size] + config['image_dims']

# save bank of data and model samples as two np arrays
save_samples(strategy, config, ebm, gen, train_iterator)