## Diffusion model
Wenhao Zhang \
Feb 20, 2025

#### Forward process
$x_t = x_{t-1} + f dt + g dw$

corresponding to

$p(x_t|x_{t-1}) = \mathcal{N}(x_t | x_{t-1} + f dt, g^2)$

#### Backward process (also run along the time forward direction)

$x_t = x_{t-1} + [f - g^2 \nabla \ln p(x_t)] dt + g dw$


#### Forward process: posterior Langevin sampling
In the neural ciruits, the continuous attractor network (CAN) runs the forward process to sample stimulus posteriors via running Langevin sampling dynamics.

$x_t = x_{t-1} + \tau^{-1} \nabla \ln \pi(x_{t-1}) dt + \sqrt{2\tau^{-1}} dw$

where $\tau$ is the time constant of the Langevin sampling, and the posterior could be 

$\pi(x) = \mathcal{N}(x| \mu, \sigma^2)$,

which leads to 

$\nabla \ln \pi(x) = \sigma^{-2} (\mu - x)$

And then the detailed Langevin sampling dynamics is 

$x_t = x_{t-1} + \tau^{-1} \sigma^{-2} (\mu - x) dt + \sqrt{2\tau^{-1}} dw$

which is equivalent to define

$f = \tau^{-1} \sigma^{-2} (\mu - x) dt$

and 

$g = \sqrt{2\tau^{-1}}$

in the forward process.


#### Approximation of the score function in the circuit
We propose the inputs from congruent neurons (modeled by a CAN) to opposite neurons approximate the score function via sampling. Specifically,

$ p_{t}(x_t) = \int p(x_t|x_{t-1}) p_{t-1}(x_{t-1}) dx_{t-1}$

where 

$p(x_t|x_{t-1})$ is the transition probability defined by the forward process (Langevin sampling dynamics) in the model.
Using sampling to approximate the above integration,

$ p_t(x_t) \approx \frac{1}{N} \sum_{i=1}^N p(x_t|\tilde{x}_{t-1, i})$ where  $ \tilde{x}_{t-1, i} \sim p_{t-1}(x_{t-1})$

We assume at each time, the neural circuits __only draw one sample
__, and then the above equation reduces into

$ p_t(x_t) \approx p(x_t|\tilde{x}_{t-1}) $

#### Backward process with approximated socre function
By using the above approximated score function, we can find

$
\nabla _{x_t} \ln p_t(x_t) 
\approx 
\nabla _{x_t} \ln p(x_t|\tilde{x}_{t-1})
= g^{-2} (\tilde{x}_{t-1} + fdt - x_t)
$

Then the backward process is 

$
x_t 
\approx x_{t-1} + [f - g^2 \nabla \ln p(x_t|x_{t-1})] dt + g dw \\
= x_{t-1} + [f - (x_{t-1} + fdt - x_t)] dt + g dw \\
= x_{t-1} (1-dt) + x_t dt + f(1-dt)dt + g dw \\
\approx x_{t-1} (1-dt) + (x_t + f ) dt + g dw
$

Note that the $x_t$ on the RHS is generated by the forward process (congruent neurons), which then applied to the backward process (opposite neurons) to generate the $x_t$ on the LHS.

### Loss function (evidence lower bound)
$E[\log p(x_0)] \geq E_q[\log p(x_T) + \sum_{t=1}^T \log \frac{p(x_{t-1}|x_t)}{q(x_t|x_{t-1})}]$

$p(x_{t-1}|x_t) = \frac{p(x_t|x_{t-1})p(x_{t-1})}{p(x_t)}$

Therefore 

$\sum_{t=1}^T \log \frac{p(x_{t-1}|x_t)}{q(x_t|x_{t-1})} = 
\sum_{t=1}^T \log \frac{p(x_t|x_{t-1})}{q(x_t|x_{t-1})} \frac{p(x_{t-1})}{p(x_t)}
$

Substituting into the original form,

$E[\log p(x_0)] \geq E_q[\log p(x_0) + \sum_{t=1}^T \log \frac{p(x_t|x_{t-1})}{q(x_t|x_{t-1})}]$

## Denoising score matching (DSM)

The loss function of DSM is 

$J_{DSM}(\mathbf{w}) = \mathbb{E} _{q_\sigma (\tilde{x}, x)} ||s_\mathbf{w}(\tilde{x}) - \nabla_{\tilde{x}} \ln q_\sigma(\tilde{x}|x)||^2$

where $s_\mathbf{w}(\tilde{x})$ is the score network with parameter $\mathbf{w}$, and $q_\sigma(\tilde{x}|x)$ is the transition probability for the noise perturbation.

To get the theoretical insight about the DSM solution, we consider the score network is an one-layer linear network, e.g., 

$s_\mathbf{w}(\tilde{x}) = \mathbf{w}^\top \tilde{x}$ 

with $\mathbf{w}$ the feedforward weights to the score network.

Taking the gradient over $\mathbf{w}$ and set it to zero

$\frac{\partial J}{\partial \mathbf{w}} = \mathbb{E}_{q(\tilde{x},x)} \{ 
    [\mathbf{w}^\top \tilde{x} - \nabla_{\tilde{x}} \ln q_\sigma(\tilde{x}|x)] \cdot \tilde{x}^\top \}= 0$

which is equivalent to 

$\mathbf{w}^\top \mathbb{E}(\tilde{x} \tilde{x}^\top) = \mathbb{E} [\nabla_{\tilde{x}} \ln q_\sigma(\tilde{x}|x) \cdot \tilde{x}^\top]$

Finally

$\mathbf{w}^\top = [\mathbb{E}(\tilde{x} \tilde{x}^\top)]^{-1} \mathbb{E} [\nabla_{\tilde{x}} \ln q_\sigma(\tilde{x}|x) \cdot \tilde{x}^\top]$

This solution form is quite common and can be learned via biologically plausible learning rule, e.g., studies from Dmitri Chklovski and Cengiz Pehlevan.


#### Insight
We can regard $\tilde{x}$ as congruent neurons' responses, and $\nabla_{\tilde{x}} \ln q_\sigma(\tilde{x}|x)$ the approximated score by the inputs from congruent neurons to opposite neurons.
Therefore the udpate of weights $\mathbf{w}$ only uses local information.

#### Theoretical analysis of $w$ in the score network

Suppose the true score function is 

$$ \nabla_{\tilde{x}} \ln q_\sigma(\tilde{x}|x) = \sigma^{-2} (x - \tilde{x})$$

Then 

$$ \mathbb{E} [\nabla_{\tilde{x}} \ln q_\sigma(\tilde{x}|x) \cdot \tilde{x}^\top] 
 = \mathbb{E} [ \sigma^{-2} (x - \tilde{x})\cdot \tilde{x}^\top], \\
 = \sigma^{-2} [ x \mathbb{E}(\tilde{x}) - \mathbb{E}(\tilde{x}\tilde{x}^\top) ]
$$
where 
$\mathbb{E}(\tilde{x}) = x$, and 
$\mathbb{E}(\tilde{x}\tilde{x}^\top) = x^2 + \sigma^2$


## A toy example: linear Score Function with Gaussian data distribution

Problem setting:
- Data follows a normal distribution: $x \sim \mathcal{N}(\mu, \sigma^2)$
- We perturb it with noise: $\tilde{x} = x + \epsilon$ where $\epsilon \sim \mathcal{N}(0, \sigma_\epsilon^2)$
- We want to learn a linear score function: $s_\theta(\tilde{x}) = w\tilde{x} + b$

Our goal is to find the optimal values of $w$ and $b$.

####  __1. The true Score Function__

For a normal distribution $\mathcal{N}(\mu, \sigma^2)$, the probability density function is:

$$p(x) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x-\mu)^2}{2\sigma^2}\right)$$

The true score function:

$$\nabla_x \log p(x) = \frac{\mu-x}{\sigma^2}$$


####  __2. Denoising score matching and linear score function__

In denoising score matching, we train the score network to predict the noise direction. The objective function is:

$$\mathcal{L}(\theta) 
= \mathbb{E}_{x \sim p(x), \epsilon \sim \mathcal{N}(0,\sigma_\epsilon^2)} 
\left[\left\|s_\theta(x+\epsilon) - \frac{\epsilon}{\sigma_\epsilon^2}\right\|^2\right]$$

For our linear model $s_\theta(\tilde{x}) = w\tilde{x} + b$, we can find the optimal parameters by taking derivatives of the loss with respect to $w$ and $b$ and setting them to zero.

First, let's rewrite our objective function for the scalar case:

$$\mathcal{L}(w,b) = \mathbb{E}_{x,\epsilon}\left[\left(w(x+\epsilon) + b - \frac{\epsilon}{\sigma_\epsilon^2}\right)^2\right]$$

Taking derivatives:

$$\frac{\partial \mathcal{L}}{\partial w} 
= \mathbb{E}_{x,\epsilon}\left[2(w(x+\epsilon) + b - \frac{\epsilon}{\sigma_\epsilon^2})(x+\epsilon)\right]
\propto \mathbb{E}_{x,\epsilon}\left[(w(x+\epsilon) + b - \frac{\epsilon}{\sigma_\epsilon^2})(x+\epsilon)\right] = 0
$$

$$\frac{\partial \mathcal{L}}{\partial b} 
= \mathbb{E}_{x,\epsilon}\left[2(w(x+\epsilon) + b - \frac{\epsilon}{\sigma_\epsilon^2})\right]
\propto \mathbb{E}_{x,\epsilon}\left[w(x+\epsilon) + b - \frac{\epsilon}{\sigma_\epsilon^2}\right] = 0
$$

- __Solving b__

    From the second equation:
    $$w\mathbb{E}[x+\epsilon] + b - \mathbb{E}\left[\frac{\epsilon}{\sigma_\epsilon^2}\right] = 0$$

    Since $\mathbb{E}[\epsilon] = 0$, we have:
    $$w\mu + b = 0$$
    $$b = -w\mu$$

- __Solving w__

    From the first equation, expanding:
    $$\mathbb{E}[w(x+\epsilon)^2 + b(x+\epsilon) - \frac{\epsilon(x+\epsilon)}{\sigma_\epsilon^2}] = 0$$

    Substituting $b = -w\mu$:
    $$w\mathbb{E}[(x+\epsilon)^2] - w\mu\mathbb{E}[x+\epsilon] - \mathbb{E}\left[\frac{\epsilon x + \epsilon^2}{\sigma_\epsilon^2}\right] = 0$$

    Simplifying:
    $$w(\mathbb{E}[x^2] + \mathbb{E}[\epsilon^2]) - w\mu^2 - \frac{\mathbb{E}[\epsilon^2]}{\sigma_\epsilon^2} = 0$$

    Given $\mathbb{E}[x^2] = \sigma^2 + \mu^2$ and $\mathbb{E}[\epsilon^2] = \sigma_\epsilon^2$:
    $$w(\sigma^2 + \mu^2 + \sigma_\epsilon^2) - w\mu^2 - 1 = 0$$ 
    $$w = \frac{1}{\sigma^2 + \sigma_\epsilon^2}$$

#### __3. Score function in limit of with small noise perturbation__

As $\sigma_\epsilon^2 \to 0$ (small noise limit), we get:
$$w = \frac{1}{\sigma^2} \cdot \frac{\sigma^2}{\sigma^2 + \sigma_\epsilon^2} \approx \frac{1}{\sigma^2}$$

Since $b = -w\mu$:
$$b = -\frac{\mu}{\sigma^2}$$

However, there's a sign discrepancy with our expected score function. This is because in the denoising formulation, the score approximates the negative gradient direction. Therefore the correct expressions are:

$$w = -\frac{1}{\sigma^2}$$
$$b = \frac{\mu}{\sigma^2}$$

## Conclusion

The optimal parameters for a linear score function $s_\theta(\tilde{x}) = w\tilde{x} + b$ to approximate the score of a normal distribution $\mathcal{N}(\mu, \sigma^2)$ are:

$$w = -\frac{1}{\sigma^2}$$
$$b = \frac{\mu}{\sigma^2}$$

This aligns with our intuition: the score function should point toward the mean with strength inversely proportional to the variance.

In [None]:
# A demo code of the diffusion model
# Forward process: Langevin sampling dynamics of a given (posterior) distribution p(x)
#                  Diffuse from uniform distribution to p(x)
# Backward process: the score function is approximated via sampling

# Wen-Hao Zhang, wenhao.zhang@utsouthwestern.edu
# Feb 20, 2024
# UT Southwestern, Dallas, TX

import numpy as np
import matplotlib.pyplot as plt

# --------------------------------------------------
# Parameter of the distribution to be sampled
mu = 0;
sigma = 1;
tau = 1;
num_trials = int(5e3);

# Simulation parameters
dt = 0.01 * tau;
tLen = 10;

# Assemble parameters into a dictionary
ParamDiffusion = {'mu': mu, 'sigma': sigma, 'tau': tau, 'dt': dt};

# --------------------------------------------------
# Define the forward process transition kernel
def logProbForward(x1, x2, Params):
    # Unnormalized forward transition probability
    # x1: current state
    # x2: next state
    diffusion_term = (Params['mu'] - x1) / Params['tau'] / Params['sigma']**2
    trans_mean = x1 + diffusion_term * Params['dt'];
    trans_var = 2 * Params['dt'] / Params['tau'];
    
    logProb = - (x2 - trans_mean)**2 / (2 * trans_var);
    return logProb;

def logProbBackward(x1, x2, Params):
    # Unnormalized backward transition probability
    # x1: current state
    # x2: next state
    
    diffusion_term = (Params['mu'] - x1) / Params['tau'] / Params['sigma']**2
    trans_mean_forward = x1 + diffusion_term * Params['dt'];
    trans_var = 2 * Params['dt'] / Params['tau'];
    
    # Approximate gradient of the marginal distribution at time t via sampling
    # p(x_t+1) = \int q(x_t+1 | x_t) p(x_t) dx_t 
    #          \approx q(x_t+1 | \tilde{x}_t), where \tilde{x}_t ~ p(x_t)
    # Therefore \nabla_x_t log p(x_t+1) \approx \nabla_x_t q(x_t+1 | \tilde{x}_t)
    
    score_approx =  (trans_mean_forward - x2) / trans_var;
    trans_mean_backward = trans_mean_forward - trans_var * score_approx * Params['dt'];
    
    logProb = - (x2 - trans_mean_backward)**2 / (2 * trans_var); 
    return logProb;

# --------------------------------------------------
# Simulate the forward and backward processes

# Initialization
num_steps = int(tLen / dt);
# x = np.zeros(num_steps + 1);
# ELBO = np.zeros(num_steps + 1);
x = np.zeros((num_steps + 1, num_trials));
ELBO = np.zeros((num_steps + 1, num_trials));
logProb_Back_array = np.zeros((num_steps + 1, num_trials));
logProb_Forward_array = np.zeros((num_steps + 1, num_trials));
x[0,:] = 0;
# x[0,:] = np.random.normal(mu, sigma, (1,num_trials));

# Simulate a Langevin dynamics
for iter in range(0, num_steps):
    # dx = sigma**-2 * (mu - x[iter]) / tau * dt + np.sqrt(2/tau * dt) * np.random.normal(0, 1, 1);
    dx = sigma**-2 * (mu - x[iter,:]) / tau * dt + np.sqrt(2/tau * dt) * np.random.normal(0, 1, (1,num_trials));
    x[iter+1,:] = x[iter,:] + dx;
    
    ELBO[iter+1,:] = ELBO[iter,:] + logProbBackward(x[iter,:], x[iter+1,:], ParamDiffusion) - logProbForward(x[iter,:], x[iter+1,:], ParamDiffusion);
    logProb_Back_array[iter+1,:] = logProbBackward(x[iter,:], x[iter+1,:], ParamDiffusion);
    logProb_Forward_array[iter+1,:] = logProbForward(x[iter], x[iter+1,:], ParamDiffusion);
    
# Plot the position over time
trial_example = np.random.randint(0, num_trials);
plt.plot(np.arange(num_steps+1) * dt, x[:,trial_example])
plt.xlabel('Time')
plt.ylabel('Position')
plt.title('Langevin Dynamics Simulation')
plt.show()

ELBO_mean = np.log(np.mean(np.exp(ELBO), axis=1));
# plt.plot(np.arange(num_steps+1) * dt, np.mean(ELBO, axis=1))
plt.plot(np.arange(num_steps+1) * dt, ELBO_mean)
plt.xlabel('Time')
plt.ylabel('ELBO')
plt.show()


plt.plot(np.arange(num_steps+1) * dt, logProb_Forward_array[:,trial_example])
# plt.plot(np.arange(num_steps+1) * dt, logProb_Back_array[:,trial_example])
# plt.plot(np.arange(num_steps+1) * dt, logProb_Back_array[:,trial_example] - logProb_Forward_array[:,trial_example])
plt.xlabel('Time')
plt.ylabel('log probability')
plt.show()


In [None]:

# Demo of the Denoising Score Matching (DSM)

import numpy as np
import matplotlib.pyplot as plt
# --------------------------------------------------
# Parameter of the data distibution
mu0 = 0;
sigma0 = 2;

# Parameter of the transition probability
mu = 0;
sigma = 1;
tau = 1;
num_data = int(50);
num_samples = int(1e4);

# Simulation parameters
dt = 0.01 * tau;
tLen = 10;

# Assemble parameters into a dictionary
ParamDiffusion = {'mu': mu, 'sigma': sigma, 'tau': tau, 'dt': dt};

x = np.random.normal(mu0, sigma0, (num_data,1))
# --------------------------------------------------
# Approximated the score of the data distribution via sampling
# p(x_t+1) = \int q(x_t+1 | x_t) p(x_t) dx_t 
#          \approx q(x_t+1 | \tilde{x}_t), where \tilde{x}_t ~ p(x_t)
# Therefore \nabla_x_t log p(x_t+1) \approx \nabla_x_t q(x_t+1 | \tilde{x}_t)

diffusion_term = (mu - x) / tau / sigma**2
trans_var = 2 * dt / tau;

# Noise perturbation via one step of Langevin dynamics
x_perturb = x + diffusion_term * dt + np.sqrt(trans_var) * np.random.normal(0, 1, (num_data, num_samples));
score_approx =  (x + diffusion_term * dt - x_perturb) / trans_var;

# x_perturb = x + np.random.normal(0, 1, (num_data, num_samples));
score_approx =  (x - x_perturb);

score_approx_mean = np.mean(score_approx, axis = 1);


# --------------------------------------------------
# A linear model to approximate the score function

#  Augment x_perturb by inserting an offset input
# x_perturb_aug = np.expand_dims(x_perturb, axis=1)
# x_perturb_aug = np.concatenate((x_perturb_aug, np.ones(x_perturb_aug.shape)), axis = 1);

w_score = np.linalg.inv(np.dot(x_perturb, x_perturb.transpose())) * np.dot(score_approx, x_perturb.transpose());

# --------------------------------------------------
# The true score of the data distribution
score_true = (mu0 - x) / sigma0**2;
plt.plot(x, score_true)
plt.plot(x, score_approx, '.')
# plt.plot(x, score_approx_mean)
plt.twinx().plot(x, w_score * x, 'o')
plt.xlabel('Data x')
plt.ylabel('Score (gradient of the data distribution)')
plt.show()

plt.plot(x, w_score, 'o')
plt.show()

In [None]:
# x_perturb_aug = np.concatenate((x_perturb, np.ones((num_data, num_samples))), axis = 0);
x_perturb_aug = np.expand_dims(x_perturb, axis=1)

# print(np.ones(tt.shape).shape)
x_perturb_aug = np.concatenate((x_perturb_aug, np.ones(x_perturb_aug.shape)), axis = 1);
print(x_perturb_aug.shape)
# print(tt.shape)
# print(tt_aug.shape)