In [1]:
from multiprocessing import Pool
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import time 

# Display options
np.set_printoptions(precision=2)
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 300
from matplotlib import rc

import matplotlib as mpl
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "sans-serif",
    "font.sans-serif": ["Computer Modern Sans serif"]})
## for Palatino and other serif fonts use:
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Palatino"],
})

# Local modules
import sys
sys.path.append("modules/")
import utils
import unbiased_estimation
import sampling

# Relo
import imp
imp.reload(unbiased_estimation)
imp.reload(utils)
imp.reload(sampling)

from time import clock_gettime
clock = lambda : clock_gettime(1)

In [2]:
def plot_chains(data, clusts1, clusts2, save_name=None):
    f, axarr = plt.subplots(ncols=2)
    utils.plot_clusts(data, clusts1, axarr[0])
    utils.plot_clusts(data, clusts2, axarr[1])
    if save_name is not None: plt.savefig(save_name)
    plt.show()

In [3]:
def ex6_gen_data(Ndata, sd, sd0=1, K=2, dp_alpha=10):
    # TRANSLATION OF TAMARA's CODE INTO PYTHON
    #
    # generate Gaussian mixture model data for inference later
    #
    # Args:
    #  Ndata: number of data points to generate
    #  sd: covariance matrix of data points around the
    #      cluster-specific mean is [sd^2, 0; 0, sd^2];
    #      i.e. this is the standard deviation in either direction
    #  sd0: std for prior mean
    #
    # Returns:
    #  x: an Ndata x 2 matrix of data points
    #  z: an Ndata-long vector of cluster assignments
    #  mu: a K x 2 matrix of cluster means,
    #      where K is the number of clusters

    # matrix of cluster centers: one in each quadrant
    mu = np.random.normal(scale=sd0, size=[K, 2])
    # vector of component frequencies
    #rho = np.array([0.4,0.3,0.2,0.1])
    rho = stats.dirichlet.rvs(alpha=dp_alpha*np.ones(K))[0]

    # assign each data point to a component
    z = np.random.choice(range(K), p=rho, replace=True, size=Ndata)
    # draw each data point according to the cluster-specific
    # likelihood of its component
    x = mu[z] + np.random.normal(scale=sd, size=[Ndata,2])  
    
    return x

In [5]:
def crp_gibbs_couple(
    data, sd, sd0, initz1, initz2,alpha=0.01, plot=True,
    log_freq=None, maxIters=100, coupling="Maximal", save_base=None):
    """
    
    Args:
        coupling: method of coupling must be "Common_RNG", "Maximal" or "Optimal" ("Common_RNG" used to be "Naive")
    
    """
    
    # initialize the sampler
    z1, z2 = initz1, initz2
    z1s, z2s = [z1.copy()], [z2.copy()]
    
    dists_by_iter = []
    
    # set frequency at which to log state of the chain
    if log_freq is None: log_freq = int(maxIters/10)
    
    # run the Gibbs sampler
    for I in range(maxIters):
            
        z1, z2 = sampling.gibbs_sweep_couple(
            data, z1.copy(), z2.copy(), sd, sd0,
            alpha=alpha, coupling=coupling)
            
        # data counts at each cluster
        clusts1, clusts2 = utils.z_to_clusts(z1), utils.z_to_clusts(z2)  
        z1s.append(z1); z2s.append(z2)
        
        
        dist_between_partitions = utils.adj_dists_fast(clusts1, clusts2)
        dists_by_iter.append(dist_between_partitions)
        
        if (I%log_freq==0 or dist_between_partitions==0) and plot:
            print("Iteration %04d/%04d"%(I, maxIters))
            print("n_clusts: ", len(clusts1), len(clusts2))
            save_name = save_base + "_%04d.png"%I if save_base is not None else None
            plot_chains(data, clusts1, clusts2, save_name=save_name)
            
        if dist_between_partitions == 0:
            print("Chains coupled after %d iterations!"%I)
            break
        
    return z1, dists_by_iter

# Run several replicates to compare performance of couplings

In [7]:
def run_rep(K, Ndata, sd=2., sd0=2., alpha=0.5, lag=200, maxIters=int(1e5)):
    """run_rep runs a replicate and returns the trace and time to coupling for maximal and optimal couplings"""
    np.random.seed()
    data = ex6_gen_data(Ndata, sd, sd0, K=K)
    initz1 = sampling.crp_gibbs(data, sd, sd0, initz, alpha=alpha, plot=False, maxIters=lag)
    initz2 = initz.copy()
    
    # simulate maximal coupling
    st = clock()
    _, trace_maximal = crp_gibbs_couple(
        data, sd, sd0, initz1.copy(), initz2.copy(), alpha=alpha, plot=False, maxIters=maxIters,
        coupling="Maximal", save_base=None)
    end = clock()
    time_maximal = end-st
    
    # simulate common rng coupling
    st = clock()
    _, trace_rng = crp_gibbs_couple(
        data, sd, sd0, initz1.copy(), initz2.copy(), alpha=alpha, plot=False, maxIters=maxIters,
        coupling="Common_RNG", save_base=None)
    end = clock()
    time_rng = end-st
    
    # simulate optimal coupling
    st = clock()
    _, trace_optimal = crp_gibbs_couple(
        data, sd, sd0, initz1.copy(), initz2.copy(), alpha=alpha, plot=False, maxIters=maxIters,
        coupling="Optimal", save_base=None)
    end = clock()
    time_optimal = end-st
    
    return trace_maximal, trace_optimal, trace_rng, time_maximal, time_optimal, time_rng

In [8]:
n_reps = 200
Ndata, K, sd, sd0, alpha = 150, 4, 2., 2.5, 0.2
initz = np.zeros(Ndata, dtype=np.int)
lag = 250 # number of lag iterations
# maxIters = 2000
maxIters = 100

traces_by_coupling = {"Optimal":[], "Maximal":[], "Common_RNG":[]}
times_by_coupling = {"Optimal":[], "Maximal":[], "Common_RNG":[]}


run_in_parallel = True
if run_in_parallel:
    pool_size = 18
    def simulate(rep):
        result = run_rep(K=K, Ndata=Ndata, sd=sd, sd0=sd0, alpha=alpha, lag=lag, maxIters=maxIters)
        print("completed rep %04d"%rep)
        return result
    with Pool(pool_size) as p:
        results = p.map(simulate, range(n_reps))
    
    for (trace_maximal, trace_optimal, trace_rng, time_maximal, time_optimal, time_rng) in results:
        traces_by_coupling["Optimal"].append(trace_optimal)
        traces_by_coupling["Maximal"].append(trace_maximal)
        traces_by_coupling["Common_RNG"].append(trace_rng)
        times_by_coupling["Optimal"].append(time_optimal)
        times_by_coupling["Maximal"].append(time_maximal)
        times_by_coupling["Common_RNG"].append(time_rng)
else:
    for rep in range(n_reps):
        if (10*rep)%n_reps==0: print("Rep %04d/%04d"%(rep, n_reps))

        trace_maximal, trace_optimal, trace_rng, time_maximal, time_optimal, time_rng = run_rep(
            K=K, Ndata=Ndata, sd=sd, sd0=sd0, alpha=alpha, lag=lag, maxIters=maxIters)
        traces_by_coupling["Optimal"].append(trace_optimal)
        traces_by_coupling["Maximal"].append(trace_maximal)
        traces_by_coupling["Common_RNG"].append(trace_rng)
        times_by_coupling["Optimal"].append(time_optimal)
        times_by_coupling["Maximal"].append(time_maximal)
        times_by_coupling["Common_RNG"].append(time_rng)

Chains coupled after 14 iterations!
Chains coupled after 6 iterations!
Chains coupled after 17 iterations!
Chains coupled after 25 iterations!
Chains coupled after 11 iterations!
completed rep 0021Chains coupled after 42 iterations!

Chains coupled after 12 iterations!
Chains coupled after 54 iterations!
Chains coupled after 28 iterations!Chains coupled after 57 iterations!

Chains coupled after 6 iterations!
completed rep 0015
Chains coupled after 69 iterations!
Chains coupled after 32 iterations!
completed rep 0039
Chains coupled after 23 iterations!
Chains coupled after 50 iterations!
Chains coupled after 49 iterations!
Chains coupled after 1 iterations!
Chains coupled after 98 iterations!
Chains coupled after 65 iterations!
Chains coupled after 73 iterations!
Chains coupled after 5 iterations!
Chains coupled after 47 iterations!Chains coupled after 1 iterations!

completed rep 0048completed rep 0022

Chains coupled after 6 iterations!
completed rep 0009Chains coupled after 37 itera

Chains coupled after 16 iterations!
completed rep 0083
Chains coupled after 27 iterations!
completed rep 0105
Chains coupled after 40 iterations!
Chains coupled after 5 iterations!
Chains coupled after 35 iterations!
completed rep 0059
Chains coupled after 45 iterations!
completed rep 0065
Chains coupled after 31 iterations!
Chains coupled after 9 iterations!
Chains coupled after 37 iterations!
completed rep 0111
Chains coupled after 18 iterations!Chains coupled after 12 iterations!

Chains coupled after 30 iterations!
Chains coupled after 38 iterations!completed rep 0100Chains coupled after 88 iterations!


completed rep 0097
Chains coupled after 10 iterations!
Chains coupled after 13 iterations!
completed rep 0095
Chains coupled after 76 iterations!
Chains coupled after 0 iterations!
Chains coupled after 0 iterations!
Chains coupled after 0 iterations!
Chains coupled after 71 iterations!completed rep 0129

Chains coupled after 55 iterations!
completed rep 0103
Chains coupled after 43

Chains coupled after 7 iterations!
Chains coupled after 56 iterations!completed rep 0176

Chains coupled after 33 iterations!
Chains coupled after 18 iterations!
completed rep 0181
Chains coupled after 15 iterations!
Chains coupled after 44 iterations!
completed rep 0195
Chains coupled after 43 iterations!
completed rep 0193
Chains coupled after 17 iterations!
completed rep 0187
Chains coupled after 21 iterations!
Chains coupled after 6 iterations!
Chains coupled after 38 iterations!
Chains coupled after 10 iterations!
completed rep 0182
Chains coupled after 29 iterations!
Chains coupled after 54 iterations!
completed rep 0190
Chains coupled after 37 iterations!
completed rep 0185
Chains coupled after 57 iterations!
completed rep 0199
Chains coupled after 50 iterations!
completed rep 0194
Chains coupled after 25 iterations!
completed rep 0188
Chains coupled after 35 iterations!
completed rep 0179
Chains coupled after 35 iterations!
Chains coupled after 24 iterations!
completed rep 0196

In [9]:
imp.reload(utils)
#Ndata, K, sd, sd0, alpha = 150, 4, 2., 2.5, 0.2
#initz = np.zeros(Ndata, dtype=np.int)
#lag = 250 # number of lag iterations
#maxIters = 2000
fn_base = "../toy_data_results/N=150_K=4_sd=2_sd0=2.5_alpha=0.2"
#traces_by_coupling_150_4_2_25_02 = copy.deepcopy(traces_by_coupling)
traces_fn = fn_base + "_traces.npy"
#np.save(traces_fn, traces_by_coupling_150_4_2_25_02)
traces_by_coupling_150_4_2_25_02 = np.load(traces_fn, allow_pickle=True).item()

#times_by_coupling_150_4_2_25_02 = copy.deepcopy(times_by_coupling)
times_fn = fn_base + "_meeting_times.npy"
#np.save(times_fn, times_by_coupling_150_4_2_25_02)
times_by_coupling_150_4_2_25_02 = np.load(times_fn, allow_pickle=True).item()
title = "Dirichlet Process Mixture Model"
np.random.seed(54)

utils.meeting_times_plots(
    traces_by_coupling_150_4_2_25_02, times_by_coupling_150_4_2_25_02, 
    couplings_plot=['Optimal', 'Maximal', 'Common_RNG'],
    couplings_colors=['#2025df', '#39f810','#fe01b5'], title=title, alpha=1.0, nbins=8, max_time=200,
    linewidth=1.5, iter_interval=5, n_traces_plot=2, max_iter=1000
    )

FileNotFoundError: [Errno 2] No such file or directory: '../toy_data_results/N=150_K=4_sd=2_sd0=2.5_alpha=0.2_traces.npy'