# Diffusion Models

Implement a Score matching generative model in JAX to emulate posterior distribtions. Package name is fusions for diffusion meets ns (primarily constricted by what is available on pypi)

In [None]:
from fusions import DiffusionModel
import numpy as np
import matplotlib.pyplot as plt

Package wraps up everything in a class, designed to wrap around anesthetic (although currently just taking numpy arrays)

In [None]:
dims = 4
from scipy.stats import multivariate_normal
from numpy.random import default_rng

rng = default_rng(0)
from sklearn.datasets import make_spd_matrix
cov=make_spd_matrix(dims)
prior = multivariate_normal(mean=rng.normal(size=dims))
model = DiffusionModel(prior)
# model=DiffusionModel()
# model.beta_max=20

## Posterior to emulate

Emulate a Gaussian Mixture model with a uniform prior 

In [None]:
from sklearn.datasets import make_spd_matrix
cov=make_spd_matrix(dims)
# data = multivariate_normal(mean=rng.normal(size=dims), cov=cov).rvs(1000)
data_1=multivariate_normal(mean=rng.normal(size=dims),cov=np.eye(dims)*0.1).rvs(1000)
data_2=multivariate_normal(mean=rng.normal(size=dims),cov=np.eye(dims)*0.1).rvs(1000)
data=np.concatenate([data_1,data_2])
plt.scatter(data[:,0],data[:,1])
# rng.normal(size=dims)

## Train the model and use it to make predictions

In [None]:
model.train(data, n_epochs=1000,batch_size=256,lr=1e-3)

In [None]:
loss_hist=np.asarray(model.state.losses)
plt.plot(loss_hist[...,1], loss_hist[...,0])

Currently we are defining zero time in our diffusion process to be a gaussian prior, in theory we should be able to furnish the training process with any generative prior we like

In [None]:
x0=model.sample_prior(1000)
x1= model.predict(x0)
x1,x1_t=model.sample_posterior(1000,history=True)


In [None]:
plt.scatter(x0[:,0],x0[:,1],label="Prior")
plt.scatter(x1[:,0],x1[:,1],label="SGM Posterior")
plt.scatter(data[:,0],data[:,1],label="Training Samples")
plt.legend()

Accuracy needs adjusting and plenty to play around with in terms of beta schedule etc. We can do cool things like plot trajectories of prior samples as a function of time as we diffuse them

In [None]:
import anesthetic as ns
a=ns.MCMCSamples(x1).plot_2d([0,1])
ns.MCMCSamples(data).plot_2d(a)

In [None]:
a=ns.MCMCSamples(x1).plot_2d(np.arange(5))
ns.MCMCSamples(data).plot_2d(a)