In [1]:
import os
import io

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import PIL
from tqdm.notebook import tqdm
from IPython.display import display, Image

## Parameters

In [2]:
rand_gen = np.random.RandomState(1)
n_samples = 100
mu_gt = np.array([[1, 1], [1, -1]]) @ np.array([0, 0])
sigma_gt = np.array([[1, 1], [1, -1]]) @ (np.diag([5e-2, 5e-1]) ** 2) @ np.array([[1, 1], [1, -1]])

mu = np.array([1, 1])
sigma = np.diag([5e-3, 5e-3])

alpha = 0.03
n_steps_mcmc = 200
n_steps_cd = 5

In [3]:
color1 = '#1f77b4'
color2 = '#ff7f0e'
color3 = '#2ca02c'
color4 = '#d62728'

## Distribution color map
color4_rgb = np.array([int(color4[1:3], 16), int(color4[3:5], 16), int(color4[5:7], 16)]) / 255
cmap = np.repeat(color4_rgb[None, :], 100, axis=0)
cmap = np.concatenate((cmap, np.linspace(0.75, 0, cmap.shape[0])[:, None]), axis=1)
cmap = mpl.colors.ListedColormap(cmap)

## Initialize

In [4]:
## Create output folder
if not os.path.isdir('./output'):
    os.mkdir('./output')

## Auxiliary function

In [5]:
class GIFWrite:
    def __init__(self, fig):
        self._fig = fig
        self._imgs = []
    def append(self):
        fig.canvas.draw()
        with io.BytesIO() as buf:
            self._fig.savefig(buf, pad_inches=0, format='png', transparent=True)
            buf.seek(0)
            img = PIL.Image.open(buf).copy()
        self._imgs.append(img)
    def save(self, filename, duration):
        print(int(duration * 1000 / len(self._imgs)))
        self._imgs[0].save(filename,
                           save_all=True,
                           append_images=self._imgs[1:],
                           optimize=False,
                           duration=int(duration * 1000 / len(self._imgs)),
                           # transparency=0,
                           loop=1,
                           # disposal=2,
                           )

## Define distribution

In [6]:
## Energy functions
def energy_func(x, mu, sigma):
    return np.log(x).sum(axis=1) + 0.5 * (((np.log(x) - mu[None, :]) @ np.linalg.inv(sigma)) * (np.log(x) - mu[None, :])).sum(axis=1)

def grad_energy_func(x, mu, sigma):
    return ((np.log(x) - mu[None, :]) @ np.linalg.inv(sigma) - 1) / x

## Generate GT
samples_gt = np.exp(rand_gen.randn(n_samples, 2) @ np.linalg.cholesky(sigma_gt).T + mu_gt[None, :])

## Generate model's samples
samples = np.exp(rand_gen.randn(n_samples, 2) @ np.linalg.cholesky(sigma).T + mu[None, :])

## Calculate energy
grid = np.stack(np.meshgrid(np.linspace(1e-10, 4, 400), np.linspace(1e-10, 4, 400)), axis=2)
energy = energy_func(grid.reshape(-1, 2), mu, sigma).reshape(grid.shape[:2])
energy_gt = energy_func(grid.reshape(-1, 2), mu_gt, sigma_gt).reshape(grid.shape[:2])

## Run MCMC

In [7]:
samples_mcmc = rand_gen.rand(*samples.shape) + 1.5
steps_mcmc = [samples_mcmc.copy()]
for _ in range(n_steps_mcmc):
    samples_mcmc += -alpha ** 2 / 2 * grad_energy_func(samples_mcmc, mu, sigma) + alpha * rand_gen.randn(*samples_mcmc.shape)
    steps_mcmc.append(samples_mcmc.copy())
steps_mcmc = np.stack(steps_mcmc, axis=0)

## Run CD-MCMC

In [8]:
samples_cd = samples_gt.copy()
steps_cd = [samples_cd.copy()]
for _ in range(n_steps_cd):
    samples_cd += -alpha ** 2 / 2 * grad_energy_func(samples_cd, mu, sigma) + alpha * rand_gen.randn(*samples_cd.shape)
    steps_cd.append(samples_cd.copy())
steps_cd = np.stack(steps_cd, axis=0)

## Plots

In [39]:
## Prepare figure
## --------------
fig = plt.figure(figsize=(4, 4), dpi=200)
ax = fig.add_axes((0, 0, 1, 1))
ax.set_xlim(-0.5, 4)
ax.set_ylim(-0.5, 4)
ax.axis('off')

## Draw axis
## ---------
ax.arrow(0, -0.5, 0, 4, head_width=0.1, head_length=0.1, fc='k', ec='k')
ax.arrow(-0.5, 0, 4, 0, head_width=0.1, head_length=0.1, fc='k', ec='k')

## Plot dataset
## ------------
gt_points = ax.plot(samples_gt[:, 0], samples_gt[:, 1], '.', color=color1, label='Dataset')[0]

## Save: samples.png
## -----------------
fig.canvas.draw()
fig.savefig('./output/samples.png', pad_inches=0, transparent=True)

## Plot initial model
## ------------------
eng_img = ax.imshow(energy,
                    extent=[1e-10, 4, 1e-10, 4],
                    origin='lower',
                    cmap=cmap,
                    vmin=energy.min(),
                    vmax=energy.min() + 8,
                    alpha=0.9,
                    interpolation='nearest',
                    label='Model'
                    )

## Save: samples_and_model.png
## ---------------------------
fig.canvas.draw()
fig.savefig('./output/samples_and_model.png', pad_inches=0, transparent=True)

## Generate animate: train.png
## ---------------------------
gif_writer = GIFWrite(fig)
for step, t in enumerate(tqdm(np.linspace(0, 1, 50))):
    mu_t = t * mu_gt + (1 - t) * mu
    sigma_t = t * sigma_gt + (1 - t) * sigma
    energy_t = energy_func(grid.reshape(-1, 2), mu_t, sigma_t).reshape(grid.shape[:2])
    eng_img.set_data(energy_t)
    gif_writer.append()
gif_writer.save('./output/train.png', duration=2)

## Plot model's points
## -------------------
eng_img.set_data(energy)
samples_points = ax.plot(samples[:, 0], samples[:, 1], '.', color=color3)[0]

## Save: contrastive_samples.png
## -----------------------------
fig.canvas.draw()
fig.savefig('./output/contrastive_samples.png', pad_inches=0, transparent=True)
samples_points.remove()

## Generate animate: mcmc.png
## --------------------------
mcmc_line = ax.plot([], [], '-', color=color3, lw=1)[0]
samples_points = ax.plot([], [], '.', color=color3)[0]
gif_writer = GIFWrite(fig)
for t in tqdm(range(0, steps_mcmc.shape[0], 3)):
    mcmc_line.set_data(steps_mcmc[:(t + 1), 0, 0], steps_mcmc[:(t + 1), 0, 1])
    gif_writer.append()
samples_points.set_data(steps_mcmc[t, 0, 0], steps_mcmc[t, 0, 1])
gif_writer.append()
gif_writer.save('./output/mcmc.png', duration=2)
samples_points.remove()
mcmc_line.remove()

## Generate animate: cd.png
## ------------------------
n = 20
cd_lines = ax.plot(np.zeros((0, n)), np.zeros((0, n)), '-', color=color2, lw=1, zorder=-1)
gif_writer = GIFWrite(fig)
for t in tqdm(range(steps_cd.shape[0])):
    for i, line in enumerate(cd_lines):
        line.set_data(steps_cd[:(t + 1), i, 0], steps_cd[:(t + 1), i, 1])
    gif_writer.append()
cd_points = ax.plot(steps_cd[t, :n, 0], steps_cd[t, :n, 1], '.', color=color2)[0]
gif_writer.append()
gif_writer.save('./output/cd.png', duration=2)
# fig.savefig('./output/cd.png', pad_inches=0, transparent=True)

## Save: cd_static.png
## -----------------
fig.canvas.draw()
fig.savefig('./output/cd_static.png', pad_inches=0, transparent=True)

## Save: cd_samples.png
## --------------------
for line in cd_lines:
    line.remove()
fig.canvas.draw()
fig.savefig('./output/cd_samples.png', pad_inches=0, transparent=True)

## Save: all_samples.png
## ---------------------
samples_points = ax.plot(samples[:, 0], samples[:, 1], '.', color=color3)[0]
fig.canvas.draw()
fig.savefig('./output/all_samples.png', pad_inches=0, transparent=True)
samples_points.remove()
cd_points.remove()

## Generate animate: single_cd.png
## -------------------------------
gt_points.set_alpha(0.2)
cd_line = ax.plot([], [], '.-', color=color2, lw=1)[0]
gt_point = ax.plot(samples_gt[0, 0], samples_gt[0, 1], '.', color=color1, label='Dataset')[0]
gif_writer = GIFWrite(fig)
for t in tqdm(range(steps_cd.shape[0])):
    cd_line.set_data(steps_cd[:(t + 1), 0, 0], steps_cd[:(t + 1), 0, 1])
    gif_writer.append()
# cd_points = ax.plot(steps_cd[t, 0, 0], steps_cd[t, 0, 1], '.', color='#ff7f0e')[0]
gif_writer.append()
gif_writer.save('./output/single_cd.png', duration=2)
# fig.savefig('./output/single_cd.png', pad_inches=0, transparent=True)

# plt.close(fig)
# display(Image(open('./output/samples.png','rb').read()))


  0%|          | 0/50 [00:00<?, ?it/s]

40


  0%|          | 0/67 [00:00<?, ?it/s]

29


  0%|          | 0/6 [00:00<?, ?it/s]

285


  0%|          | 0/6 [00:00<?, ?it/s]

285


## Plot single CD zoom

In [25]:
fig, ax = plt.subplots(figsize=(1.5, 2.25), dpi=200)
ax.axis('off')
ax.axis('equal')
plt.tight_layout()
cd_line = ax.plot(steps_cd[:, 0, 0], steps_cd[:, 0, 1], '.-', color='#ff7f0e', lw=1)[0]
gt_points = ax.plot(samples_gt[0, 0], samples_gt[0, 1], '.', color='#1f77b4', label='Dataset')[0]
steps_texts = []
for i in range(steps_cd.shape[0]):
     steps_texts.append(ax.text(steps_cd[i, 0, 0] + 0.02, steps_cd[i, 0, 1], f'$x^{{({i})}}$', va='center'))

## Save: single_cd_zoom.png
## ------------------------
fig.savefig('./output/single_cd_zoom.png', pad_inches=0, transparent=True)

## Save: single_cd_zoom.png
## ------------------------
for steps_text in steps_texts:
    steps_text.remove()
fig.savefig('./output/single_cd_original.png', pad_inches=0, transparent=True)

## Save: single_cd_revesed_zoom.png
## ------------------------
gt_points.set_data(steps_cd[-1, 0, 0], steps_cd[-1, 0, 1])
fig.savefig('./output/single_cd_revesed_zoom.png', pad_inches=0, transparent=True)

## Save: single_cd_clean.png
## -------------------------
gt_points.remove()
fig.savefig('./output/single_cd_clean.png', pad_inches=0, transparent=True)

$$
\nabla_{\boldsymbol{\theta}}\mathcal{O}=
\mathbb{E}\left[\nabla_{\boldsymbol{\theta}} E(\tilde{\mathbf{x}};\boldsymbol{\theta})
-\nabla_{\boldsymbol{\theta}} E(\mathbf{x};\boldsymbol{\theta})\right]
+\frac{\partial D_{\text{KL}}(p_{\text{CD}}||p_{\text{Model}})}{\partial p_{\text{CD}}}\nabla_{\boldsymbol{\theta}}p_{\text{CD}}
$$

$$
\boldsymbol{\theta}^{(t+1)}
=\boldsymbol{\theta}^{(t)}+\eta\mathbb{E}\left[
\alpha(\mathbf{x}^{(0)},\mathbf{x}^{(1)},\dots\mathbf{x}^{(5)})
\left(
\nabla_{\boldsymbol{\theta}} E(\mathbf{x}^{(5)};\boldsymbol{\theta})
-\nabla_{\boldsymbol{\theta}} E(\mathbf{x}^{(0)};\boldsymbol{\theta})
\right)\right]
$$

$$
\alpha(\mathbf{x}^{(0)},\mathbf{x}^{(1)},\dots\mathbf{x}^{(5)})=\left(1+\frac
{p(\mathbf{x}^{(0)}|\mathbf{x}^{(1)})\cdots p(\mathbf{x}^{(3)}|\mathbf{x}^{(4)})p(\mathbf{x}^{(5)};\boldsymbol{\theta})}
{p(\mathbf{x}^{(0)}|\mathbf{x}^{(1)})\cdots p(\mathbf{x}^{(1)}|\mathbf{x}^{(0)})p(\mathbf{x}^{(0)};\boldsymbol{\theta})}
\right)^{-1}
$$

In [23]:
colors_hex = """1F77B4
FF7F0E
2CA02C
D62728
9467BD
8C564B
CFECF9
7F7F7F
BCBD22
17BECF"""

for i, color_hex in enumerate(colors_hex.split('\n')):
    color_rgb = np.array([int(color_hex[0:2], 16),
                          int(color_hex[2:4], 16),
                          int(color_hex[4:6], 16)]) / 255
    print(f'\\\\definecolor{{color{i + 1}}}{{rgb}}{{{color_rgb[0]:.2f},{color_rgb[1]:.2f},{color_rgb[2]:.2f}}}')


\\definecolor{color1}{rgb}{0.12,0.47,0.71}
\\definecolor{color2}{rgb}{1.00,0.50,0.05}
\\definecolor{color3}{rgb}{0.17,0.63,0.17}
\\definecolor{color4}{rgb}{0.84,0.15,0.16}
\\definecolor{color5}{rgb}{0.58,0.40,0.74}
\\definecolor{color6}{rgb}{0.55,0.34,0.29}
\\definecolor{color7}{rgb}{0.81,0.93,0.98}
\\definecolor{color8}{rgb}{0.50,0.50,0.50}
\\definecolor{color9}{rgb}{0.74,0.74,0.13}
\\definecolor{color10}{rgb}{0.09,0.75,0.81}


In [36]:
txt = '→➤➜➱➨➽➙➫'
for ch in txt:
    print(ch + str(ch.encode("unicode_escape")))

→b'\\u2192'
➤b'\\u27a4'
➜b'\\u279c'
➱b'\\u27b1'
➨b'\\u27a8'
➽b'\\u27bd'
➙b'\\u2799'
➫b'\\u27ab'


In [37]:
import urllib.parse
urllib.parse.quote_plus(txt)

'%E2%86%92%E2%9E%A4%E2%9E%9C%E2%9E%B1%E2%9E%A8%E2%9E%BD%E2%9E%99%E2%9E%AB'