In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append('../tools/sampling_utils')

from easydict import EasyDict as edict
import torch
import numpy as np
from torch.distributions import (MultivariateNormal, 
                                 Normal, 
                                 Independent, 
                                 Uniform)

In [3]:
from ebm_sampling import mala_dynamics
from sir_ais_sampling import sir_independent_dynamics
from adaptive_mc import adaptive_sir_correlated_dynamics, ex2_mcmc_mala
from distributions import (Target, 
                           Gaussian_mixture, 
                           IndependentNormal,
                           init_independent_normal,
                           init_independent_normal_scale)
from metrics import Evolution

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [5]:
def random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [6]:
def define_target(loc_1_target = -3, loc_2_target = 3, scale_target = 1, dim=100, device='cpu'):
    target_args = edict()
    target_args.device = device
    target_args.num_gauss = 2

    coef_gaussian = 1. / target_args.num_gauss
    target_args.p_gaussians = [torch.tensor(coef_gaussian)]*target_args.num_gauss
    locs = [loc_1_target*torch.ones(dim, dtype = torch.float64).to(device),
        loc_2_target*torch.ones(dim, dtype = torch.float64).to(device)]
    locs_numpy = torch.stack(locs, axis = 0).cpu().numpy()
    target_args.locs = locs
    target_args.covs = [(scale_target**2)*torch.eye(dim, 
                                                    dtype = torch.float64).to(device)]*target_args.num_gauss
    target_args.dim = dim
    target = Gaussian_mixture(target_args)
    return target

In [17]:
def compute_metrics(sample, target, trunc_chain_len=None):
    if trunc_chain_len is not None:
        trunc_sample = sample[(-trunc_chain_len - 1):-1]
    else:
        trunc_sample = sample
    if isinstance(sample, list):
        sample = torch.stack(sample, axis = 0).cpu().numpy()
        trunc_sample = torch.stack(trunc_sample, axis = 0).cpu()
    chain_len, batch_size, dim = sample.shape
        
    locs = target.locs
    evolution = Evolution(None, locs=torch.stack(locs, 0).cpu(), sigma=target.covs[0][0, 0])  

    result_np = trunc_sample.numpy()
        
    modes_var_arr = []
    modes_mean_arr = []
    hqr_arr = []
    jsd_arr = []
    means_est_1 = torch.zeros(dim)
    means_est_2 = torch.zeros(dim)
    num_found_1_mode = 0
    num_found_2_mode = 0
    num_found_both_modes = 0

    for i in range(batch_size):
        X_gen = trunc_sample[:, i, :]
        assignment = Evolution.make_assignment(X_gen, evolution.locs, evolution.sigma)
        mode_var = Evolution.compute_mode_std(X_gen, assignment).item()**2
            
        modes_mean, found_modes_ind = Evolution.compute_mode_mean(X_gen, assignment)

        if 0 in found_modes_ind and 1 in found_modes_ind:
            num_found_both_modes += 1
        if 0 in found_modes_ind:
            num_found_1_mode += 1
            means_est_1 += modes_mean[0]
        if 1 in found_modes_ind:
            num_found_2_mode += 1
            means_est_2 += modes_mean[1]
        
        hqr = Evolution.compute_high_quality_rate(assignment).item()
        jsd = Evolution.compute_jsd(assignment).item()
        
        modes_var_arr.append(mode_var)
        hqr_arr.append(hqr)
        jsd_arr.append(jsd)

    jsd = np.array(jsd_arr).mean()
    modes_var = np.array(modes_var_arr).mean()
    hqr = np.array(hqr_arr).mean()
    if num_found_1_mode == 0:
        print("Unfortunalely, no points were assigned to 1st mode, default estimation - zero")
        modes_mean_1_result = np.nan #0.0
    else:
        modes_mean_1_result = (means_est_1/num_found_1_mode).mean().item()
    if num_found_2_mode == 0:
        print("Unfortunalely, no points were assigned to 2nd mode, default estimation - zero")
        modes_mean_2_result = np.nan #0.0
    else:
        modes_mean_2_result = (means_est_2/num_found_2_mode).mean().item()
    if num_found_1_mode == 0 and num_found_2_mode == 0:
        modes_mean_1_result = modes_mean_2_result = trunc_sample.mean().item()
        
    result = dict(jsd=jsd, modes_var=modes_var, hqr=hqr, mode1_mean=modes_mean_1_result, mode2_mean=modes_mean_2_result, fraction_found2_modes=num_found_both_modes/batch_size, fraction_found1_mode=(num_found_1_mode+num_found_2_mode-2*num_found_both_modes)/batch_size)
    return result



In [241]:
args = edict()
args.loc_1_target = -0.5
args.loc_2_target = 0.5
args.scale_target = 1

args.scale_proposal = 2
args.loc_proposal = 0

args.dim = [100]

args.batch_size = 200
args.n_steps = 150


In [242]:
dim = args.dim[0]

In [243]:
target = define_target(args.loc_1_target, args.loc_2_target, args.scale_target, dim, device=device)#.log_prob
proposal = init_independent_normal(args.scale_proposal, dim, device, args.loc_proposal)

In [244]:
method_args = edict()
method_args.n_steps = 1000
method_args.N = 6

start = proposal.sample([args.batch_size])
sample = sir_independent_dynamics(start, target, proposal, **method_args)

In [245]:
trunc_chain_len = 900
result = compute_metrics(sample, target, trunc_chain_len=trunc_chain_len)
print(result)

Unfortunalely, no points were assigned to 1st mode, default estimation - zero
Unfortunalely, no points were assigned to 2nd mode, default estimation - zero
{'jsd': 0.0, 'modes_var': nan, 'hqr': 0.0, 'mode1_mean': -0.0029020842630416155, 'mode2_mean': -0.0029020842630416155, 'fraction_found2_modes': 0.0, 'fraction_found1_mode': 0.0}


In [246]:
method_args = edict()
method_args.n_steps = 1000
method_args.grad_step = 1e-2
method_args.eps_scale = (2 * method_args.grad_step)**.5

start = proposal.sample([args.batch_size])
sample, acceptance = mala_dynamics(start, target.log_prob, proposal, **method_args, acceptance_rule='Hastings', adapt_stepsize=True)

In [247]:
trunc_chain_len = 900
result = compute_metrics(sample, target, trunc_chain_len=trunc_chain_len)
print(result)

{'jsd': 0.040205003209412095, 'modes_var': 0.9572283194203931, 'hqr': 0.9499833327531815, 'mode1_mean': -0.4875207841396332, 'mode2_mean': 0.4794863760471344, 'fraction_found2_modes': 0.025, 'fraction_found1_mode': 0.975}


In [248]:
method_args = edict()
method_args.n_steps = 1000
method_args.grad_step = 1e-2
method_args.noise_scale = (2 * method_args.grad_step)**.5
method_args.N = 10
method_args.corr_coef = 0 #0.9
method_args.bernoulli_prob_corr = 0 #0.5


start = proposal.sample([args.batch_size])
sample, acceptance = ex2_mcmc_mala(start, target.log_prob, proposal, **method_args, adapt_stepsize=True)

In [249]:
trunc_chain_len = 900
result = compute_metrics(sample, target, trunc_chain_len=trunc_chain_len)
print(result)

{'jsd': 0.04044845063239336, 'modes_var': 0.9514185357343075, 'hqr': 0.9511444461345673, 'mode1_mean': -0.48025742173194885, 'mode2_mean': 0.4788275957107544, 'fraction_found2_modes': 0.03, 'fraction_found1_mode': 0.97}


In [252]:
method_args = edict()
method_args.n_steps = 1000
method_args.grad_step = 1e-2
method_args.noise_scale = (2 * method_args.grad_step)**.5
method_args.N = 10
method_args.corr_coef = 0.95
method_args.bernoulli_prob_corr = 0.9


start = proposal.sample([args.batch_size])
sample, acceptance = ex2_mcmc_mala(start, target.log_prob, proposal, **method_args, adapt_stepsize=True)

In [253]:
trunc_chain_len = 900
result = compute_metrics(sample, target, trunc_chain_len=trunc_chain_len)
print(result)

{'jsd': 0.0431051641702652, 'modes_var': 0.9574164348796834, 'hqr': 0.9504277762770653, 'mode1_mean': -0.4801580309867859, 'mode2_mean': 0.49236881732940674, 'fraction_found2_modes': 0.02, 'fraction_found1_mode': 0.98}


In [256]:
method_args = edict()
method_args.n_steps = 1000
method_args.grad_step = 1e-2
method_args.noise_scale = (2 * method_args.grad_step)**.5
method_args.N = 10
method_args.corr_coef = 0.95
method_args.bernoulli_prob_corr = 0.25


start = proposal.sample([args.batch_size])
sample, acceptance = ex2_mcmc_mala(start, target.log_prob, proposal, **method_args, adapt_stepsize=True)

In [257]:
trunc_chain_len = 900
result = compute_metrics(sample, target, trunc_chain_len=trunc_chain_len)
print(result)

{'jsd': 0.044252626560628415, 'modes_var': 0.9441087437931205, 'hqr': 0.9501277777552605, 'mode1_mean': -0.4614450931549072, 'mode2_mean': 0.4808114767074585, 'fraction_found2_modes': 0.04, 'fraction_found1_mode': 0.96}


In [286]:
method_args = edict()
method_args.n_steps = 1000
method_args.grad_step = 1e-2
method_args.noise_scale = (2 * method_args.grad_step)**.5
method_args.N = 25
method_args.corr_coef = 0.9
method_args.bernoulli_prob_corr = 0.3


start = proposal.sample([args.batch_size])
sample, acceptance = ex2_mcmc_mala(start, target.log_prob, proposal, **method_args, adapt_stepsize=True)

In [287]:
trunc_chain_len = 900
result = compute_metrics(sample, target, trunc_chain_len=trunc_chain_len)
print(result)

{'jsd': 0.0405741473287344, 'modes_var': 0.9585413747998734, 'hqr': 0.9508777776360512, 'mode1_mean': -0.500348687171936, 'mode2_mean': 0.4692520499229431, 'fraction_found2_modes': 0.02, 'fraction_found1_mode': 0.98}


In [280]:
method_args = edict()
method_args.n_steps = 1000
method_args.grad_step = 1e-2
method_args.noise_scale = (2 * method_args.grad_step)**.5
method_args.N = 10
method_args.corr_coef = 0.98
method_args.bernoulli_prob_corr = 0.35


start = proposal.sample([args.batch_size])
sample, acceptance = ex2_mcmc_mala(start, target.log_prob, proposal, **method_args, adapt_stepsize=True)

In [281]:
trunc_chain_len = 900
result = compute_metrics(sample, target, trunc_chain_len=trunc_chain_len)
print(result)

{'jsd': 0.04425475563853979, 'modes_var': 0.9506992716573822, 'hqr': 0.9494777771830559, 'mode1_mean': -0.4737156629562378, 'mode2_mean': 0.4795755445957184, 'fraction_found2_modes': 0.035, 'fraction_found1_mode': 0.965}


In [276]:
method_args = edict()
method_args.n_steps = 1000
method_args.grad_step = 1e-2
method_args.noise_scale = (2 * method_args.grad_step)**.5
method_args.N = 10
method_args.corr_coef = 0.3
method_args.bernoulli_prob_corr = 0.95


start = proposal.sample([args.batch_size])
sample, acceptance = ex2_mcmc_mala(start, target.log_prob, proposal, **method_args, adapt_stepsize=True)

In [277]:
trunc_chain_len = 900
result = compute_metrics(sample, target, trunc_chain_len=trunc_chain_len)
print(result)

{'jsd': 0.03869704980403185, 'modes_var': 0.965659766954748, 'hqr': 0.9495277762413025, 'mode1_mean': -0.49223682284355164, 'mode2_mean': 0.4912351369857788, 'fraction_found2_modes': 0.01, 'fraction_found1_mode': 0.99}


In [32]:
# alpha = 0.1

# start = proposal.sample([batch_size])
# sample = adaptive_sir_correlated_dynamics(start, target, proposal, n_steps, N, alpha, flow)