In [1]:
class Distribution:
    """Abstract class for unnormalized distribution"""
    
    def log_density(self, x):
        """
            Computes vectorized log of unnormalized log density
            
            x (torch tensor of shape BxD): B points at which we compute log density
            returns (torch tensor of shape B): \log \hat{\pi}(x) 
        """
        raise NotImplementedError
        
    def grad_log_density(self, x):
        """
            Computes vectorized gradient \nabla_x \log \pi(x)
            
            x (torch tensor of shape BxD): point at which we compute \nabla \log \pi
            returns (torch.tensor of shape BxD): gradients of log density
        """
        # FIXME why clone? (not detach?)
        x = x.clone().requires_grad_()
        logp = self.log_density(x)
        logp.sum().backward()

        return x.grad


class Proposal:
    """Abstract class for proposal"""
    
    def sample(self, x):
        """
            Computes vectorized sample from proposal q(x' | x)
            
            x (torch tensor of shape BxD): current point from which we propose
            returns: (torch tensor of shape BxD) new points
        """
        raise NotImplementedError

    def log_density(self, x, x_prime):
        """
            Computes vectorized log of unnormalized log density
            
            x (torch tensor of shape BxD): B points at which we compute log density
            returns (torch tensor of shape B): \log q(x' | x) 
        """
        raise NotImplementedError


class MCMC:
    def __init__(self, distribution, proposal):
        """
            Constructs MCMC sampler
        
            distribution (Distribution): distribution from which we sample
            proposal (Proposal): MCMC proposal
        """
        self.distribution = distribution
        self.proposal = proposal
    
    def _step(self, x, reject=True):
        batch_size = x.shape[0]

        x_prime = self.proposal.sample(x)
        acceptance_prob = self.acceptance_prob(x_prime, x) if reject else torch.ones(batch_size)
        
        # Keep accepted samples
        mask = torch.rand(batch_size) < acceptance_prob
        x[mask] = x_prime[mask]

        # Keep track of # rejected samples
        self._rejected += (1 - mask).type(torch.float32)

        return x

    def simulate(self, initial_point, n_steps, n_parallel=10):
        """
            Run `n_parallel ` simulations for `n_steps` starting from `initial_point`
            
            initial_point (torch tensor of shape D): starting point for all chains
            n_steps (int): number of samples in Markov chain
            n_parallel (int): number of parallel chains
            returns: dict(
                points (torch tensor of shape n_parallel x n_steps x D): samples
                n_rejected (numpy array of shape n_parallel): number of rejections for each chain
                rejection_rate (float): mean rejection rate over all chains
                means (torch tensor of shape n_parallel x n_steps x D): means[c, s] = mean(points[c, :s])
                variances (torch tensor of shape n_parallel x n_steps x D): variances[c, s, d] = variance(points[c, :s, d])
            )
        """
        xs = []
        x = initial_point.repeat(n_parallel, 1)
        self._rejected = torch.zeros(n_parallel)

        dim = initial_point.shape[0]
        sums = np.zeros([n_parallel, dim])
        squares_sum = np.zeros([n_parallel, dim])

        # For each chain, for each dim, over all steps
        means = []
        variances = []

        for i in range(n_steps):
            x = self._step(x)
            xs.append(x.numpy().copy())

            sums += xs[-1]
            squares_sum += xs[-1]**2

            mean, squares_mean = sums / (i + 1), squares_sum / (i + 1)
            means.append(mean.copy())
            variances.append(squares_mean - mean**2)

        xs = np.stack(xs, axis=1)
        means = np.stack(means, axis=1)
        variances = np.stack(variances, axis=1)

        return dict(
            points=xs,
            n_rejected=self._rejected.numpy(),
            rejection_rate=(self._rejected / n_steps).mean().item(),
            means=means,
            variances=variances
        )