In [5]:
import numpy as np
import scipy as sp
import traceback
import yaml
import hashlib
import os
import pandas as pd
import bridgestan as bs
import time
import matplotlib.pyplot as plt
import pprint

# --- Tools ---

In [6]:
def logsubexp(a, b):
    if a > b:
        return a + np.log1p(-np.exp(b - a))
    elif a < b:
        return b + np.log1p(-np.exp(a - b))
    else:
        return np.inf

def hash_string(s):
    return hashlib.md5(s.encode('utf-8')).hexdigest()

def stan_initializations(model, num_chains=4, num_warmup=1000, num_samples=1000):
    stan_file = model.stan_file
    data = model.data
    model_name = hash_string(open(stan_file).read())
    output_file = f'{model_name}.output.csv'
    if not os.path.exists(output_file):
        cmd = f'stanc {stan_file} --o={model_name}.stan'
        os.system(cmd)
        cmd = f'{model_name}.stan sample num_chains={num_chains} num_warmup={num_warmup} num_samples={num_samples} data file={data} output file={output_file}'
        os.system(cmd)

    df = pd.read_csv(output_file, comment='#')
    stepsize = np.median(df['stepsize__'].values)
    return stepsize

# --- MCMC Base ---

In [7]:
class MCMCBase:
    def __init__(self, model, stepsize, seed=None):
        self.model = model
        self.stepsize = stepsize
        self.rng = np.random.default_rng(seed)
        self.D = self.model.dims()
        self.sampler_name = "MCMCBase"

    def draw(self):
        raise NotImplementedError

    def log_joint(self, theta, rho):
        return self.model.log_density(theta) - 0.5 * np.sum(rho**2)

    def leapfrog_step(self, theta, rho):
        grad = self.model.log_density_gradient(theta)
        rho_mid = rho + 0.5 * self.stepsize * grad
        theta_new = theta + self.stepsize * rho_mid
        grad_new = self.model.log_density_gradient(theta_new)
        rho_new = rho_mid + 0.5 * self.stepsize * grad_new
        return theta_new, rho_new

    def leapfrog(self, theta, rho, steps):
        for _ in range(steps):
            theta, rho = self.leapfrog_step(theta, rho)
        return theta, rho

    def sample_constrained(self, iterations):
        samples = []
        for _ in range(iterations):
            theta_unconstrained = self.draw()
            theta_constrained = self.model.constrain_pars(theta_unconstrained)
            samples.append(theta_constrained)
        return np.array(samples)


# --- HMC Base ---

In [8]:
class HMCBase(MCMCBase):
    def __init__(self, model, stepsize, seed=None, theta=None):
        super().__init__(model, stepsize, seed)
        if theta is None:
            self.theta = self.rng.normal(size=self.D)
        else:
            self.theta = theta

# --- BS Model ---

In [9]:
class BSModel:
    def __init__(self, stan_file, data):
        self.stan_file = stan_file
        self.data = data
        self.model = bs.StanModel(stan_file=stan_file, data=data)

    def log_density(self, theta):
        return self.model.log_density(theta)

    def log_density_gradient(self, theta):
        return self.model.log_density_gradient(theta)

    def dims(self):
        return self.model.dims

    def unconstrain_pars(self, pars):
        return self.model.unconstrain_pars(pars)

    def constrain_pars(self, pars):
        return self.model.constrain_pars(pars)

# --- SARSA HMC ---

In [10]:
class SARSAHMC(HMCBase):
    def __init__(self, model, stepsize, seed=None, theta=None, alpha=0.1, gamma=0.9, epsilon=0.1, num_actions=5):
        super().__init__(model, stepsize, seed, theta)
        self.sampler_name = "SARSA-HMC"
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.num_actions = num_actions
        self.Q = np.zeros((1, num_actions))
        self.action_space = np.linspace(0.5, 1.5, num_actions) * self.stepsize
        self.steps = 0
        self.prop_accepted = 0.0
        self.draws = 0

    def choose_action(self):
        if self.rng.uniform() < self.epsilon:
            return self.rng.choice(self.num_actions)
        else:
            return np.argmax(self.Q[0])

    def draw(self):
        self.draws += 1
        try:
            theta = self.theta
            rho = self.rng.normal(size=self.D)
            H = self.log_joint(theta, rho)

            action = self.choose_action()
            stepsize = self.action_space[action]
            theta_star, rho_star = self.leapfrog(theta, rho, 1)
            H_star = self.log_joint(theta_star, rho_star)

            log_alpha = H_star - H
            accepted = 0
            if np.log(self.rng.uniform()) < np.minimum(0.0, log_alpha):
                accepted = 1
                self.theta = theta_star

            self.prop_accepted += (accepted - self.prop_accepted) / self.draws

            reward = accepted
            next_action = self.choose_action()

            self.Q[0, action] += self.alpha * (reward + self.gamma * self.Q[0, next_action] - self.Q[0, action])

        except Exception as e:
            traceback.print_exc()
            pass
        return self.theta

# --- Random Walk ---

In [11]:
class RW(MCMCBase):
    def __init__(self, model, stepsize, seed=None):
        super().__init__(model, stepsize, seed)
        self.sampler_name = "Random Walk"

    def draw(self):
        theta = self.theta
        theta_star = theta + self.rng.normal(scale=self.stepsize, size=self.D)
        H = self.model.log_density(theta)
        H_star = self.model.log_density(theta_star)
        log_alpha = H_star - H
        if np.log(self.rng.uniform()) < np.minimum(0.0, log_alpha):
            self.theta = theta_star
        return self.theta

# --- Main ---

In [12]:
if __name__ == "__main__":
    stan_file = 'normal3d.stan'
    data = 'normal3d.data.R'

    model = BSModel(stan_file, data)
    stepsize = stan_initializations(model)
    sampler = SARSAHMC(model, stepsize, seed=42)

    num_samples = 1000
    samples = []
    for _ in range(num_samples):
        samples.append(sampler.draw())

    samples = np.array(samples)
    print("Mean:", np.mean(samples, axis=0))
    print("Std:", np.std(samples, axis=0))

TypeError: StanModel.__init__() got an unexpected keyword argument 'stan_file'