In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import division

from copy import deepcopy
import os
import pickle
import uuid

from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import numpy as np
from brokenaxes import brokenaxes

from pico.gan import PicoGAN
from pico.user_models import MLPPolicy
from pico.discrim_models import MLPDiscrim
from pico.encoder_models import BTCVAEEncoder
from pico.envs import MNISTEnv
from pico.compression_models import Masker, MLPCompressor
from pico import utils
from pico import viz

In [None]:
from matplotlib import pyplot as plt
import matplotlib as mpl
%matplotlib inline

In [None]:
sess = utils.make_tf_session(gpu_mode=False)

In [None]:
data_dir = utils.mnist_data_dir
fig_dir = os.path.join(data_dir, 'figures')
if not os.path.exists(fig_dir):
  os.makedirs(fig_dir)

In [None]:
dataset = utils.make_mnist_dataset()

In [None]:
n_act_dims = dataset['n_classes']
img_shape = tuple(list(dataset['img_shape']) + [1])
flat_img_size = 1
for x in img_shape:
  flat_img_size *= x
dataset['imgs'] = dataset['feats'].reshape((-1, *img_shape))
img_shape, flat_img_size

In [None]:
encoder = BTCVAEEncoder('mnist')

In [None]:
img_idx = 30000

In [None]:
img = dataset['imgs'][img_idx]
plt.imshow(img, cmap=mpl.cm.binary)
plt.show()

In [None]:
latent = encoder.encode(dataset['imgs'][img_idx:img_idx+1])
latent, np.max(latent), np.min(latent)

In [None]:
latent = np.random.normal(0, 1, encoder.latent_dim)[np.newaxis]

In [None]:
img = encoder.decode(latent)[0]
plt.imshow(img, cmap=mpl.cm.binary)
plt.show()

In [None]:
data = {
  'obses': encoder.encode(dataset['imgs']),
  'imgs': dataset['imgs'],
  'actions': np.array([utils.onehot_encode(int(a), n_act_dims) for a in dataset['labels']])
}

In [None]:
data = utils.split_user_data(data, train_frac=0.9)

In [None]:
sim_user_model_data_path = os.path.join(data_dir, 'sim_user_model_data.pkl')

In [None]:
with open(sim_user_model_data_path, 'rb') as f:
  data = pickle.load(f)

In [None]:
with open(sim_user_model_data_path, 'wb') as f:
  pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)

In [None]:
sim_user_model_train_kwargs = {
  'iterations': 10000,
  'ftol': 1e-6,
  'learning_rate': 1e-3,
  'batch_size': 32,
  'val_update_freq': 1000,
  'verbose': True
}

In [None]:
sim_user_model = MLPPolicy(
  sess, 
  n_act_dims=n_act_dims, 
  n_obs_dims=encoder.latent_dim,
  n_layers=2,
  layer_size=256,
  #scope=str(uuid.uuid4()),
  scope_file=os.path.join(data_dir, 'sim_user_model_scope.pkl'),
  tf_file=os.path.join(data_dir, 'sim_user_model.tf')
)

In [None]:
sim_user_model.train(data, **sim_user_model_train_kwargs)

In [None]:
sim_user_model.save()

In [None]:
sim_user_model.load()

In [None]:
idxes = data['train_idxes']
demo_data = {k: data[k][idxes] for k in ['obses', 'imgs', 'actions']}

In [None]:
obs_prior_mean = np.mean(demo_data['obses'], axis=0)
obs_prior_std = np.std(demo_data['obses'], axis=0)

In [None]:
def apply_mask(real_obses, mask, **kwargs):
  return utils.apply_mask(real_obses, mask, obs_prior_mean, obs_prior_std)

In [None]:
n_act_blocks = encoder.latent_dim
train_mask_limits = (0.5, 0.5)
def make_env(val_mode=True):
  discrim = MLPDiscrim(
    sess, 
    n_act_dims=n_act_dims,
    n_obs_dims=encoder.latent_dim,
    struct=True,
    n_layers=2,
    layer_size=256,
    scope=str(uuid.uuid4()),
    scope_file=os.path.join(data_dir, 'discrim_scope.pkl'),
    tf_file=os.path.join(data_dir, 'discrim.tf')
  )
  discrim.noverfit = True
  rew_mod = MLPDiscrim(
    sess, 
    n_act_dims=n_act_blocks,
    n_obs_dims=encoder.latent_dim,
    n_layers=2,
    layer_size=256,
    scope=str(uuid.uuid4()),
    scope_file=os.path.join(data_dir, 'rew_mod_scope.pkl'),
    tf_file=os.path.join(data_dir, 'rew_mod.tf')
  )
  if not val_mode:
    mask_limits = train_mask_limits
  else:
    mask_limits = (None, None)
  env = MNISTEnv(
    sim_user_model, 
    encoder, 
    demo_data,
    apply_mask,
    rew_mod,
    discrim,
    val_mode=val_mode,
    n_act_blocks=n_act_blocks,
    mask_limits=mask_limits
  )
  return env

In [None]:
def make_model(env, model_path):
  if not os.path.exists(model_path):
    os.makedirs(model_path)
  model = MLPCompressor(
    sess,
    rew_mod=env.rew_mod,
    n_obs_dims=encoder.latent_dim,
    n_act_dims=n_act_blocks,
    n_user_act_dims=n_act_dims,
    n_layers=2,
    layer_size=64,
    #scope=str(uuid.uuid4()),
    scope_file=os.path.join(model_path, 'scope.pkl'),
    tf_file=os.path.join(model_path, 'model.tf')
  )
  return model

In [None]:
model_train_kwargs = {
  'iterations': 10000,
  'ftol': 1e-6,
  'learning_rate': 1e-3,
  'batch_size': 32,
  'val_update_freq': 1000,
  'verbose': True
}

n_iter = 1

In [None]:
discrim_train_kwargs = {
  'iterations': 2000,
  'ftol': 1e-6,
  'learning_rate': 1e-3,
  'batch_size': 32,
  'val_update_freq': 100,
  'verbose': True
}

rew_mod_train_kwargs = {
  'iterations': 10000,
  'ftol': 1e-6,
  'learning_rate': 1e-3,
  'batch_size': 32,
  'val_update_freq': 1000,
  'verbose': True
}

rew_mod_update_freq = 1000

In [None]:
def run_gan_training(model_path, using_mae=False):
  env = make_env(val_mode=False)
  model = make_model(env, model_path)
  gan = PicoGAN(model, env)
  model = gan.train(
    model_train_kwargs, 
    verbose=False,
    n_iter=n_iter,
    rew_mod_update_freq=rew_mod_update_freq, 
    rew_mod_train_kwargs=rew_mod_train_kwargs,
    discrim_train_kwargs=discrim_train_kwargs,
    discrim_zero_val=0.5,
    using_mae=using_mae
  )
  return model

In [None]:
model_path = os.path.join(data_dir, 'model_0')
mae_model_path = os.path.join(data_dir, 'mae_model_0')

In [None]:
model = run_gan_training(
  model_path=model_path, 
  using_mae=False
)

In [None]:
model.save()

In [None]:
mae_model = run_gan_training(
  model_path=mae_model_path, 
  using_mae=True
)

In [None]:
mae_model.save()

In [None]:
eval_env = make_env(val_mode=True)

In [None]:
def load_model(model_path):
  model = make_model(eval_env, model_path)
  model.load()
  return model

In [None]:
model = load_model(model_path)

In [None]:
mae_model = load_model(mae_model_path)

In [None]:
def local_eval_model(compression_model, verbosity=0):
  return utils.eval_model(
    compression_model,
    data,
    encoder,
    sim_user_model,
    n_eval_obses=None,
    verbosity=verbosity
  )

In [None]:
mask_limit = 0.5

In [None]:
baseline_mask_policy = lambda real_obses: np.random.random((real_obses.shape[0], eval_env.n_act_blocks))
baseline_compression_model = Masker(baseline_mask_policy, eval_env, mask_limit)

In [None]:
baseline_metrics = local_eval_model(baseline_compression_model)
baseline_metrics

In [None]:
learned_mask_policy = model.act
learned_compression_model = Masker(learned_mask_policy, eval_env, mask_limit)

In [None]:
learned_metrics = local_eval_model(learned_compression_model)
learned_metrics

In [None]:
mae_mask_policy = mae_model.act
mae_compression_model = Masker(mae_mask_policy, eval_env, mask_limit)

In [None]:
mae_metrics = local_eval_model(mae_compression_model)
mae_metrics

In [None]:
mask_limits = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 1]

In [None]:
mask_policy_of_model = {
  'baseline': baseline_mask_policy,
  'mae': mae_mask_policy,
  'learned': learned_mask_policy
}

In [None]:
def eval_mask_policy(mask_policy, mask_limit, **kwargs):
  compression_model = Masker(mask_policy, eval_env, mask_limit)
  metrics = local_eval_model(compression_model)
  return metrics

In [None]:
mets_of_model = viz.sweep_mask_limits(
  mask_limits, 
  eval_env,
  mask_policy_of_model,
  eval_mask_policy
)

In [None]:
mets_path = os.path.join(data_dir, 'mets.pkl')

In [None]:
with open(mets_path, 'rb') as f:
  mets_of_model = pickle.load(f)

In [None]:
with open(mets_path, 'wb') as f:
  pickle.dump(mets_of_model, f, pickle.HIGHEST_PROTOCOL)

In [None]:
plt.title('MNIST Digits')
plt.xlabel('Bitrate (Bits)')
plt.ylabel("User's Action Agreement")
y_key = 'act_acc'
x_key = 'kldiv'
configs = [
  ('learned', 'orange', 'PICO (Ours)'),
  ('baseline', 'gray', 'Non-Adaptive (Baseline)'),
  ('mae', 'red', 'Perceptual Similarity (Baseline)')
]
for model_name, color, label in configs:
  plt.errorbar(
    mets_of_model[model_name][x_key], 
    mets_of_model[model_name][y_key], 
    mets_of_model[model_name]['%s_stderr' % y_key], 
    color=color,
    marker='o', 
    capsize=2,
    label=label
  )
plt.legend(loc='lower right')
plt.show()