In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

from collections import defaultdict

In [None]:
# Beta schedules
from source.diff_util import cosine_beta_schedule, get_named_beta_schedule, log_1_min_a

num_timesteps = 1000
alphas = cosine_beta_schedule(num_timesteps)
alphas = 1 - get_named_beta_schedule('cosine', num_timesteps)

plt.plot(alphas)
plt.show()

In [None]:
import torch
alphas = torch.tensor(alphas)

log_alpha = np.log(alphas)
log_cumprod_alpha = np.cumsum(log_alpha)
log_1_min_alpha = log_1_min_a(log_alpha)
log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)

fig, axes = plt.subplots(2, 2, figsize=(8,8))

axes[0,0].plot(log_alpha)
axes[0,0].set_title('log_alpha')

axes[0,1].plot(log_cumprod_alpha)
axes[0,1].set_title('log_cumprod_alpha')

axes[1,0].plot(log_1_min_alpha)
axes[1,0].set_title('log_1_min_alpha')

axes[1,1].plot(log_1_min_cumprod_alpha)
axes[1,1].set_title('log_1_min_cumprod_alpha')

In [None]:
num_classes = 10
t = 500

x_start = np.zeros(num_classes)
x_start[0] += 1
log_x_start = np.log(x_start.clip(min=1e-30))

def log_add_exp(a, b):
    maximum = a.copy()
    maximum[a < b] = b
    return maximum + np.log(np.exp(a - maximum) + np.exp(b - maximum))

log_cumprod_alpha_t = log_cumprod_alpha[t].numpy()
log_1_min_cumprod_alpha_t = log_1_min_cumprod_alpha[t].numpy()

log_probs = log_add_exp(
    log_x_start + log_cumprod_alpha_t,
    log_1_min_cumprod_alpha_t - np.log(num_classes)
)

logits = log_probs

data = []
for i in range(1000):
    uniform = np.random.rand(*logits.shape)
    gumbel_noise = -np.log(-np.log(uniform + 1e-30) + 1e-30)
    sample = (gumbel_noise + logits).argmax(axis=-1)
    data.append(sample)

data = np.array(data)
sns.histplot(data,bins=np.arange(data.min(), data.max()+2),kde=False)
plt.show()

In [None]:
data = []
for t in range(0, 1000):
    x_start = np.zeros(num_classes)
    x_start[0] += 1
    log_x_start = np.log(x_start.clip(min=1e-30))

    log_cumprod_alpha_t = log_cumprod_alpha[t].numpy()
    log_1_min_cumprod_alpha_t = log_1_min_cumprod_alpha[t].numpy()
    
    log_probs = log_add_exp(
        log_x_start + log_cumprod_alpha_t,
        log_1_min_cumprod_alpha_t - np.log(num_classes)
    )

    data.append(log_probs[0])

data = np.array(data)
plt.plot(np.exp(data))
plt.show()

In [None]:
num_classes = 10
t = 800

x_start = np.zeros(num_classes)
x_start[0] += 1
log_x_start = np.log(x_start.clip(min=1e-30))

def log_add_exp(a, b):
    maximum = a.copy()
    maximum[a < b] = b
    return maximum + np.log(np.exp(a - maximum) + np.exp(b - maximum))

log_cumprod_alpha_t = log_cumprod_alpha[t].numpy()
log_1_min_cumprod_alpha_t = log_1_min_cumprod_alpha[t].numpy()

log_probs = log_add_exp(
    log_x_start + log_cumprod_alpha_t,
    log_1_min_cumprod_alpha_t - np.log(num_classes)
)

logits = log_probs

data = []
for i in range(1000):
    dist = np.random.dirichlet(np.ones_like(logits))
    sample = (dist * np.exp(logits))
    sample /= sample.sum(axis = -1)
    sample = np.log(sample)
    sample = sample.argmax(axis=-1)
    # print(sample)
    data.append(sample)

data = np.array(data)
sns.histplot(data,bins=np.arange(data.min(), data.max()+2),kde=False)
plt.show()

In [None]:
uniform = np.random.rand(10)
gumbel_noise = -np.log(-np.log(uniform + 1e-30) + 1e-30)
    
sns.barplot(np.exp(gumbel_noise))
plt.show()

dirichlet = np.random.dirichlet(np.ones(10))
    
sns.barplot(dirichlet)
plt.show()

In [None]:
import torch

dist = torch.distributions.Dirichlet(torch.ones((2, 2))).sample()
print(dist)