In [5]:
%env CUDA_VISIBLE_DEVICES=0
%load_ext autoreload
%autoreload 2
from d3exp.config.defaults import defaults_dict
from d3exp.beta_schedules import get_beta_schedule
from d3exp.diffusion import jax_float64_context
import numpy as np
from matplotlib import pyplot as plt

env: CUDA_VISIBLE_DEVICES=0
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
with jax_float64_context():
    beta_schedules = {
        config.beta_schedule_type: np.asarray(get_beta_schedule(config).get_betas(config.num_timesteps))
        for key, config in defaults_dict.items()
    }

In [11]:
steps = np.arange(defaults_dict['absorbing'].num_timesteps)
def plot_figure(betas, saveto):
    fig = plt.figure(figsize=(2, 2), dpi=600)
    plt.plot(steps, betas)
    plt.yscale('log')
    plt.xlabel('Step')
    plt.ylabel(r'$\beta$')
    plt.savefig(saveto, bbox_inches='tight', pad_inches=0)
    plt.close(fig)
for key, betas in beta_schedules.items():
    plot_figure(np.asarray(betas), f'../results/beta_{key}.pdf')

In [24]:
fig = plt.figure(figsize=(3, 2), dpi=600)
for key, betas in beta_schedules.items():
    if key == 'step': 
        betas = betas / 5
    kwargs = {}
    if key == 'jsd':
        kwargs['linestyle'] = '--'
    plt.plot(steps, betas, label=key, alpha=0.7, **kwargs)
plt.xlabel('Step')
plt.ylabel(r'$\beta$')
plt.legend()
plt.savefig('../results/beta_schedules.pdf', bbox_inches='tight', pad_inches=0)
plt.close(fig)

In [30]:
fig = plt.figure(figsize=(3, 2), dpi=600)
plt.plot(steps, beta_schedules['linear'], alpha=0) # keep color
plt.plot(steps, beta_schedules['cosine'], label='cosine', alpha=0.7)
plt.plot(steps, beta_schedules['jsd'], label='jsd', alpha=0.7, linestyle='--')
plt.yscale('log')
plt.xlabel('Step')
plt.ylabel(r'$\log(\beta)$')
plt.legend()
plt.savefig('../results/beta_schedules_logscale.pdf', bbox_inches='tight', pad_inches=0)
plt.close(fig)