In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import division

import os
import pickle
import uuid
import time
from copy import deepcopy

from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import numpy as np
from brokenaxes import brokenaxes

from pico.gan import PicoGAN
from pico.user_models import ConvPolicy
from pico.discrim_models import ConvDiscrim
from pico.encoder_models import NVAEEncoder
from pico.envs import CelebAEnv
from pico.compression_models import Masker, ConvCompressor
from pico import utils
from pico import viz

In [None]:
from matplotlib import pyplot as plt
import matplotlib as mpl
%matplotlib inline

In [None]:
sess = utils.make_tf_session(gpu_mode=True)

In [None]:
#task_idx = 15 # eyeglasses
task_idx = 35 # hat

In [None]:
data_dir = os.path.join(utils.celeba_data_dir, str(task_idx))
if not os.path.exists(data_dir):
  os.makedirs(data_dir)
fig_dir = os.path.join(data_dir, 'figures')
if not os.path.exists(fig_dir):
  os.makedirs(fig_dir)

In [None]:
dataset = utils.make_celeba_dataset(use_cache=True, task_idx=task_idx)

In [None]:
n_act_dims = dataset['n_classes']
img_shape = dataset['img_shape']
flat_img_size = 1
for x in img_shape:
  flat_img_size *= x
img_shape

In [None]:
encoder = NVAEEncoder()

In [None]:
bn_n_batches = 32
bn_batch_size = 32
bn_imgs = []
for img_idx in range(0, bn_batch_size*bn_n_batches, bn_batch_size):
  bn_img = dataset['imgs'][img_idx:img_idx+bn_batch_size]
  bn_img = utils.front_img_ch(bn_img)
  bn_img = bn_img.astype(float) / 255.
  bn_img = utils.numpy_to_torch(bn_img).to('cuda')
  bn_imgs.append(bn_img)

In [None]:
bn_imgs = None

In [None]:
img_idxes = [i for i, x in enumerate(dataset['labels']) if x == 1]
len(img_idxes) / len(dataset['labels'])

In [None]:
img = dataset['imgs'][img_idxes[1]]
plt.imshow(img)
plt.show()

In [None]:
temp = 0.5

In [None]:
mask = np.zeros((8, 8))
#mask[3:5, :] = 1
mask = mask.ravel()

In [None]:
start_time = time.time()
comp_img, kldiv = encoder.compress(img, mask, temp, bn_imgs=bn_imgs)
kldiv, (time.time() - start_time)

In [None]:
plt.imshow(comp_img)
plt.show()

In [None]:
data = {
  'obses': dataset['imgs'],
  'actions': np.array([utils.onehot_encode(int(a), n_act_dims) for a in dataset['labels']])
}

In [None]:
data = utils.split_user_data(data, train_frac=0.999)

In [None]:
sim_user_model_data_path = os.path.join(data_dir, 'sim_user_model_data.pkl')

In [None]:
with open(sim_user_model_data_path, 'rb') as f:
  data = pickle.load(f)

In [None]:
with open(sim_user_model_data_path, 'wb') as f:
  pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)

In [None]:
sim_user_model_train_kwargs = {
  'iterations': 2000,
  'ftol': 1e-6,
  'learning_rate': 5e-4,
  'batch_size': 32,
  'val_update_freq': 200,
  'verbose': True
}

In [None]:
sim_user_model = ConvPolicy(
  sess, 
  n_act_dims=n_act_dims, 
  n_obs_dims=img_shape,
  n_layers=2,
  layer_size=256,
  #scope=str(uuid.uuid4()),
  scope_file=os.path.join(data_dir, 'sim_user_model_scope.pkl'),
  tf_file=os.path.join(data_dir, 'sim_user_model.tf')
)

In [None]:
sim_user_model.train(data, **sim_user_model_train_kwargs)

In [None]:
sim_user_model.save()

In [None]:
sim_user_model.load()

In [None]:
idxes = data['train_idxes']
env_data = {k: data[k][idxes] for k in ['obses', 'actions']}

In [None]:
temp = 0.5
def apply_mask(real_obs, obs_mask):
  imgs = []
  kldivs = []
  for i in range(real_obs.shape[0]):
    img, kldiv = encoder.compress(real_obs[i], obs_mask[i], temp, bn_imgs=bn_imgs)
    imgs.append(img)
    kldivs.append(kldiv)
  return np.array(imgs), np.array(kldivs)

In [None]:
n_act_blocks = 8
train_mask_limits = (0.25, 0.25)
def make_env(val_mode=True):
  rew_mod = ConvDiscrim(
    sess, 
    n_act_dims=n_act_blocks,
    n_obs_dims=img_shape,
    n_layers=2,
    layer_size=256,
    scope=str(uuid.uuid4()),
    scope_file=os.path.join(data_dir, 'rew_mod_scope.pkl'),
    tf_file=os.path.join(data_dir, 'rew_mod.tf')
  )
  discrim = ConvDiscrim(
    sess, 
    n_act_dims=n_act_dims,
    n_obs_dims=img_shape,
    struct=True,
    n_layers=2,
    layer_size=256,
    scope=str(uuid.uuid4()),
    scope_file=os.path.join(data_dir, 'discrim_scope.pkl'),
    tf_file=os.path.join(data_dir, 'discrim.tf')
  )
  if not val_mode:
    mask_limits = train_mask_limits
  else:
    mask_limits = (None, None)
  env = CelebAEnv(
    sim_user_model, 
    encoder, 
    env_data,
    apply_mask,
    rew_mod,
    discrim,
    val_mode=val_mode,
    n_act_blocks=n_act_blocks,
    mask_limits=mask_limits
  )
  return env

In [None]:
def make_model(env, model_path):
  if not os.path.exists(model_path):
    os.makedirs(model_path)
  model = ConvCompressor(
    sess,
    rew_mod=env.rew_mod,
    n_obs_dims=img_shape,
    n_act_dims=n_act_blocks,
    n_user_act_dims=n_act_dims,
    n_layers=2,
    layer_size=256,
    #scope=str(uuid.uuid4()),
    scope_file=os.path.join(model_path, 'scope.pkl'),
    tf_file=os.path.join(model_path, 'model.tf')
  )
  return model

In [None]:
model_train_kwargs = {
  'iterations': 10000,
  'ftol': 1e-6,
  'learning_rate': 1e-3,
  'batch_size': 32,
  'val_update_freq': 1000,
  'verbose': True
}

n_iter = 1

In [None]:
rew_mod_train_kwargs = {
  'iterations': 2000,
  'ftol': 1e-6,
  'learning_rate': 1e-3,
  'batch_size': 32,
  'val_update_freq': 100,
  'verbose': True
}

discrim_train_kwargs = {
  'iterations': 2000,
  'ftol': 1e-6,
  'learning_rate': 1e-3,
  'batch_size': 32,
  'val_update_freq': 100,
  'verbose': True
}

rew_mod_update_freq = 1000

In [None]:
def run_gan_training(model_path, using_mae=False):
  env = make_env(val_mode=False)
  model = make_model(env, model_path)
  gan = PicoGAN(model, env)
  model = gan.train(
    model_train_kwargs, 
    verbose=False, 
    n_iter=n_iter,
    rew_mod_update_freq=rew_mod_update_freq, 
    rew_mod_train_kwargs=rew_mod_train_kwargs,
    discrim_train_kwargs=discrim_train_kwargs,
    discrim_zero_val=0.25,
    using_mae=using_mae
  )
  return model

In [None]:
model_path = os.path.join(data_dir, 'model_0')
mae_model_path = os.path.join(data_dir, 'mae_model_0')

In [None]:
encoder = NVAEEncoder()
comp_img = encoder.compress(img, mask, temp, bn_imgs=bn_imgs)

In [None]:
model = run_gan_training(
  model_path=model_path, 
  using_mae=False
)

In [None]:
model.save()

In [None]:
mae_model = run_gan_training(
  model_path=mae_model_path
  using_mae=True
)

In [None]:
mae_model.save()

In [None]:
eval_env = make_env(val_mode=True)

In [None]:
def load_model(model_path):
  model = make_model(eval_env, model_path)
  model.load()
  return model

In [None]:
model = load_model(model_path)

In [None]:
mae_model = load_model(mae_model_path)

In [None]:
def local_eval_model(compression_model, verbosity=0):
  return utils.eval_model(
    compression_model,
    data,
    encoder,
    sim_user_model,
    verbosity=verbosity
  )

In [None]:
encoder = NVAEEncoder()
comp_img = encoder.compress(img, mask, temp=temp, bn_imgs=bn_imgs)

In [None]:
mask_limit = 2/8

In [None]:
baseline_mask_policy = lambda real_obses: np.random.random((real_obses.shape[0], eval_env.n_act_blocks))
baseline_compression_model = Masker(baseline_mask_policy, eval_env, mask_limit)

In [None]:
baseline_metrics = local_eval_model(baseline_compression_model, verbosity=20)
baseline_metrics

In [None]:
learned_mask_policy = model.act
learned_compression_model = Masker(learned_mask_policy, eval_env, mask_limit)

In [None]:
learned_metrics = local_eval_model(learned_compression_model, verbosity=20)
learned_metrics

In [None]:
mae_mask_policy = mae_model.act
mae_compression_model = Masker(mae_mask_policy, eval_env, mask_limit)

In [None]:
mae_metrics = local_eval_model(mae_compression_model, verbosity=20)
mae_metrics

In [None]:
mask_limits = np.arange(0, 1+1/n_act_blocks, 1/n_act_blocks)

In [None]:
mask_policy_of_model = {
  'baseline': baseline_mask_policy,
  'mae': mae_mask_policy,
  'learned': learned_mask_policy
}

In [None]:
def eval_mask_policy(mask_policy, mask_limit, **kwargs):
  compression_model = Masker(mask_policy, eval_env, mask_limit)
  metrics = local_eval_model(compression_model)
  return metrics

In [None]:
mets_of_model = viz.sweep_mask_limits(
  mask_limits, 
  eval_env,
  mask_policy_of_model,
  eval_mask_policy
)

In [None]:
mets_path = os.path.join(data_dir, 'mets.pkl')

In [None]:
with open(mets_path, 'rb') as f:
  mets_of_model = pickle.load(f)

In [None]:
with open(mets_path, 'wb') as f:
  pickle.dump(mets_of_model, f, pickle.HIGHEST_PROTOCOL)

In [None]:
plt.title('CelebA Faces')
plt.xlabel('Bitrate (Bits)')
plt.ylabel("User's Action Agreement")
y_key = 'act_acc'
x_key = 'kldiv'
configs = [
  ('learned', 'orange', 'PICO (Ours)'),
  ('baseline', 'gray', 'Non-Adaptive (Baseline)'),
  ('mae', 'red', 'Perceptual Similarity (Baseline)')
]
for model_name, color, label in configs:
  plt.errorbar(
    np.array(mets_of_model[model_name][x_key]), 
    mets_of_model[model_name][y_key], 
    mets_of_model[model_name]['%s_stderr' % y_key], 
    color=color,
    marker='o', 
    capsize=2,
    label=label
  )
plt.legend(loc='lower right', fontsize=11)
plt.show()