# Estimating mean and variance through back propagation.

Although deep learning typically solves classification problems, it is still a powerful approach
regression problems. As any entry level ML course teaches you, the loss function for regression
is MSE (mean squared error). 

One of the core reasons MSE is the right loss is because it assumes the distribution of errors follows
a normal distribution, and that the variance (or standard deviation) of those errors is constant. This is a good assumption because the CLT (central limit theorem) sort of says that the "average" distribution is the normal distribution. 

Looking at the PDF (probability distribution function) of a Normal distribution:

$$
f(x \mid \mu, \sigma^2) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x - \mu)^2}{2\sigma^2}\right)
$$

We see hidden in there the square of the error: ${(x - \mu)^2}$. That's why we use the Squared Error, and then we just take the Mean of all the examples. That's how we end up using MSE. 

Why can we ignore the variance or $\sigma^2$? Because an assumption is that the variance of the error is constant. That's called "homoscedastic" - homo meaning the same like in homonym. So when we calculate the gradient, we can ignore any constant value. 

What about that exponent function? Well, training neural networks is typically based on MLE (Maximum Likelihood Estimation), which says that whatever parameters give you the highest joint probability matching the output are the most likely params. Joint probability involves taking the product of lots of small numbers - for numerical stability you can log the whole thing to make it a sum instead, that's one reason we find the log function in the classic crossentropy loss used for classification.

$$
\text{Cross-Entropy Loss} = - \left[ y \cdot \log(\hat{y}) + (1 - y) \cdot \log(1 - \hat{y}) \right]
$$

It's also why MSE doesn't have any pesky exponent functions - a log is applied to it, and the log and exponent basically cancel each other out. 

If the variance is not constant then the data has "heteroscedasticity" - hetero means different, like in heterogeneous, and we actually run into lots of problems when trying to do things like t-tests, but that's not the focus of this article.

The focus is how to get a back-propped model to estimate variance. The answer is simply to output two heads - one for mean, one for variance, and undo the two shortcuts we applied above (constant variance, log of exponent). Basically, all we have to do is the NLL of the PDF ((negative of the log of the likelihood of the probability distribution function, where probability becomes likelihood when we apply MLE). Again: it's just -log(pdf(x | mean & var)). It's really that simple.

And just to be explicit: -log(pdf) simplifies to MSE when the variance is constant. That's why the use of MSE is perfectly consistent with MLE and it's use of NLL.  

Some of the first ideas here came from econometrics, where they wanted to estimate time-series data but realizes that during some periods of time, there is higher variance than others. This [book](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=db869fa192a3222ae4f2d766674a378e47013b1b) talks about it further in the context of ML, and of course we have to go to Fisher for MLE and Gauss and Laplace for CLT. 

## First, setup some data with different variance. 

We'll train a model to estimate the expected value and variance of the output of a function, based on the input. 

We'll have two clusters of data. If the input is between zero and one, the output is a normal distribution with mean 1 and stdev 1 (var 1).
If the input is between ten and eleven, the output is a normal distribution with mean -1 and stdev 2 (var 4)

In [128]:
import torch
import torch.nn as nn

x1 = torch.randint(0, 1, (10000, 1)).float()
y1 = torch.randn(10000, 1) + 1

x2 = torch.randint(10, 11, (10000, 1)).float()
y2 = torch.randn(10000, 1) * 2 -1

x = torch.cat([x1, x2])
y = torch.cat([y1, y2])

### Now let's setup a data loader for that data. 

In [129]:
class Data(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

dataloader = torch.utils.data.DataLoader(Data(x, y), batch_size=100, shuffle=True, drop_last=True)

## Now setup a model and loss function: NLL(PDF)

In [142]:
class Predictor(nn.Module):
    def __init__(self):
        super(Predictor, self).__init__()

        self.dense1 = nn.Linear(1, 10)
        self.tanh1 = nn.Tanh()
        self.dense2 = nn.Linear(10, 10)
        self.tanh2 = nn.Tanh()
        self.dense3 = nn.Linear(10, 10)
        self.tanh3 = nn.Tanh()
        self.dense4 = nn.Linear(10, 10)
        self.tanh4 = nn.Tanh()
        self.dense5 = nn.Linear(10, 10)

        self.mean_y = nn.Linear(10, 1)
        self.var_y = nn.Linear(10, 1)
    
    def forward(self, x):
        x = self.dense1(x)
        x = self.tanh1(x)
        x = self.dense2(x)
        x = self.tanh2(x)
        x = self.dense3(x)
        x = self.tanh3(x)
        x = self.dense4(x)
        x = self.tanh4(x)
        x = self.dense5(x)


        mean = self.mean_y(x)
        var = torch.exp(self.var_y(x)) + 1e-2 # Variance is always positive, so we activate with the exp function, not sigmoid/softmax.
        return mean, var

model = Predictor()

for p in model.parameters():
    p.register_hook(lambda grad: torch.clamp(grad, -1.0, +1.0))


class MeanVarianceLoss(nn.Module):
    def __init__(self):
        super(MeanVarianceLoss, self).__init__()
    
    def forward(self, mean, var, target):
        normal = torch.distributions.Normal(mean, torch.sqrt(var))

        # For numerical stability, the torch distributions library
        # returns the log(pdf(target | mean, variance)), not the pdf(target | mean, variance) directly.
        log_prob = normal.log_prob(target)

        # NLL is the negative log probability
        nll = -log_prob

        return torch.mean(nll)

criterion = MeanVarianceLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1000.)

from tqdm import tqdm

for epoch in range(100):
    if 0 <= epoch < 20:
        optimizer.param_groups[0]['lr'] = 1e-3
    elif 20 <= epoch < 40:
        optimizer.param_groups[0]['lr'] = 1e-4
    elif 40 <= epoch < 60:
        optimizer.param_groups[0]['lr'] = 1e-5
    elif 60 <= epoch < 80:
        optimizer.param_groups[0]['lr'] = 1e-6
    else:
        optimizer.param_groups[0]['lr'] = 1e-7  

    for x, y in dataloader:
        optimizer.zero_grad()
        mean, var = model(x)
        loss = criterion(mean, var, y)
        loss.backward()
        optimizer.step()
    if epoch % 10 == 0: 
        print(f"Epoch {epoch}, Loss: {loss.item()}")

Epoch 0, Loss: 2.5946478843688965
Epoch 10, Loss: 1.9665488004684448
Epoch 20, Loss: 1.7211822271347046
Epoch 30, Loss: 1.8307368755340576
Epoch 40, Loss: 1.676632046699524
Epoch 50, Loss: 1.7625548839569092
Epoch 60, Loss: 1.8011854887008667
Epoch 70, Loss: 1.6816025972366333
Epoch 80, Loss: 1.6646822690963745
Epoch 90, Loss: 1.7892930507659912


## Did it work? Let's spot check it's estimate of mean and stdev on some numbers from our two sets.     

In [143]:
mean, var = model(torch.tensor([0.5]))
print(f"Estimate of mean and var for x = 1 should be mean=1, var=1. Mean estimate: {mean.item()}, Var estimate: {var.item()}")

mean, var = model(torch.tensor([10.5]))
print(f"Estimate of mean and var for x = 10.5 should be mean=-1, var=4. Mean estimate: {mean.item()}, Var estimate: {var.item()}")

Estimate of mean and var for x = 1 should be mean=1, var=1. Mean estimate: 0.8190476298332214, Var estimate: 1.3436105251312256
Estimate of mean and var for x = 10.5 should be mean=-1, var=4. Mean estimate: -0.8588033318519592, Var estimate: 4.115914344787598
