# Setup

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns; sns.set()

In [None]:
import matplotlib.cm as cm

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
from ml import np
from ml.rbms import RBM

In [None]:
def sample_rbm(rbm, n, initial, burnin=1000, sample_ever=10, sampler='cd', sample_every=10, **sampler_kwargs):
    v = initial
    
    for i in range(burnin):
        if sampler == 'pt':
            v, h = rbm.parallel_tempering(v, **sampler_kwargs)
        elif sampler.lower() == 'pcd':
            _, _, v, h = rbm.contrastive_divergence(v, **sampler_kwargs)
        else:
            _, _, v, h = rbm.contrastive_divergence(v, persistent=True, **sampler_kwargs)
    if sampler == 'pt':
        visibles = np.zeros((n, v[0].shape[1]))
        hiddens = np.zeros((n, h[0].shape[1]))
    else:
        visibles = np.zeros((n, v.shape[1]))
        hiddens = np.zeros((n, h.shape[1]))
    for i in range(n * sample_every):
        if sampler == 'pt':
            v, h = rbm.parallel_tempering(v, **sampler_kwargs)
        elif sampler.lower() == 'pcd':
            _, _, v, h = rbm.contrastive_divergence(v, **sampler_kwargs)
        else:
            _, _, v, h = rbm.contrastive_divergence(v, persistent=True, **sampler_kwargs)
        
        if i % sample_every == 0:
            if sampler == 'pt':
                visibles[i // sample_every] = v[0]
                hiddens[i // sample_every] = h[0]
            else:
                visibles[i // sample_every] = v
                hiddens[i // sample_every] = h
        
    return visibles, hiddens

In [None]:
def gibbs_sample_rbm(rbm, n, initial, burnin=1000, sample_every=10):
    v = initial
    h = rbm.sample_hidden(v)
    for i in range(burnin):
        v = rbm.sample_visible(h)
        h = rbm.sample_hidden(v)
        
    visibles = np.zeros((n, v.shape[0]))
    hiddens = np.zeros((n, h.shape[0]))
    for i in range(n * sample_every):
        v = rbm.sample_visible(h)
        h = rbm.sample_hidden(v)
        if i % sample_every == 0:
            visibles[i // sample_every] = rbm.sample_visible(h)
            hiddens[i // sample_every] = rbm.sample_hidden(v)
        
    return visibles, hiddens

In [None]:
def plot_gaussian_mixtures(data, samples_v, title=None, include_means=False):
    fig, axes = plt.subplots(data.shape[1], 1, figsize=(15, 16), sharex=True, sharey=True)
    
    if title is not None:
        fig.suptitle(title, fontsize='x-large')

    for j in range(data.shape[1]):
        axes[j].hist(samples_v[:, j], alpha=0.5, bins=100, density=True, label=f"{j} fake")
        axes[j].hist(data[:, j], alpha=0.5, bins=100, density=True, label=f"{j} real")
        if include_means:
            axes[j].vlines(np.mean(samples_v[:, j]), ymin=0, ymax=0.5)
            axes[j].vlines(np.mean(data[:, j]), ymin=0, ymax=0.1, color='g', linewidth=5, alpha=0.7)
        axes[j].legend()
        axes[j].set_xlim(-5, 15)
        
    return fig, axes

# Toy-problems

## Multivariate Gaussian

In [None]:
!mkdir -p runs/test

In [None]:
!ls runs

In [None]:
!rm runs/test/*

In [None]:
visible_size = 6
hidden_size = 6

In [None]:
means = np.arange(visible_size) + np.random.random(size=visible_size) * 3.0

In [None]:
cov = np.zeros((visible_size, visible_size))
for i in range(visible_size):
    cov[i, i] = 1.0
    cov[max(i - 1, 0), i] = 1.0
    cov[min(i + 1, visible_size - 1), i] = 1.0

# cov[0, 1] = cov[1, 0] = 0.8
# cov[5, 4] = cov[4, 5] = -0.5
# cov[5, 5] = 3.0
    
cov = np.matmul(cov, cov)

In [None]:
fig = plt.figure(figsize=cov.shape)
coloraxes = plt.imshow(cov, cmap=cm.viridis)
cbar = fig.colorbar(coloraxes)

In [None]:
data = np.random.multivariate_normal(means, np.matmul(cov, cov) , size=100000)

In [None]:
train_data, test_data = train_test_split(data)

In [None]:
from ml.rbms.core import RBM

In [None]:
# Training parameters
LR = 0.001
BATCH_SIZE = 128
NUM_EPOCHS = 10
K = 1

V_SIGMA = 0.1

### Contrastive Divergence

In [None]:
rbm.v_sigma

In [None]:
rbm = RBM(visible_size, hidden_size, 
          visible_type='gaussian', hidden_type='bernoulli',
          sampler_method="cd",
          estimate_visible_sigma=True)

In [None]:
# rbm.v_sigma = np.std(train_data, axis=0)

In [None]:
callbacks = dict(
    pre_epoch=[lambda model, epoch: model.dump(f"runs/test/{epoch:04d}_vars.pkl", 'v_bias', 'h_bias', 'W')],
    post_step=[lambda model, epoch, end: model.dump(f"runs/test/{epoch:04d}_{end:04d}_vars.pkl", 'v_bias', 'h_bias', 'W')]
)

In [None]:
stats = rbm.fit(
    train_data, 
    k=K, 
    batch_size=BATCH_SIZE, 
    num_epochs=NUM_EPOCHS * 10, 
    learning_rate=LR,
    test_data=test_data,
#     callbacks=callbacks
)

In [None]:
!ls runs/test

In [None]:
list(stats.keys())

In [None]:
plt.plot(stats['nll_train'])
plt.plot(stats['nll_test'])

In [None]:
samples_v, samples_h = gibbs_sample_rbm(rbm, 100000, data[123], burnin=10000)

In [None]:
plot_gaussian_mixtures(data, samples_v, title=f"CD-k with $k={K}$")

In [None]:
v = data

np.mean(np.abs(np.sum(np.matmul(v, rbm.W), axis=1)))

In [None]:
im = np.cov(np.round(samples_v), rowvar=False)

fig = plt.figure(figsize=im.shape)
coloraxes = plt.imshow(im, cmap=cm.viridis)
cbar = fig.colorbar(coloraxes)

In [None]:
fig = plt.figure(figsize=(visible_size, visible_size))
coloraxes = plt.imshow(np.matmul(rbm.W, rbm.W.T), cmap=cm.viridis)
cbar = fig.colorbar(coloraxes)

### Weights through time

In [None]:
import os
import glob

In [None]:
# [os.remove(p) for p in glob.glob("runs/test/00*_*_vars.pkl")]

In [None]:
import pickle

def load(p):
    with open(p, "rb") as f:
        return pickle.load(f)

In [None]:
import time
from IPython import display

fig, axes = plt.subplots(1, 4, figsize=(10 + 6 + 10, 20))

# load weights
history = (load(p) for p in sorted(glob.glob("runs/test/00*_*_vars.pkl")))
i = 0
for h in history:
#     print(i)
    i += 1
    if i % 50 != 0:
        continue
    W = h['W']
    fig.suptitle(i)
    axes[0].imshow(np.matmul(W, W.T), cmap=cm.viridis)
    axes[1].imshow(h['v_bias'].reshape(-1, 1), cmap=cm.viridis)
    axes[2].imshow(h['h_bias'].reshape(-1, 1), cmap=cm.viridis)
    axes[3].imshow(np.matmul(W.T, W), cmap=cm.viridis)
#     plt.colorbar(coloraxes)
#     cbar = fig.colorbar(coloraxes)
    display.clear_output(wait=True)
    display.display(plt.gcf())
    time.sleep(1.0)

### Parallel Tempering

In [None]:
rbm = RBM(visible_size, hidden_size, 
          visible_type='gaussian', hidden_type='bernoulli',
          sampler_method="pt",
          estimate_visible_sigma=True)

In [None]:
# rbm.v_sigma = np.var(train_data, axis=0)

In [None]:
train_nll, test_nll = rbm.fit(
    train_data, 
    k=K, 
    batch_size=BATCH_SIZE, 
    num_epochs=NUM_EPOCHS * 10, 
    learning_rate=LR,
    test_data=test_data,
    persist=True,
#     callbacks=callbacks
)

In [None]:
plt.plot(train_nll)
plt.plot(test_nll)

In [None]:
samples_v, samples_h = gibbs_sample_rbm(rbm, 10000, data[123], burnin=1000)

In [None]:
plot_gaussian_mixtures(data, samples_v, title="PT with $k = 1$ and $R = 10$", include_means=True)

In [None]:
im = np.matmul(rbm.W, rbm.W.T)

fig = plt.figure(figsize=im.shape)
coloraxes = plt.imshow(im, cmap=cm.viridis)
cbar = fig.colorbar(coloraxes)

# 1D Ising

In [None]:
import pandas as pd

In [None]:
df = pd.read_csv("data/ising_binary_1d.csv", skiprows=1)

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
data = df.values.astype(np.float64)

In [None]:
# data += np.random.normal(scale=0.1, size=data.shape)

In [None]:
train_data, test_data = train_test_split(data)

In [None]:
LR = 0.01
BATCH_SIZE = 128
NUM_EPOCHS = 10
K = 1

V_SIGMA = 0.1

In [None]:
visible_size = 6
hidden_size = 6

## Model

In [None]:
rbm_cd = RBM(visible_size, hidden_size, 
          visible_type='gaussian', hidden_type='bernoulli',
          sampler_method="cd",
          estimate_visible_sigma=False)

rbm_cd.v_sigma = V_SIGMA

In [None]:
train_nll, test_nll = rbm_cd.fit(
    train_data, 
    k=K, 
    batch_size=BATCH_SIZE, 
    num_epochs=NUM_EPOCHS, 
    learning_rate=LR,
    test_data=test_data,
#     reset_per_epoch=True,
#     burnin=1000,
#     persist=True
)

In [None]:
plt.plot(train_nll)
plt.plot(test_nll)

In [None]:
train_test_split?

In [None]:
fig = plt.figure(figsize=(visible_size, visible_size))
coloraxes = plt.imshow(np.matmul(rbm_cd.W, rbm_cd.W.T), cmap=cm.viridis)
cbar = fig.colorbar(coloraxes)

### Sample

In [None]:
num_samples = 20000
burnin = 2000
sample_every = 10

In [None]:
samples_v, samples_h = gibbs_sample_rbm(rbm_cd, num_samples, test_data[0], burnin=burnin, sample_every=sample_every)

In [None]:
v = test_data[0].reshape(1, -1)
betas = [1.0, 0.5, 0.25, 0.1]
v_pt = np.tile(v, (len(betas), 1, 1))

samples_v_pt, samples_h_pt = sample_rbm(
    rbm_cd, num_samples, 
    v_pt, 
    burnin=burnin, 
    sample_every=sample_every,
    sampler='pt',
    k=1,
    betas=betas
#     betas=np.linspace(0.0, 1.0, 2 * K + 1)[::-1][:-1] # drop 0.0
#     num_temps=10, max_temp=1000
)

In [None]:
fig, axes = plt.subplots(data.shape[1], 2, figsize=(15, 16), sharex=False, sharey=False)

fig.suptitle("CD with $k = 10$", fontsize='x-large')

for j in range(data.shape[1]):
    axes[j][0].hist(samples_v[:, j], alpha=0.5, bins=100, density=True, label=f"{j} fake (Gibbs sampling)")
    axes[j][1].hist(samples_v_pt[:, j], alpha=0.5, bins=100, density=True, color='red', label=f"{j} fake (PT sampling)")
#     axes[j][0].hist(data[:, j], alpha=0.5, bins=100, density=True, label=f"{j} real")
#     axes[j].vlines(np.mean(samples_v[:, j]), ymin=0, ymax=0.5)
#     axes[j].vlines(np.mean(data[:, j]), ymin=0, ymax=0.1, color='g', linewidth=5, alpha=0.7)
    axes[j][0].legend()
    axes[j][1].legend()
#     axes[j].set_xlim(-5, 15)

plt.savefig("gibbs_and_pt_sampling_1d_ising_samples_200000_every_100_burnin_20000.png")

In [None]:
fig = plt.figure(figsize=(visible_size, visible_size))
coloraxes = plt.imshow(np.matmul(rbm_cd.W, rbm_cd.W.T), cmap=cm.viridis)
cbar = fig.colorbar(coloraxes)

In [None]:
fig = plt.figure(figsize=(visible_size, visible_size))
coloraxes = plt.imshow(np.corrcoef(samples_v, rowvar=False), cmap=cm.viridis)
cbar = fig.colorbar(coloraxes)

In [None]:
fig = plt.figure(figsize=(visible_size, visible_size))
coloraxes = plt.imshow(np.corrcoef(samples_v_pt, rowvar=False), cmap=cm.viridis)
cbar = fig.colorbar(coloraxes)

In [None]:
im = np.cov(data, rowvar=False)

fig = plt.figure(figsize=im.shape)
coloraxes = plt.imshow(im, cmap=cm.viridis)
cbar = fig.colorbar(coloraxes)

### Parallel Tempering

In [None]:
rbm = RBM(visible_size, hidden_size, 
          visible_type='bernoulli', hidden_type='gaussian',
          sampler_method="pt",
          estimate_visible_sigma=False)

In [None]:
rbm.v_sigma = V_SIGMA

In [None]:
train_nll, test_nll = rbm.fit(
    train_data,
    batch_size=BATCH_SIZE, 
    num_epochs=NUM_EPOCHS, 
    learning_rate=LR * 0.1,  # use smaller learning-rate when using PT
    test_data=test_data,
    k=1,
    num_temps=K,  # k = 1 per temp, but use K temps
    max_temp=1000,
)

In [None]:
plt.plot(train_nll)
plt.plot(test_nll)

#### Sample

In [None]:
samples_v, samples_h = gibbs_sample_rbm(rbm, 100000, test_data[0], burnin=10000, sample_every=1)

In [None]:
fig, axes = plt.subplots(data.shape[1], 1, figsize=(15, 16), sharex=True, sharey=True)

fig.suptitle("PT with $k = 1$ and $R = 10$", fontsize='x-large')

for j in range(data.shape[1]):
    axes[j].hist(samples_v[:, j], alpha=0.5, bins=100, density=True, label=f"{j} fake")
    axes[j].hist(data[:, j], alpha=0.5, bins=100, density=True, label=f"{j} real")
    axes[j].vlines(np.mean(samples_v[:, j]), ymin=0, ymax=0.5)
#     axes[j].vlines(np.mean(data[:, j]), ymin=0, ymax=0.1, color='g', linewidth=5, alpha=0.7)
    axes[j].legend()
#     axes[j].set_xlim(-5, 15)

In [None]:
fig = plt.figure(figsize=(visible_size, visible_size))
coloraxes = plt.imshow(np.matmul(rbm.W, rbm.W.T), cmap=cm.viridis)
cbar = fig.colorbar(coloraxes)

In [None]:
fig = plt.figure(figsize=(visible_size, visible_size))
coloraxes = plt.imshow(np.cov(train_data, rowvar=False), cmap=cm.viridis)
cbar = fig.colorbar(coloraxes)

In [None]:
fig = plt.figure(figsize=(visible_size, visible_size))
coloraxes = plt.imshow(np.cov(samples_v, rowvar=False), cmap=cm.viridis)
cbar = fig.colorbar(coloraxes)