# (C-)DUSVGD for Sampling from a Mixture of Gaussian Distributions


## Settings


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Adam, SGD, RMSprop
from torch.distributions import Normal, Categorical
from torch.distributions.mixture_same_family import MixtureSameFamily

import numpy as np
import math
import random
import seaborn as sns
import matplotlib.pyplot as plt

torch.set_printoptions(edgeitems=1000)
sns.set(style="white")

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True, warn_only=True)

seed = 42
set_seed(seed)

print(torch.cuda.is_available())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_tensor_type("torch.cuda.FloatTensor" if torch.cuda.is_available() else "torch.FloatTensor")
print("device =", device)


## Gaussian Mixture Model

$$p(x) = a_1\mathcal{N}(\mu_1,\sigma_1^2) + a_2\mathcal{N}(\mu_2,\sigma_2^2)$$


In [None]:
class GMM:
    def __init__(self):
        self.weights = torch.tensor([0.75, 0.25])
        self.means = torch.tensor([-2.0, 2.5])
        self.stds = torch.tensor([1.0, 1.0])

    def model(self):
        mix = Categorical(self.weights)
        comp = Normal(self.means, self.stds)
        return MixtureSameFamily(mix, comp)

    def print_parameters(self):
        print('Target Distribution (GMM): weights =', self.weights.tolist(), 
              ', means =', self.means.tolist(), ', stds =', self.stds.tolist())

model_gmm = GMM()
target_model = model_gmm.model()

## Hyper parameters


In [None]:
# Hyperparameters for SVGD and C-DUSVGD
num_particles = 100         # Number of particles in SVGD
data_size = 1000            # Total number of data points
train_batch_size = 50       # Batch size for training
test_batch_size = 100       # Batch size for testing
num_epochs = 10             # Number of epochs for DUSVGD training
num_epochs_c = 40           # Number of epochs for C-DUSVGD training
max_du_iterations = 10      # Number of iterations for dual updates in DUSVGD
lr_adam = 0.005             # Learning rate for Adam optimizer in DUSVGD
lr_adam_c = 0.0005          # Learning rate for Adam optimizer in C-DUSVGD
init_params = 2.0           # Initial value for SVGD learning rate
init_params_c = torch.tensor([0.3, 1.0])  # Initial parameters for C-DUSVGD
init_dist_mean = -2.0       # Mean of the initial distribution for particles
init_dist_std = 1.0         # Standard deviation of the initial distribution for particles

def print_hyperparameters():
    print('num_particles:\t', num_particles)
    print('data_size:\t', data_size)
    print('train_batch_size:\t', train_batch_size)
    print('test_batch_size:\t', test_batch_size)
    print('num_epochs:\t', num_epochs)
    print('num_epochs_c:\t', num_epochs_c)
    print('max_du_iterations:\t', max_du_iterations)
    print('lr_adam:\t', lr_adam)
    print('lr_adam_c:\t', lr_adam_c)
    print('init_params:\t', init_params)
    print('init_params_c:\t', init_params_c.tolist())
    print('init_distribution (Normal): mean =', init_dist_mean, ', std =', init_dist_std)
    model_gmm.print_parameters()

print_hyperparameters()

## Generating Datasets


In [None]:
x_data = torch.normal(init_dist_mean, init_dist_std, size=(data_size, num_particles)).to(device)
t_data = target_model.sample([data_size, num_particles]).to(device)

train_size = int(0.9 * data_size)
test_size = data_size - train_size
x_train, x_test = torch.utils.data.random_split(x_data, [train_size, test_size], generator=torch.Generator(device))
t_train, t_test = torch.utils.data.random_split(t_data, [train_size, test_size], generator=torch.Generator(device))

class CustomDataset:
    def __init__(self, X, t):
        self.X = X
        self.t = t

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return self.X[index], self.t[index]

train_dataset = CustomDataset(x_train, t_train)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, generator=torch.Generator(device))

test_dataset = CustomDataset(x_test, t_test)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True, generator=torch.Generator(device))


## SVGD


In [None]:
def median(tensor):
    tensor = tensor.clone().flatten(1)
    tensor_max = tensor.max(1).values.unsqueeze(1)
    median_val = (torch.cat((tensor, tensor_max), 1).median(1).values + tensor.median(1).values) / 2.0
    return median_val.view(-1, 1, 1)

def get_gradient(model_gmm, inputs, alpha, beta, retain_graph):
    inputs = inputs.clone().requires_grad_(True)
    bs, n, d = inputs.shape

    log_prob = model_gmm.model().log_prob(inputs)
    log_prob_grad = torch.autograd.grad(log_prob.sum(), inputs, retain_graph=retain_graph)[0]

    pairwise_distance = torch.cdist(inputs, inputs, p=2.0).pow(2).to(device)
    h = median(pairwise_distance) / math.log(n)
    kernel = torch.exp(-pairwise_distance / h)
    kernel_grad = 2 * (kernel.sum(2).diag_embed() - kernel).matmul(inputs) / h

    gradient = -(alpha * torch.matmul(kernel, log_prob_grad) + beta * kernel_grad) / n
    return gradient

# DUSVGD

In [None]:
class DUGD(nn.Module):
    def __init__(self, itr):
        super().__init__()
        self.gamma = nn.Parameter(init_params * torch.ones(itr))

    def forward(self, iteration, x_data):
        for i in range(iteration):
            j = i % max_du_iterations
            x_data = x_data - abs(self.gamma[j]) * get_gradient(model_gmm, x_data, 1.0, 1.0, retain_graph=True)

        return x_data, self.gamma

model_dugd = DUGD(max_du_iterations).to(device)
opt_dugd = optim.Adam(model_dugd.parameters(), lr=lr_adam)

# C-DUSVGD

In [None]:
def chebyshev_step(a, b, t, T):
    la1 = a ** 2
    lan = a ** 2 + b ** 2
    lr_c = 1 / ((la1 + lan) / 2 + ((lan - la1) / 2) * math.cos((2 * (T - t) - 1) * torch.pi / (2 * T)))
    return lr_c

class CDUGD(nn.Module):
    def __init__(self):
        super().__init__()
        self.gamma = nn.Parameter(init_params_c)

    def forward(self, iteration, x_data):
        for i in range(iteration):
            lr_c = chebyshev_step(self.gamma[0], self.gamma[1], i, max_du_iterations)
            x_data = x_data - lr_c * get_gradient(model_gmm, x_data, 1.0, 1.0, retain_graph=True)

        return x_data, self.gamma

model_cdugd = CDUGD().to(device)
opt_cdugd = optim.Adam(model_cdugd.parameters(), lr=lr_adam_c)


## Training


In [None]:
class MMDLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, outputs, targets):
        x, y = outputs, targets
        bs, n, d = x.shape
        _, m, _ = y.shape
        sigma = 1

        xy = torch.cat([x.clone(), y.clone()], dim=1).to(device)
        dists = torch.cdist(xy, xy, p=2.0)
        k = torch.exp(-0.5 * dists**2) + torch.ones_like(dists) * 1e-5

        k_x = k[:, :n, :n]
        k_y = k[:, n:, n:]
        k_xy = k[:, :n, n:]
        mmd = (k_x.sum() / (n * (n - 1)) + k_y.sum() / (m * (m - 1)) - 2 * k_xy.sum() / (n * m)) / bs
        return mmd

loss_func = MMDLoss()

def incremental_training():
    for iteration in range(max_du_iterations):
        print()
        for epoch in range(num_epochs):
            for batch in train_dataloader:
                torch.autograd.set_detect_anomaly(True)
                x, t = [tensor.unsqueeze(2).to(device) for tensor in batch]

                opt_dugd.zero_grad()
                x_hat, gamma = model_dugd(iteration + 1, x)
                loss = loss_func(x_hat, t)
                loss.backward()
                opt_dugd.step()

            print(f"\ri: {iteration + 1}\te: {epoch + 1}\tparams: {', '.join(f'{g:.2f}' for g in gamma.tolist())}\tloss: {loss.item():.3f}", end="    ")

def static_training():
    for epoch in range(num_epochs_c):
        for batch in train_dataloader:
            torch.autograd.set_detect_anomaly(True)
            x, t = [tensor.unsqueeze(2).to(device) for tensor in batch]

            opt_cdugd.zero_grad()
            x_hat, gamma = model_cdugd(max_du_iterations, x)
            loss = loss_func(x_hat, t)
            loss.backward()
            opt_cdugd.step()

        print(f"\re: {epoch + 1}\tparams: {', '.join(f'{g:.2f}' for g in gamma.tolist())}\tloss: {loss.item():.3f}", end="    ")
        if (epoch + 1) % 10 == 0:
            print()


In [None]:
incremental_training()

In [None]:
static_training()

# Execute SVGD and Test Performance

In [None]:
lr_dugd = model_dugd.gamma
lr_cdugd = model_cdugd.gamma

D, par, iteration_ = 4, max_du_iterations, 500
mmds = np.zeros((D, iteration_ // par))
lr_c = np.array([chebyshev_step(lr_cdugd[0], lr_cdugd[1], i, max_du_iterations).item() for i in range(max_du_iterations)])

lr_fixed = 2.0
lr_RMSprop = 0.002
lr_Adam = 0.02

# Dual Gradient Descent (DUGD) step for particles
def dugd_step(particles, lr_val):
    particles -= lr_val * get_gradient(model_gmm, particles, 1.0, 1.0, retain_graph=False)
    return particles

# Chebyshev Step Gradient Descent (C-DUGD) step for particles
def cdugd_step(particles, lr_val):
    particles -= lr_val * get_gradient(model_gmm, particles, 1.0, 1.0, retain_graph=False)
    return particles

# RMSprop optimization step for particles
def rmsprop_step(particles, lr_val=lr_RMSprop):
    optimizer = RMSprop([particles], lr=lr_val)
    optimizer.zero_grad()
    particles.grad = get_gradient(model_gmm, particles, 1.0, 1.0, retain_graph=False)
    optimizer.step()
    return particles

# Fixed-step update for particles
def fixed_step(particles, lr_val=lr_fixed):
    particles -= lr_val * get_gradient(model_gmm, particles, 1.0, 1.0, retain_graph=False)
    return particles

def test():
    global mmds, lr_c

    for batch in test_dataloader:
        init_particles, target_particles = [tensor.unsqueeze(2) for tensor in batch]
        break

    particles = [init_particles.clone() for _ in range(D)]
    idx = 0

    for i in range(iteration_):
        print(f"\r{i / iteration_ * 100:.0f}% complete", end="")
        
        # Every `par` iterations, record the MMD between particles and target
        if i % par == 0:
            with torch.no_grad():
                for d in range(D):
                    mmds[d, idx] = loss_func(particles[d], target_particles).item()
                idx += 1

        j = i % max_du_iterations
        particles[0] = dugd_step(particles[0], lr_dugd[j])
        particles[1] = cdugd_step(particles[1], lr_c[j])
        particles[2] = rmsprop_step(particles[2])
        particles[3] = fixed_step(particles[3])

    return particles

particles_list = test()


# Plot Results of SVGD Experiments


In [None]:
# Print hyperparameters to verify settings
print_hyperparameters()

# Plotting function for results
def plotf():
    particles = [p.detach().flatten().cpu().numpy() for p in particles_list]
    
    print("\nLearned Step Sizes:")
    print("DUSVGD:", lr_dugd)
    print("C-DUSVGD:", lr_cdugd)

    w, w2 = 2, 2 #linewidth

    plt.figure(dpi=300)
    target = target_model.sample([10000])
    sns.kdeplot(target.squeeze().detach().cpu().numpy(), linewidth=w2, bw_method=0.2, color="black", label="Target")
    
    styles = ["dashed", "dashdot", "dotted", (0, (1, 1))]
    colors = ["red", "blue", "orange", "green", "purple"]
    labels = ["DUSVGD:Proposed", "C-DUSVGD:Proposed", "RMSProp", "Fixed step size"]

    for i, (p, style, color, label) in enumerate(zip(particles, styles, colors, labels)):
        sns.kdeplot(p, linewidth=w, bw_method=0.2, linestyle=style, color=color, label=label)

    plt.legend()
    plt.title("Particle Distributions Compared to Target")
    plt.show()

    plt.figure(dpi=300)
    x_plt = np.arange(0, iteration_, max_du_iterations)
    for d in range(D):
        plt.plot(x_plt, np.log10(mmds[d]), linewidth=w, linestyle=styles[d], color=colors[d], label=labels[d])

    plt.xlabel("Iteration")
    plt.ylabel("Log10 MMD")
    plt.title("Log MMD over Iterations")
    plt.legend()
    plt.tight_layout()
    plt.show()

    plt.figure(dpi=300)
    plt_x = np.arange(0, max_du_iterations)
    plt_dugd_step = lr_dugd.detach().cpu().numpy()
    plt_cdugd_step = lr_c

    plt.plot(plt_x, plt_dugd_step, marker=".", color="red", linewidth=w, label="DUSVGD")
    plt.plot(plt_x, np.sort(plt_dugd_step)[::-1], marker="o", color="red", alpha=0.2, linewidth=w, label="DUSVGD Sorted")
    plt.plot(plt_x, plt_cdugd_step, marker=".", color="blue", linewidth=w, label="C-DUSVGD")
    
    plt.xlabel("Index t")
    plt.ylabel("Step Size")
    plt.title("Learned Step Sizes for DUSVGD and C-DUSVGD")
    plt.legend()
    plt.tight_layout()
    plt.show()

plotf()