# Estimating mean and variance through back propagation.

Although deep learning typically solves classification problems, it is still a powerful approach
regression problems. As any entry level ML course teaches you, the loss function for regression
is MSE (mean squared error). 

One of the core reasons MSE is the right loss is because it assumes the distribution of errors follows
a normal distribution, and that the variance (or standard deviation) of those errors is constant. The former is a good assumption because the CLT (central limit theorem) sort of says that the "average" distribution is the normal distribution. 

Looking at the PDF (probability distribution function) of a Normal distribution

$$
f(x \mid \mu, \sigma^2) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x - \mu)^2}{2\sigma^2}\right)
$$

we see hidden in there the square of the delta between the mean and a datapoint: ${(x - \mu)^2}$. That's actually the heart of why we use the MSE.

When we run a regression model, we are predicting some value y based on x. Such as height (y) based on age (x). We know that we can't possibly predict any specific person's height based on their age. What our model is really doing is predicting the average height y of all people with age x. 

When we think the distribution of heights based on age follows a normal distribution, when we see a datapoint $x$, we consider that an Error against our predicted mean height, then we Square the Error, then we just take the Mean of all the datapoints. That's how we end up using MSE. 

Why can we ignore the variance or $\sigma^2$? Because an assumption is that the variance of the error is constant. That's called "homoscedastic" - homo meaning the same like in homonym. So when we calculate the gradient, we can ignore any constant value. 

What about that exponent function? Well, training neural networks is typically based on MLE (Maximum Likelihood Estimation), which says that whatever parameters give you the highest joint probability matching the output are the most likely params. Joint probability involves taking the product of lots of small numbers - for numerical stability you can log the whole thing to make it a sum of logs instead. 

$$
\log\left(\prod_{i=1}^{n} x_i\right) = \sum_{i=1}^{n} \log(x_i)

$$

Those logs are also why MSE doesn't have any pesky exponent or log functions. A log is applied to an exponent, cancelling each other out. 

$$
\log(\exp(x)) = x
$$

If the variance is not constant then the data has "heteroscedasticity" - hetero means different, like in heterogeneous, and we actually run into lots of problems when trying to do things like t-tests, but that's not the focus of this article.

The focus is how to get a back-propped model to estimate variance. The answer is simply to output two heads - one for mean of our y value, one for variance, and undo the two shortcuts we applied above (constant variance, log of exponent). Based on MLE, we still need the NLL (negative of the log of the likelihood). And the likelihood is mathematically the same thing as the probability according to the frequentistic statistics behind MLE. And the probability is given by the PDF: a function estimating the probability of the prediction being x given the parameters mean and variance. 

Basically, all we have to do is the $-\log(\text{pdf})$. The form you'll see this in ML textbooks will be something like the following, which means the loss function is the log of the pdf of an example value 

$$
-\log(\text{pdf}(y_{true} | \theta))
$$

where $\theta$ represents the predictions of your neural network, namely, mean and variance. It's really that simple.

This happens to apply to any probability distribution function, not just the normal one... see below for a description of how to use this equation for classification. 

So just to reiterate! $-log(pdf)$ simplifies to MSE when the variance is constant. MSE is perfectly consistent with the MLE concept of NLL, and the negative log just happens to cancel out with an expontent found in the PDF function of the Normal distribution, leaving us a simple equation. 

## Additional Notes (advanced)

BTW clasification is a much more common task for deep learning, and when you are classifying something, you are working with a Generalized Bernoulli distribution. The NLL in that case is still -log(pdf), and if you use the PDF of the Bernoulli distribution, you can directly get the equation for the crossentropy loss. 

Although you could also get to the equation using the concepts of Shannon's Entropy and KL Divergence, and in fact that's why it's called "entropy"... in the context of MLE and NLL, we could just as well call the equation the NLL of the Generalized Bernoulli distribution.

$$
\text{Cross-Entropy Loss} = - \left[ y \cdot \log(\hat{y}) + (1 - y) \cdot \log(1 - \hat{y}) \right]
$$

Of course, just because we know that NLL tells us the most likely parameters of a model... it does not gaurantee that backprop will get us there. For that, we need a property called convexity, and that's a whole other topic. Convexity exists for crossentropy when the activation function is sigmoid or softmax, and for MSE, but other distributions like Gaussian Mixture Models are prone to local minima. 

Some of the first ideas here came from econometrics, where they wanted to estimate time-series data but realized that during some periods of time, there is higher variance than others. This [book](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=db869fa192a3222ae4f2d766674a378e47013b1b) talks about it further in the context of ML, and of course we have to go to Fisher for MLE, Gauss and Laplace for CLT, and Rockafellar for Convexity. 

## Let's test this empirically. First, setup some data with different variances. 

We'll train a model to estimate the expected value and variance of the output of a function, based on the input. 

We'll have two clusters of data. If the input is between zero and one, the output is a normal distribution with mean 1 and stdev 1 (var 1).
If the input is between two and three, the output is a normal distribution with mean -1 and stdev 2 (var 4)

In [48]:
import torch
import torch.nn as nn

x1 = torch.rand(10000, 1) # Uniform between 0 and 1.
y1 = torch.randn(10000, 1) + 1

x2 = torch.rand(10000, 1) + 2 # Uniform between 2 and 3.
y2 = torch.randn(10000, 1) * 2 -1

x = torch.cat([x1, x2])
y = torch.cat([y1, y2])

### Now let's setup a data loader for that data. 

In [None]:
class Data(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

dataloader = torch.utils.data.DataLoader(Data(x, y), batch_size=100, shuffle=True, drop_last=True)

In [None]:
len(dataloader)

2000

## Now setup a model.

Nothing fancy here. Just an MLP (multilayer perceptron) with two heads. 

In [None]:
class Predictor(nn.Module):
    def __init__(self):
        super(Predictor, self).__init__()

        hidden_depth = 100

        self.dense1 = nn.Linear(1, hidden_depth)
        self.act1 = nn.Sigmoid()
        
        self.dense2 = nn.Linear(hidden_depth, hidden_depth)
        self.act2 = nn.Sigmoid()
        self.dense3 = nn.Linear(hidden_depth, hidden_depth)
        self.act3 = nn.Sigmoid()
        self.dense4 = nn.Linear(hidden_depth, hidden_depth)
        self.act4 = nn.Sigmoid()
        self.dense5 = nn.Linear(hidden_depth, hidden_depth)
        self.act5 = nn.Sigmoid()
        self.dense6 = nn.Linear(hidden_depth, hidden_depth)
        self.act6 = nn.Sigmoid()
        self.dense7 = nn.Linear(hidden_depth, hidden_depth)
        self.act7 = nn.Sigmoid()
        self.dense8 = nn.Linear(hidden_depth, hidden_depth)
        self.act8 = nn.Sigmoid()
        self.dense9 = nn.Linear(hidden_depth, hidden_depth)
        self.act9 = nn.Sigmoid()
        
        self.dense10 = nn.Linear(hidden_depth, hidden_depth)

        self.mean_y = nn.Linear(hidden_depth, 1)
        self.var_y = nn.Linear(hidden_depth, 1)
    
    def forward(self, x):
        x = self.act1(self.dense1(x))

        x = x + self.act2(self.dense2(x))
        x = x + self.act3(self.dense3(x))
        x = x + self.act4(self.dense4(x))
        x = x + self.act5(self.dense5(x))
        x = x + self.act6(self.dense5(x))
        x = x + self.act7(self.dense5(x))
        x = x + self.act8(self.dense5(x))
        x = x + self.act9(self.dense5(x))

        x = self.dense10(x)

        mean = self.mean_y(x)
        var = torch.exp(self.var_y(x)) # var should be positive
        return mean, var

model = Predictor()

for p in model.parameters():
    p.register_hook(lambda grad: torch.clamp(grad, -1.0, +1.0))

# Let's setup that fancy loss function

Again, it's not fancy at all, it's just NLL applied to the key parameters, or $\theta$, of a Gaussian distribution: mean and variance.

In [64]:
class MeanVarianceLoss(nn.Module):
    """Calculates the negative log likelihood of seeing a target value given a mean and a variance."""
    def __init__(self):
        super(MeanVarianceLoss, self).__init__()
    
    def forward(self, mean, var, target):
        normal = torch.distributions.Normal(mean, torch.sqrt(var))

        # For numerical stability, the torch distributions library
        # returns the log(pdf(target | mean, variance)), not the pdf(target | mean, variance) directly.
        log_prob = normal.log_prob(target)

        # NLL is the negative log probability
        nll = -log_prob

        return torch.mean(nll)

criterion = MeanVarianceLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1000.)

from tqdm import tqdm

for epoch in range(100):
    if 0 <= epoch < 20:
        optimizer.param_groups[0]['lr'] = 1e-3
    elif 20 <= epoch < 40:
        optimizer.param_groups[0]['lr'] = 1e-4
    elif 40 <= epoch < 60:
        optimizer.param_groups[0]['lr'] = 1e-5
    elif 60 <= epoch < 80:
        optimizer.param_groups[0]['lr'] = 1e-6
    else:
        optimizer.param_groups[0]['lr'] = 1e-7

    for x, y in dataloader:
        optimizer.zero_grad()
        mean, var = model(x)
        loss = criterion(mean, var, y)
        loss.backward()
        optimizer.step()
    if epoch % 10 == 0: 
        print(f"Epoch {epoch}, Loss: {loss.item()}")

Epoch 0, Loss: 2.313269853591919
Epoch 10, Loss: 1.6423704624176025
Epoch 20, Loss: 2.411954402923584
Epoch 30, Loss: 1.5081806182861328
Epoch 40, Loss: 1.5562080144882202
Epoch 50, Loss: 1.6598745584487915
Epoch 60, Loss: 1.8379545211791992
Epoch 70, Loss: 1.839219093322754
Epoch 80, Loss: 1.7417705059051514
Epoch 90, Loss: 1.529351830482483


## Did it work? Let's spot check it's estimate of mean and stdev on some numbers from our two sets.     

In [65]:
mean, var = model(torch.tensor([0.5]))
print(f"Estimate of mean and var for first cluster of data should be mean=1, var=1.")
print(f"Mean estimate: {mean.item()}, Var estimate: {var.item()}")
print()

mean, var = model(torch.tensor([2.5]))
print(f"Estimate of mean and var for second cluster of data should be mean=-1, var=4.")
print(f"Mean estimate: {mean.item()}, Var estimate: {var.item()}")

Estimate of mean and var for first cluster of data should be mean=1, var=1.
Mean estimate: 1.0142967700958252, Var estimate: 0.8464043736457825

Estimate of mean and var for second cluster of data should be mean=-1, var=4.
Mean estimate: -0.8514268398284912, Var estimate: 3.932140827178955


It worked!