In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import random
import glob
from tqdm import tqdm
from functools import partial
import ot
from matplotlib import pyplot as plt
from collections import defaultdict
import torch, torch.nn as nn
from torch.distributions import Normal
from easydict import EasyDict as edict

In [None]:
from iterative_sir.toy_examples_utils.toy_examples_utils import prepare_swissroll_data
from iterative_sir.toy_examples_utils.gan_fc_models import (
    Generator_fc, 
    Discriminator_fc,
    )
from iterative_sir.sampling_utils.visualization import (
                           sample_fake_data,
                           plot_discriminator_2d,
                           mh_sampling_plot_2d,
                           langevin_sampling_plot_2d,
                           mala_sampling_plot_2d,
                           plot_chain_metrics)
from iterative_sir.sampling_utils.ebm_sampling import (
                          langevin_sampling,
                          mala_dynamics, 
                          mala_sampling,
                          gan_energy,
                          IndependentNormal)
from iterative_sir.sampling_utils.adaptive_mc import ex2_mcmc_mala
from iterative_sir.toy_examples_utils.params_swissroll_wasserstein import (random_seed,
                                          train_dataset_size,
                                          n_dim,
                                          n_layers_d,
                                          n_layers_g,
                                          n_hid_d,
                                          n_hid_g,
                                          n_out,
                                          device)

from iterative_sir.sampling_utils.metrics import Evolution

In [None]:
from pathlib import Path

figpath = Path('../figs')

assert figpath.exists()

In [None]:
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

X_train = prepare_swissroll_data(train_dataset_size)
X_train_std = X_train

In [None]:
G = Generator_fc(n_dim=n_dim, 
                 n_layers=n_layers_g,
                 n_hid=n_hid_g,
                 n_out=n_out,
                 non_linear=nn.ReLU(),
                 device=device)
D = Discriminator_fc(n_in=n_dim, 
                     n_layers=n_layers_d,
                     n_hid=n_hid_d,
                     non_linear=nn.ReLU(),
                     device=device)

In [None]:
gen_path = sorted(Path('../models/models_swissroll/wasserstein').glob('*_generator.pth'))[0]
disc_path = sorted(Path('../models/models_swissroll/wasserstein').glob('*_discriminator.pth'))[0]

G.load_state_dict(torch.load(gen_path, map_location=device))
D.load_state_dict(torch.load(disc_path, map_location=device))

In [None]:
n_dim = G.n_dim
loc = torch.zeros(n_dim).to(G.device)
scale = torch.ones(n_dim).to(G.device)
normal = Normal(loc, scale)
normalize_to_0_1 = True 
log_prob = True

proposal = IndependentNormal(
    dim=n_dim,
    device=device,
    loc=loc,
    scale=scale)

target = partial(gan_energy, 
                     generator = G, 
                     discriminator = D, 
                     proposal = proposal,
                     normalize_to_0_1 = normalize_to_0_1,
                     log_prob = log_prob)

In [None]:
evols = dict()

In [None]:
batch_size = 1 #25 #5000
n_steps = 800
every = 200

In [None]:
target_sample = X_train[np.random.choice(np.arange(X_train.shape[0]), 1000)]

grad_step = 1e-1 #3e-3
eps_scale = (grad_step * 2) ** 0.5

z_last_np, zs = langevin_sampling(target,
                               proposal,  
                               batch_size=batch_size,
                               n = batch_size,
                               grad_step = grad_step,
                               eps_scale = eps_scale,
                               n_steps = n_steps)

n_chunks = len(zs[0]) // every
zs = zs[0, -n_chunks * every:].reshape((n_chunks, batch_size, -1, zs.shape[-1]))
zs_gen = zs.reshape(batch_size, n_chunks, -1, zs.shape[-1])

Xs_gen = G(torch.FloatTensor(zs_gen).to(device)).detach().cpu().numpy()
#Xs_gen = scaler.inverse_transform(Xs_gen.reshape(-1, Xs_gen.shape[-1])).reshape(Xs_gen.shape)

evol = defaultdict(list)
for X_gen in Xs_gen:
    evolution = Evolution(target_sample,
                      target_log_prob=target)
    for chunk in X_gen:
        evolution.invoke(torch.FloatTensor(chunk))
    evol_ = evolution.as_dict()
    for k, v in evol_.items():
        evol[k].append(v)

for k, v in evol.items():
    evol[k] = (np.mean(np.array(v), 0), np.std(np.array(v), 0, ddof=1) / np.sqrt(batch_size))
evols['ULA'] = evol