In [None]:
!pip install ex2mcmc

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import trange
from torch import nn
import time
import matplotlib.pyplot as plt

from .types_ import *
from abc import abstractmethod

import pyro
from pyro.infer import MCMC, NUTS
import ot
import jax
import gc

from ex2mcmc.sampling_utils.adaptive_mc import Ex2MCMC, FlowMCMC
from ex2mcmc.sampling_utils.distributions import (
    Banana,
    IndependentNormal,
)
from ex2mcmc.models.rnvp import RNVP
from ex2mcmc.metrics.chain import ESS, acl_spectrum
from ex2mcmc.metrics.total_variation import (
    average_total_variation,
)


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
def sample_nuts(target, proposal, device = 'cpu', num_samples=1000, burn_in=1000, batch_size=1, rand_seed = 42):
    def true_target_energy(z):
        return -target(z)

    def energy(z):
        z = z["points"]
        return true_target_energy(z).sum()
    start_time = time.time()
    # kernel = HMC(potential_fn=energy, step_size = 0.1, num_steps = K, full_mass = False)
    kernel_true = NUTS(potential_fn=energy, full_mass=False)
    #kernel_true = HMC(potential_fn=energy, full_mass=False)
    pyro.set_rng_seed(rand_seed)
    init_samples = proposal.sample((batch_size,)).to(device)
    print(init_samples.shape) 
    #init_samples = torch.zeros_like(init_samples)
    dim = init_samples.shape[-1]
    init_params = {"points": init_samples}
    mcmc_true = MCMC(
        kernel=kernel_true,
        num_samples=num_samples,
        initial_params=init_params,
        warmup_steps=burn_in,
    )
    mcmc_true.run()
    q_true = mcmc_true.get_samples(group_by_chain=True)["points"].cpu()
    samples_true = np.array(q_true.view(-1, batch_size, dim))
    end_time = time.time()
    return end_time-start_time, samples_true

In [None]:
def compute_metrics(
    xs_true,
    xs_pred,
    name=None,
    n_samples=1000,
    scale=1.0,
    trunc_chain_len=None,
    ess_rar=1,
):
    metrics = dict()
    key = jax.random.PRNGKey(0)
    n_steps = 25
    # n_samples = 100

    ess = ESS(
        acl_spectrum(
            xs_pred[::ess_rar] - xs_pred[::ess_rar].mean(0)[None, ...],
        ),
    ).mean()
    metrics["ess"] = ess

    xs_pred = xs_pred[-trunc_chain_len:]
    print(xs_true.shape)
    print(xs_pred.shape)

    tracker = average_total_variation(
        key,
        xs_true,
        xs_pred,
        n_steps=n_steps,
        n_samples=n_samples,
    )

    metrics["tv_mean"] = tracker.mean()
    metrics["tv_conf_sigma"] = tracker.std_of_mean()

    mean = tracker.mean()
    std = tracker.std()

    metrics["emd"] = 0
    #Cost_matr_isir = ot.dist(x1 = isir_res[j][i], x2=gt_samples[i], metric='sqeuclidean', p=2, w=None)
    for b in range(xs_pred.shape[1]):
        M = ot.dist(xs_true / scale, xs_pred[:, b,:] / scale)
        emd = ot.lp.emd2([], [], M, numItermax = 1e6)
        metrics["emd"] += emd / xs_pred.shape[1]

    if name is not None:
        print(f"===={name}====")
    print(
        f"TV distance. Mean: {mean:.3f}, Std: {std:.3f}. \nESS: {ess:.3f} \nEMD: {emd:.3f}",
    )

    return metrics

In [None]:
n_steps = 1
dist = "banana"
dim = 100
scale_proposal = 1.
scale_isir = 5.
dist_class = "Banana"
dist_params = {
    "b": 0.02,
    "sigma":5.0,
}
sigma = 5.0
b = 0.02

In [None]:
target = Banana(
                dim=dim,
                device=device,
                b = b,
                sigma = sigma,
                #b = b
                #**dist_params.dict,
)

loc_proposal = torch.zeros(dim).to(device)
scale_proposal = scale_proposal * torch.ones(dim).to(device)
scale_isir = scale_isir * torch.ones(dim).to(device)

proposal = IndependentNormal(
    dim=dim,
    loc=loc_proposal,
    scale=scale_proposal,
    device=device,
)

proposal_ex2 = IndependentNormal(
    dim=dim,
    loc=loc_proposal,
    scale=scale_isir,
    device=device,
)

### Ground-truth examples

In [None]:
N_samples = 2*10**3
np.random.seed(42)
True_samples = np.random.randn(N_samples,dim)
for i in range(dim):
    if i % 2 == 0:
      True_samples[:,i] *= sigma
    else:
      True_samples[:,i] += b*True_samples[:,i-1]**2 - (sigma**2)*b

In [None]:
fig,ax=plt.subplots(1,1)
#cp = ax.contourf(X, Y, dens_vals)
#fig.colorbar(cp) # Add a colorbar to a plot
ax.scatter(True_samples[:,2],True_samples[:,3], alpha=0.5)
ax.set_title('Ground truth samples')
plt.show()

### Ground-truth with NUTS

In [None]:
#samples to compute ground-truth metrics
Nuts_samples_ground_truth = 2000
#Nuts_samples_comparison
trunc_chain_len = 1000
#nuts samples burn_in
nuts_burn_in = 500
#nuts batch size
nuts_batch = 1

In [None]:
rand_seed = 42
time_nuts, sample_nuts_ref = sample_nuts(
                target,
                proposal,
                device,
                num_samples=trunc_chain_len,
                batch_size=nuts_batch,
                burn_in=nuts_burn_in,
                rand_seed = rand_seed
)
print(sample_nuts_ref.shape)

In [None]:
fig,ax=plt.subplots(1,1)
ax.scatter(sample_nuts_ref[:,0,0],sample_nuts_ref[:,0,1], alpha=0.5)
ax.set_title('NUTS samples')
plt.show()

In [None]:
metrics = compute_metrics(
                    True_samples,
                    sample_nuts_ref,
                    name="NUTS",
                    trunc_chain_len=trunc_chain_len,
                    ess_rar=1,
)

### Sample with Ex2MCMC

In [None]:
params = {
      "N": 200,
      "grad_step": 0.1,
      "adapt_stepsize": True, #True
      "corr_coef": 0.0,
      "bernoulli_prob_corr": 0.0, #0.75
      "mala_steps": 3
}
        
n_steps_ex2 = 5000
batch_size = 1

In [None]:
mcmc = Ex2MCMC(**params, dim=dim)
pyro.set_rng_seed(43)
start = proposal_ex2.sample((batch_size,)).to(device)
# print(start)
# s = time.time()
out = mcmc(start, target, proposal_ex2, n_steps = n_steps_ex2)
# print(out[1])
if isinstance(out, tuple):
    sample = out[0]
else:
    sample = out
sample = np.array(
    [_.detach().numpy() for _ in sample],
).reshape(-1, batch_size, dim)
sample_ex2_final = sample[:,0,:]
print(sample_ex2_final.shape)

In [None]:
fig,ax=plt.subplots(1,1)
ax.scatter(sample_ex2_final[:,0],sample_ex2_final[:,1], alpha=0.5)
ax.set_title('Ex2 samples')
plt.show()

### Sample with Flex2MCMC (adaptive version)

In [None]:
params_flex = {
      "N": 200,
      "grad_step": 0.2,
      "adapt_stepsize": True,
      "corr_coef": 0.0,
      "bernoulli_prob_corr": 0.0,
      "mala_steps": 0,
    "flow": {
      "num_blocks": 4, # number of normalizing layers 
      "lr": 1e-3, # learning rate 
      "batch_size": 100,
      "n_steps": 800,
    }
}
batch_size = 1
torch.cuda.empty_cache()

In [None]:
pyro.set_rng_seed(47)
mcmc = Ex2MCMC(**params_flex, dim=dim)
verbose = mcmc.verbose
mcmc.verbose = False
flow = RNVP(params_flex["flow"]["num_blocks"], dim=dim, device = device)
flow_mcmc = FlowMCMC(
    target,
    proposal,
    device,
    flow,
    mcmc,
    batch_size=params_flex["flow"]["batch_size"],
    lr=params_flex["flow"]["lr"],
)
flow.train()
out_samples, nll = flow_mcmc.train(
    n_steps=params_flex["flow"]["n_steps"],
)
assert not torch.isnan(
    next(flow.parameters())[0, 0],
).item()
gc.collect()
torch.cuda.empty_cache()
flow.eval()
mcmc.flow = flow
mcmc.verbose = verbose

In [None]:
#sample from a normalizing flow
n_steps_flex2 = 2000
batch_size = 1
pyro.set_rng_seed(42)
start = proposal.sample((batch_size,))
mcmc.N = 200
mcmc.mala_steps = 0
mcmc.grad_step = 0.1
# s = time.time()
out = mcmc(start, target, proposal, n_steps = n_steps_flex2)
if isinstance(out, tuple):
    sample = out[0]
else:
    sample = out
sample = np.array(
    [_.detach().numpy() for _ in sample],
).reshape(-1, batch_size, dim)
sample_flex2_new = sample
#resample with 0 mala steps
torch.cuda.empty_cache()
mcmc.mala_steps = 5
out_new = mcmc(start, target, proposal, n_steps = n_steps_flex2)
# print(out_new[1])
out_new = out_new[0]
out_new = np.array(
    [_.detach().numpy() for _ in out_new],
).reshape(-1, batch_size, dim)
sample_flex2_final = out_new
print(sample_flex2_final.shape)

In [None]:
fig,ax=plt.subplots(1,1)
ax.scatter(sample_flex2_new[:,0,0],sample_flex2_new[:,0,1],alpha = 0.5)
ax.set_title('Adaptive i-sir samples')
plt.show()

In [None]:
trunc_chain_len = 1000
metrics = compute_metrics(
                    True_samples,
                    sample_flex2_new,
                    name="Adaptive i-sir",
                    trunc_chain_len=trunc_chain_len,
                    ess_rar=1,
)

In [None]:
fig,ax=plt.subplots(1,1)
ax.scatter(sample_flex2_final[:,0,0],sample_flex2_final[:,0,1],alpha = 0.5)
ax.set_title('Flex2 samples')
plt.show()

In [None]:
metrics = compute_metrics(
                    True_samples,
                    sample_flex2_final,
                    name="Flex2",
                    trunc_chain_len=trunc_chain_len,
                    ess_rar=1,
)

### VAE-mcmc

In [None]:
class Vae_loss(nn.Module):
    def __init__(
        self,
        flow,
        kld_weight,
        beta=0.0,
    ):  # .2):#.99):
        super().__init__()

        # self.alpha = alpha
        self.beta = beta
        # self.gamma = gamma
        self.flow = flow
        # self.target = target
        # self.proposal = proposal
        self.kld_weight = kld_weight

    def forward(self, recons, input, mu, log_var):
        # alpha = alpha if alpha is not None else self.alpha
        # beta = beta if beta is not None else self.beta
        # gamma = gamma if gamma is not None else self.gamma

        recons_loss =F.mse_loss(recons, input)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
        loss = recons_loss + self.beta * self.kld_weight * kld_loss
        return loss

In [None]:
class BaseVAE(nn.Module):
    
    def __init__(self) -> None:
        super(BaseVAE, self).__init__()

    def encode(self, input: Tensor) -> List[Tensor]:
        raise NotImplementedError

    def decode(self, input: Tensor) -> Any:
        raise NotImplementedError

    def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
        raise NotImplementedError

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        raise NotImplementedError

    @abstractmethod
    def forward(self, *inputs: Tensor) -> Tensor:
        pass

    @abstractmethod
    def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
        pass


class VAE(BaseVAE):

    num_iter = 0 # Global static variable to keep track of iterations

    def __init__(self,
                 input_size: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 beta: int = 4,
                 gamma:float = 1000.,
                 max_capacity: int = 25,
                 Capacity_max_iter: int = 1e5,
                 loss_type:str = 'B',
                 **kwargs) -> None:
        super(VAE, self).__init__()

        self.latent_dim = latent_dim
        self.beta = beta
        self.gamma = gamma
        self.loss_type = loss_type
        self.C_max = torch.Tensor([max_capacity])
        self.C_stop_iter = Capacity_max_iter

        modules = []
        if hidden_dims is None:
            hidden_dims = [100]
        
        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                 nn.Sequential(
                    nn.Linear(input_size, h_dim),
                    nn.LeakyReLU())
             )
            input_size = h_dim

        # Build Encoder
        # for h_dim in hidden_dims:
        #     modules.append(
        #         nn.Sequential(
        #             nn.Conv2d(in_channels, out_channels=h_dim,
        #                       kernel_size= 3, stride= 2, padding  = 1),
        #             nn.BatchNorm2d(h_dim),
        #             nn.LeakyReLU())
        #     )
        #     in_channels = h_dim
        

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1])

        hidden_dims.reverse()

        modules.append(
            nn.Sequential(
                nn.Linear(hidden_dims[0], hidden_dims[0]),
                nn.LeakyReLU()
            )
        )

        # for i in range(len(hidden_dims) - 1):
        #     modules.append(
        #         nn.Sequential(
        #             nn.ConvTranspose2d(hidden_dims[i],
        #                                hidden_dims[i + 1],
        #                                kernel_size=3,
        #                                stride = 2,
        #                                padding=1,
        #                                output_padding=1),
        #             nn.BatchNorm2d(hidden_dims[i + 1]),
        #             nn.LeakyReLU())
        #     )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Linear(hidden_dims[-1], input_size)
        # nn.Sequential(
        #                     nn.ConvTranspose2d(hidden_dims[-1],
        #                                        hidden_dims[-1],
        #                                        kernel_size=3,
        #                                        stride=2,
        #                                        padding=1,
        #                                        output_padding=1),
        #                     nn.BatchNorm2d(hidden_dims[-1]),
        #                     nn.LeakyReLU(),
        #                     nn.Conv2d(hidden_dims[-1], out_channels= 3,
        #                               kernel_size= 3, padding= 1),
        #                     nn.Tanh())

    def forward(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        prob = 0
        return [mu, log_var], prob

    def inverse(self, z: Tensor) -> Tensor:
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)

        prob = 0
        return result, prob

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Will a single z be enough ti compute the expectation
        for the loss??
        :param mu: (Tensor) Mean of the latent Gaussian
        :param logvar: (Tensor) Standard deviation of the latent Gaussian
        :return:
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    # def forward(self, input: Tensor, **kwargs) -> Tensor:
    #     mu, log_var = self.encode(input)
    #     z = self.reparameterize(mu, log_var)
    #     return  [self.decode(z), input, mu, log_var]

    # def loss_function(self,
    #                   *args,
    #                   **kwargs) -> dict:
    #     self.num_iter += 1
    #     recons = args[0]
    #     input = args[1]
    #     mu = args[2]
    #     log_var = args[3]
    #     kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset

    #     recons_loss =F.mse_loss(recons, input)

    #     kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

    #     if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl
    #         loss = recons_loss + self.beta * kld_weight * kld_loss
    #     elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf
    #         self.C_max = self.C_max.to(input.device)
    #         C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])
    #         loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs()
    #     else:
    #         raise ValueError('Undefined loss type.')

    #     return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    # def log_prob(self, x):
    #     z = self.encode(x)
    #     log_Pr = torch.distributions.Normal(loc=torch.tensor(0., device=x.device, dtype=torch.float32),
    #                                             scale=torch.tensor(1., device=x.device, dtype=torch.float32)).log_prob(
    #             z).sum(-1)

    #     def get_likelihood(self, x_reconst, x):
    #         x_reconst = x_reconst.view(x_reconst.shape[0], -1)
    #         likelihood = torch.distributions.Normal(loc=x_reconst,
    #                                                 scale=self.sigma * torch.ones_like(x_reconst)).log_prob(
    #             x.view(*x_reconst.shape)).sum(-1)

    #         return likelihood

    #     likelihood = self.get_likelihood(x_reconst, x)
    #     return likelihood + log_Pr


    # def generate(self, x: Tensor, **kwargs) -> Tensor:
    #     """
    #     Given an input image x, returns the reconstructed image
    #     :param x: (Tensor) [B x C x H x W]
    #     :return: (Tensor) [B x C x H x W]
    #     """

    #     return self.forward(x)[0]

class VaeMCMC:
    def __init__(self, target, proposal, device, flow, mcmc_call: callable, **kwargs):
        self.flow = flow
        self.proposal = proposal
        self.target = target
        self.device = device
        self.batch_size = kwargs.get("batch_size", 64)
        self.mcmc_call = mcmc_call
        self.grad_clip = kwargs.get("grad_clip", 1.0)
        self.jump_tol = kwargs.get("jump_tol", 1e6)
        optimizer = kwargs.get("optimizer", "adam")
        # loss = kwargs.get("loss", "mix_kl")
        self.flow.to(self.device)
        # if isinstance(loss, (Callable, nn.Module)):
        #     self.loss = loss
        # elif isinstance(loss, str):
        #     self.loss = get_loss(loss)(self.target, self.proposal, self.flow)
        # else:
        #     ValueError
        self.loss = Vae_loss(self.flow, 0.005, 4)

        lr = kwargs.get("lr", 1e-3)
        wd = kwargs.get("wd", 1e-4)
        if isinstance(optimizer, torch.optim.Optimizer):
            self.optimizer = optimizer
        elif isinstance(optimizer, str):
            if optimizer.lower() == "adam":
                self.optimizer = torch.optim.Adam(
                    flow.parameters(), lr=lr, weight_decay=wd
                )

        # self.loss_hist = []

    def train_step(self, inp=None, alpha=0.5, do_step=True, inv=True):
        if do_step:
            self.optimizer.zero_grad()
        if inp is None:
            inp = self.proposal.sample((self.batch_size,))
        elif inv:
            inp, _ = self.flow.forward(inp)
        out = self.mcmc_call(inp, self.target, self.proposal, flow=self.flow)
        if isinstance(out, Tuple):
            acc_rate = out[1].mean()
            out = out[0]
        else:
            acc_rate = 1
        out = out[-1]
        out = out.to(self.device)
        nll = -self.target.log_prob(out).mean().item()

        if do_step:
            # loss_est, loss = self.loss(out, acc_rate=acc_rate, alpha=alpha)

            loss = self.loss(out, acc_rate=acc_rate, alpha=alpha)

            # if (
            #     len(self.loss_hist) > 0
            #     and loss.item() - self.loss_hist[-1] > self.jump_tol
            # ) or torch.isnan(loss):
            #     print("KL wants to jump, terminating learning")
            #     return out, nll

            # self.loss_hist = self.loss_hist[-500:] + [loss_est.item()]
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.flow.parameters(),
                self.grad_clip,
            )
            self.optimizer.step()

        return out, nll

    def train(self, n_steps=100, start_optim=10, init_points=None, alpha=None):
        samples = []
        inp = self.proposal.sample((self.batch_size,))

        neg_log_likelihood = []

        for step_id in trange(n_steps):
            # if alpha is not None:
            #    if isinstance(alpha, Callable):
            #        a = alpha(step_id)
            #    elif isinstance(alpha, float):
            #        a = alpha
            # else:
            a = min(0.5, 3 * step_id / n_steps)

            out, nll = self.train_step(
                alpha=a,
                do_step=step_id >= start_optim,
                inp=init_points if step_id == 0 and init_points is not None else inp,
                inv=True,
            )
            inp = out.detach().requires_grad_()
            samples.append(inp.detach().cpu())

            neg_log_likelihood.append(nll)

        return samples, neg_log_likelihood

    def sample(self):
        pass

In [None]:
params_flex = {
      "N": 200,
      "grad_step": 0.2,
      "adapt_stepsize": True,
      "corr_coef": 0.0,
      "bernoulli_prob_corr": 0.0,
      "mala_steps": 0,
    "flow": {
      "num_blocks": 4, # number of normalizing layers 
      "lr": 1e-3, # learning rate 
      "batch_size": 100,
      "n_steps": 800,
    }
}
batch_size = 1
torch.cuda.empty_cache()

In [None]:
pyro.set_rng_seed(47)
mcmc = Ex2MCMC(**params_flex, dim=dim)
verbose = mcmc.verbose
mcmc.verbose = False
#flow = RNVP(params_flex["flow"]["num_blocks"], dim=dim, device = device)
vae = VAE(2, 5)
vae_mcmc = VaeMCMC(
    target,
    proposal,
    device,
    vae,
    mcmc,
    batch_size=params_flex["flow"]["batch_size"],
    lr=params_flex["flow"]["lr"],
)
vae.train()
out_samples, nll = vae_mcmc.train(
    n_steps=params_flex["flow"]["n_steps"],
)
assert not torch.isnan(
    next(flow.parameters())[0, 0],
).item()
gc.collect()
torch.cuda.empty_cache()
vae.eval()
mcmc.flow = vae
mcmc.verbose = verbose

In [None]:
#sample from a vae
n_steps_flex2 = 2000
batch_size = 1
pyro.set_rng_seed(42)
start = proposal.sample((batch_size,))
mcmc.N = 200
mcmc.mala_steps = 0
mcmc.grad_step = 0.1
# s = time.time()
out = mcmc(start, target, proposal, n_steps = n_steps_flex2)
if isinstance(out, tuple):
    sample = out[0]
else:
    sample = out
sample = np.array(
    [_.detach().numpy() for _ in sample],
).reshape(-1, batch_size, dim)
sample_flex2_new = sample
#resample with 0 mala steps
torch.cuda.empty_cache()
mcmc.mala_steps = 5
out_new = mcmc(start, target, proposal, n_steps = n_steps_flex2)
# print(out_new[1])
out_new = out_new[0]
out_new = np.array(
    [_.detach().numpy() for _ in out_new],
).reshape(-1, batch_size, dim)
sample_flex2_final = out_new
print(sample_flex2_final.shape)

In [None]:
fig,ax=plt.subplots(1,1)
ax.scatter(sample_flex2_new[:,0,0],sample_flex2_new[:,0,1],alpha = 0.5)
ax.set_title('Adaptive i-sir samples')
plt.show()

In [None]:
trunc_chain_len = 1000
metrics = compute_metrics(
                    True_samples,
                    sample_flex2_new,
                    name="Adaptive i-sir",
                    trunc_chain_len=trunc_chain_len,
                    ess_rar=1,
)

In [None]:
fig,ax=plt.subplots(1,1)
ax.scatter(sample_flex2_final[:,0,0],sample_flex2_final[:,0,1],alpha = 0.5)
ax.set_title('Flex2 samples')
plt.show()

In [None]:
metrics = compute_metrics(
                    True_samples,
                    sample_flex2_final,
                    name="Flex2",
                    trunc_chain_len=trunc_chain_len,
                    ess_rar=1,
)