## score matching

consider u have n data points of unknown distribution and if u want to build a generative model to sample new data points. there are several approaches to solve this problem. One way could be to estimate the pdf of the distribution. 

$$
p(x) = \frac{f(x)}{\int f(x)dx}
$$

here p is defined as probabilty of x occuring.

Where:
- f(x) is an unnormalized density function
- p(x) is the properly normalized probability density function

the integral in the denomintor is difficult to calculate for complex data like images and text. A 32x32x3 image has 3072 dimensions and calcuating the denomintor $\int f(x_{1}, x_{2}, ... x_{n})dx_{1} dx_{2}... dx_{n}$. there are no analytical approaches to solve this and we generally call this `intractable`. 


if we try to model this using some parameters then the equation becomes 

$$
p(x ; \theta) = \frac{f(x; \theta)}{\int f(x; \theta)dx}
$$

if we take log on both sides 

$$
\log{p(x; \theta)} = \log(f(x;\theta)) - \log(\int f(x; \theta)dx)
$$

Integration eliminates the integration variable, so our \int f(x, \theta)dx can be re-written as $Z_{\theta}$ which is free of x now.  

$$
\log{p(x;\theta)} = \log(f(x;\theta)) - \log(Z_{\theta})
$$

if we take derivative with x 

$$
\triangledown_{x} \log{p(x;\theta)} = \triangledown_{x} \log(f(x, \theta)) - \triangledown_{x} \log(Z_{\theta})
$$

the last term when integrating with x is zero, so 

$$
\triangledown_{x} \log{p(x;\theta)} = \triangledown_{x} \log(f(x, \theta)) 
$$

this is defined as $\triangledown_{x} \log(f(x, \theta))$ as $s_{\theta}(x)$ is a neural network which takes x as input and outputs the gradient vector field at that particular point. 

How do we optimize this ? so the network outputs a vector field (model score) and original data has a vector field (data score). we can take the difference of each of these points and then average. if both the scores are same we will have zero loss. this is exactly what is captured by fisher divergence.  

$$
\frac{1}{2} \mathbb{E}_{p_{\text{data}}(\mathbf{x})} \left[ \left\| \nabla_{\mathbf{x}} \log p_{\text{data}}(\mathbf{x}) - s_{\theta}(\mathbf{x}) \right\|^2_2 \right]
$$

However we don't know the ground truth value of the data-score function. there is a way to achieve this using `score matching` described by hyvarinen [here](https://andrewcharlesjones.github.io/journal/21-score-matching.html). 

$$
\mathbb{E}_{p_{\text{data}}(\mathbf{x})} \left[ \frac{1}{2} \|s_{\theta}(\mathbf{x})\|^2_2 + \text{trace}( \underbrace{\nabla_{\mathbf{x}} s_{\theta}(\mathbf{x})}_{\text{Jacobian of } s_{\theta}(\mathbf{x})} ) \right]
$$

- the first term is simple the sum of squared terms of the output. 
- in order to obtain second term, we need to calculate the jacobbian matrix between inputs and outputs. so if we have 3072 inputs and 3072 outputs the jacobian matrix [3072x3072] matrix. taking the sum of all diagonal elements would give u the value of 2nd term.  However getting this matrix is extremely time consuming as we need to calculate gradient 3072 times. there are approximations for this but still for large images of size 512x512 this will be very time-consuming and resource intensive. 


lets take a toy example and compute it. we have a 2D input and we want to fit a linear model to it with output having sigmoid.

## 1st layer 
$$
\begin{bmatrix} y_1 \\ y_2 \end{bmatrix} = \begin{bmatrix} w_{11} & w_{12} \\ w_{21} & w_{22} \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} + \begin{bmatrix} b_1 \\ b_2 \end{bmatrix}
$$

$$
 \begin{bmatrix} y_1 \\ y_2 \end{bmatrix} = \begin{bmatrix} w_{11} x_1 + w_{21} x_2 + b_1 \\ w_{12} x_1 + w_{22} x_2 + b_2 \end{bmatrix}
$$

## 2nd layer 
$$
\begin{bmatrix} y_1 \\ y_2 \end{bmatrix} = \begin{bmatrix} \sigma(y_1) \\ \sigma(y_2) \end{bmatrix}
$$


Now lets calculate the score function $s_\theta(x)$ aka jacobian of the output with respect to the input.

$$
J = \begin{bmatrix}
\frac{\partial y_1}{\partial x_1} & \frac{\partial y_1}{\partial x_2} \\
\frac{\partial y_2}{\partial x_1} & \frac{\partial y_2}{\partial x_2}
\end{bmatrix}
$$

Now lets calculate the gradient for each element. In the above the first element is 

$$
y1 = \sigma(w_{11} x_1 + w_{21} x_2 + b_1)
$$

as we know the derivative of sigmoid is 

$$
\sigma'(x) = \sigma(x) (1 - \sigma(x))
$$

$$
\frac{\partial y_1}{\partial x_1} = \sigma (w_{11} x_1 + w_{21} x_2 + b_1) (1 - \sigma (w_{11} x_1 + w_{21} x_2 + b_1)) w_{11}
$$

$$
\frac{\partial y_2}{\partial x_2} = \sigma (w_{12} x_1 + w_{22} x_2 + b_2) (1 - \sigma (w_{12} x_1 + w_{22} x_2 + b_2)) w_{22}
$$


Now lets take some values and calculate this . we will take w_11 = 1, w_12 = 2, w_21 = 3, w_22 = 4, b_1 = 0, b_2 = 0, x_1 = 1, x_2 = 1

$$
\frac{\partial y_1}{\partial x_1} = \sigma (1 + 3 + 0) (1 - \sigma (1 + 3 + 0)) 1 = 0.017
$$

we will see in pytorch how to calculate this. 



In [1]:
import torch
import torch.nn as nn

# Simple score network
class ToyScoreNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 2)  # 2D input -> 2D output
        self.linear.weight.data = torch.tensor([[1, 3], [2, 4]]).float()
        self.linear.bias.data = torch.tensor([0, 0]).float()
        
    def forward(self, x):
        return torch.sigmoid(self.linear(x))  # Returns score: [∂/∂x₁ log p, ∂/∂x₂ log p]

# Create network and sample data
score_net = ToyScoreNetwork()
x = torch.tensor([[1.0, 1.0]], requires_grad=True)  # One 2D point

# Method 1: Computing full gradients
score = score_net(x)  # Shape: [1, 2]
print("Score output:", score)


Score output: tensor([[0.9820, 0.9975]], grad_fn=<SigmoidBackward0>)


In [2]:
torch.autograd.grad(
        score[:, 0],  # Take i-th component
        x,
        create_graph=True
    )[0]

tensor([[0.0177, 0.0530]], grad_fn=<MmBackward0>)

In [3]:
torch.autograd.grad(
        score[:, 1],  # Take i-th component
        x,
        create_graph=True
    )[0]

tensor([[0.0049, 0.0099]], grad_fn=<MmBackward0>)

> As u can see 0.0177 matches with the calculating the gradient of the function wrt to the input.

Now we will implement our first score matching network here and see how the network learns