# The Varieties of Covariance

The parameterization of spherical, diagonal, and full covariances offers a progression of richer receptive fields.

These forms of the covariance $\Sigma$ are

- **spherical** $\Sigma = \begin{bmatrix}\sigma & 0 \\ 0 & \sigma\end{bmatrix}$,
- **diagonal** $\Sigma = \begin{bmatrix}\sigma_y & 0 \\ 0 & \sigma_x\end{bmatrix}$, and
- **full** $\Sigma = \begin{bmatrix}\sigma_y & \rho \\ \rho & \sigma_x\end{bmatrix}$,

with $\rho$ the correlation of $x, y$.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image

import torch
import torch.nn.functional as F

# set display defaults
plt.rcParams['figure.figsize'] = (10, 10)        # large images
plt.rcParams['image.interpolation'] = 'nearest'  # don't interpolate: show square pixels
plt.rcParams['image.cmap'] = 'gray'  # use grayscale output rather than a (potentially misleading) color heatmap

# work from project root for local imports
import os
import sys
import subprocess
from pathlib import Path

# root here refers to the segmentron-master folder
root_dir = Path(subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).strip().decode("utf-8"))
root_dir = root_dir / "segmentron-master"
os.chdir(root_dir)
sys.path.append(str(root_dir))

from sigma.blur import blur2d_full, gauss2d_full, sigma2logchol, logchol2sigma

torch.manual_seed(1337)

Before inspecting covariance, let's glance at variance in 1D, by forming Gaussians through the unnormalized density $\exp{\frac{-x^2}{2\sigma^2}}$.
We approximate the continuous distribution with the *sampled* Gaussian at linearly-interpolated points, truncate at two standard deviations to include 95\% of the density, then normalize.
Note there is a *discrete* Gaussian approach, derived through diffusion with discrete time, but that's a step for another day.

In [None]:
sigma = torch.tensor([1.0, 2., 4.])

plt.figure(figsize=(15, 5))
for i, s in enumerate(sigma, 1):
    # determine kernel size to cover +/- 2 sigma s.t. 95% of density is included
    half_size = int(max(1, torch.ceil(s * 2.)))
    # always make odd kernel to keep coordinates centered
    kernel_size = half_size*2 + 1
    # calculate unnormalized density then normalize
    x = torch.linspace(-half_size, half_size, steps=kernel_size)
    filter_ = torch.exp(-x**2 / (2*s**2))
    filter_sum = filter_.sum()
    filter_norm = filter_ / filter_sum

    plt.subplot(1, 3, i)
    plt.title(f"Gaussian $\sigma=${s:0.1f}")
    plt.plot(x.numpy(), filter_norm.numpy())

Now let's consider spherical covariance in 2D, which analogously adjusts the scale of the filter.
This case reduces to the product of a 1D Gaussian, because spherical covariance gives an *isotropic* distribution that is identical in every direction. 

In [None]:
sigma = torch.tensor([1.0, 2., 4.])

half_size = int(max(1, torch.ceil(sigma.max() * 2.))) * 1.1
x = torch.linspace(-half_size, half_size, steps=kernel_size * 3)

plt.figure(figsize=(15, 5))
for i, s in enumerate(sigma, 1):
    # calculate unnormalized density then normalize
    filter_ = torch.exp(-x**2 / (2*s**2))
    # 2D is product of 1D b.c. this is isotropic
    filter_ = filter_.view(-1, 1) @ filter_.view(1, -1)
    filter_sum = filter_.sum()
    filter_norm = filter_ / filter_sum

    plt.subplot(1, 3, i)
    plt.title(f"Gaussian $\sigma=${s:0.1f}")
    plt.contour(filter_norm.numpy(), cmap='viridis')
    plt.axis('off')

There are useful special cases of the Gaussian for composition:

- $\sigma \to 0$ gives the delta filter, yielding the identity;
- $\sigma < 1$ gives a small, identity-like filter, which can learn to be sharper or smoother, for a good initialization;
- $\sigma \to \infty$ gives average pooling;

so that smoothing is learnable, can reduce to the identity if no smoothing is desired, and can reduce to averaging (approximately) if global pooling is desired.

In [None]:
sigma = torch.tensor([0.1, 0.7, 16.])

half_size = 7  # hardcode to focus on center of filter
kernel_size = half_size * 2 + 1
x = torch.linspace(-half_size, half_size, steps=kernel_size)

plt.figure(figsize=(15, 5))
for i, s in enumerate(sigma, 1):
    # calculate unnormalized density then normalize
    filter_ = torch.exp(-x**2 / (2*s**2))
    # 2D is product of 1D b.c. this is isotropic
    filter_ = filter_.view(-1, 1) @ filter_.view(1, -1)
    filter_sum = filter_.sum()
    filter_norm = filter_ / filter_sum

    plt.subplot(1, 3, i)
    plt.title(f"Gaussian $\sigma=${s:0.1f}")
    plt.imshow(filter_norm[5:10, 5:10].numpy(), vmin=0, cmap='gray')
    plt.axis('off')
    plt.tight_layout()

For general covariance, we need the general multivariate Gaussian density (again, unnormalized): $\text{det}(\Sigma)^{1/2} \exp\left\{-\frac{1}{2} x' \Sigma^{-1} x\right\}$.
Let's inspect densities with spherical, diagonal, and full covariances.

In [None]:
sigma_sphere = torch.diag(torch.Tensor([2., 2.]))
sigma_diag = torch.diag(torch.Tensor([2., 4.]))
sigma_full = torch.Tensor([[1., 1.5],
                           [1.5, 4.]])
sigmas = list(zip(('sphere', 'diag.', 'full'), (sigma_sphere, sigma_diag, sigma_full)))

In [None]:
# see sigma.blur.gauss2d_full
def sigma_filter(sigma):
    half_size = torch.ceil(sigma.diag()**(0.5) * 2.).clamp(min=1.).int()
    kernel_size = half_size*2 + 1
    y = torch.linspace(half_size[0], -half_size[0], steps=kernel_size[0])
    x = torch.linspace(-half_size[1], half_size[1], steps=kernel_size[1])
    coords = torch.stack(torch.meshgrid(y, x), dim=-1).view(-1, 2)
    # vectorize quadratic form x^T sigma x by matmul-product-sum
    filter_ = torch.det(sigma)**(-0.5) * torch.exp(-(0.5)*(((coords @ sigma.inverse()) * coords).sum(1)))
    filter_ /= filter_.sum()
    filter_ = filter_.view(*kernel_size)
    return filter_
    
plt.figure(figsize=(15, 5))

filters = []
for name, s in sigmas:
    filters.append((name, sigma_filter(s)))
    
max_size, _ = torch.stack([torch.tensor(f.size()) for name, f in filters]).max().repeat(2)
for i, (name, f) in enumerate(filters, 1):
    plt.subplot(1, 3, i)
    plt.title(name)
    # pad to equal size for display
    pad_h, pad_w = (max_size - torch.tensor(f.size()))  // 2
    f = torch.nn.functional.pad(f, (pad_w, pad_w, pad_h, pad_h))
    plt.imshow(f.numpy())
    plt.axis('off')

Instead of calculating the density at linearly-interpolated coordinates, let's inspect the placement of coordinates by mapping the coordinates of a standard Gaussian through the covariance.
First, let's look out the coordinates for a 1D Gaussian when linearly interpolating through the inverse cumulative distribution function, for a sense of the spacing.

In [None]:
sigma = torch.tensor([1., 2., 4.])
eps = torch.tensor(0.025)  # two sigma coverage for 95% of density

In [None]:
num_steps = 5

fig = plt.figure(figsize=(15, 1))
for i, s in enumerate(sigma, 1):
    normal = torch.distributions.normal.Normal(torch.tensor(0.), s)
    x = normal.icdf(torch.linspace(eps, -eps + 1, steps=num_steps))
    subplt = plt.subplot(1, 3, i)
    plt.title(f"$\sigma$ = {s:}")
    plt.plot(x.numpy(), torch.zeros_like(x).numpy(), linestyle='none', marker='.')
    plt.ylim(-10., 10)
    plt.xlim(-16., 16)
    subplt.get_yaxis().set_visible(False) 

For 2D, we switch to polar coordinates and choose the radii according to the spacing above.
We choose the number of angles for angular resolution, and the number of steps for distance resolution (where steps count the rings from the center).
We take a two sigma truncation and plot the coordinates for different standard deviations.

In [None]:
import math

num_angles = 8
num_steps = 2

plt.figure(figsize=(15, 5))
for i, s in enumerate(sigma, 1):
    normal = torch.distributions.normal.Normal(torch.tensor(0.), s)
    angles = torch.linspace(0., 2*math.pi, steps=num_angles + 1)[:-1]  # exclude end == beginning
    radii = normal.icdf(torch.linspace(0.5, -eps + 1, steps=num_steps + 1))[1:]  # steps

    rho, theta = torch.meshgrid((radii, angles))
    y = rho * torch.sin(theta)
    x = rho * torch.cos(theta)
    y = torch.cat((torch.tensor([0.]), y.view(-1)))
    x = torch.cat((torch.tensor([0.]), x.view(-1)))

    plt.subplot(1, 3, i)
    plt.title(f"$\sigma$ = {s:}")
    plt.plot(x.numpy(), y.numpy(), linestyle='none', marker='.')
    plt.ylim(-10, 10)
    plt.xlim(-10, 10)
    plt.axis('off')

Mapping the coordinates of the standard Gaussian through the covariance gives us a sampled approximation to that Gaussian.
Adjusting the number of steps and angles controls the quality for trading accuracy and computation.
Coloring the points shows the correspondences among the covariances.

In [None]:
num_steps = 1
num_angles = 8

# make a standard Gaussian
normal = torch.distributions.normal.Normal(torch.tensor(0.), 1.)
angles = torch.linspace(0., 2*math.pi, steps=num_angles + 1)[:-1]
radii = normal.icdf(torch.linspace(0.5, -eps + 1, steps=num_steps + 1))[1:]

rho, theta = torch.meshgrid((radii, angles))
y = rho * torch.sin(theta)
x = rho * torch.cos(theta)
coords = torch.cat((torch.zeros(1, 2),
                    torch.stack((y, x), dim=-1).view(-1, 2)))

# define a variety of full covariances
sigmas = [
        torch.tensor([1., 0., 0., 1.]).view(-1, 2),
        torch.tensor([0.1, 0., 0., 0.1]).view(-1, 2),
        torch.tensor([2., 0., 0., 2.]).view(-1, 2),
        torch.tensor([2., 0., 0., 0.5]).view(-1, 2),
        torch.tensor([0.5, 0., 0., 2.0]).view(-1, 2),
        torch.tensor([1., 0.5, 0.5, 1.]).view(-1, 2),
        torch.tensor([1., -0.5, -0.5, 1.]).view(-1, 2),
]

# transform standard coordinates by covariances
lim = 5.
plt.figure(figsize=(3*len(sigmas), 3))
for i, s in enumerate(sigmas, 1):
    s_coords = coords.clone() @ s
    y, x = torch.unbind(s_coords, dim=1)

    plt.subplot(1, len(sigmas), i)
    for j in range(len(s_coords)):
        plt.plot(x[j].numpy(), y[j].numpy(), linestyle='none', marker='.', markersize=12)
    plt.xlim(-lim, lim)
    plt.ylim(-lim, lim)
    plt.axis('off')

To compose our approximate Gaussian coordinates with a filter, we must map from the Gaussian points to the filter taps.
Here we make a mask group the center point and tap together and group the other points and taps by angle.
We normalize the mask such that each tap has equal weighting and the total weighting of the points is one.

In [None]:
# points x taps for 3x3 kernel
num_taps = 3 * 3
mask = torch.zeros(num_steps * num_angles + 1, num_taps).float()
# assign center point to center tap
mask[0, num_taps // 2] = 1.
# assign angles to their taps: 
# - angles go in counter-clockwise order from center-right
# - taps go in row major order from top-left
mask[1::num_angles, 5] = 1.
mask[2::num_angles, 2] = 1.
mask[3::num_angles, 1] = 1.
mask[4::num_angles, 0] = 1.
mask[5::num_angles, 3] = 1.
mask[6::num_angles, 6] = 1.
mask[7::num_angles, 7] = 1.
mask[8::num_angles, 8] = 1.
# normalize so every point-tap dot has equal weight,
# and all the points together sum to one, so that 
# the input magnitude is unchanged
mask /= mask.sum(0).view(1, -1)
mask /= mask.sum()
mask

## Parameterizing $\Sigma$ for Learning

The covariance $\Sigma$ is not any given $d \times d$ matrix, and in particular it's positive definite, so we have to properly parametrize it for unconstrained optimization. For a primer on covariance parameterization, see 

> Unconstrained Parameterizations for Variance-Covariance Matrices. Pinheiro & Bates 1996.
 
In our case, the log-Cholesky parameterization is a good choice because it's simple and quick:
$\Sigma = U'U$ for upper-triangular $U$ with positive diagonal.
We can keep the diagonal positive by representing the log of the diagonal (hence log-Cholesky), and exp'ing it.

In [None]:
# generate random covariances in log-Cholesky form
params = torch.randn(3)

U = torch.zeros((2, 2))
triu_indices = [[0, 0, 1], [0, 1, 1]]
U[triu_indices] = params
U.diagonal().exp_()
U_inv = U.inverse()

sigma = U.t() @ U
sigma_inv = U_inv @ U_inv.t()

print(f"parameters: {params}")
print(f"sigma:\n{sigma}")

Convert known sigma to log-Cholesky parameterization to check that the original and log-Cholesky kernels agree.

In [None]:
def sigma_filter_logchol(params):
    # make cholesky factor and inverse
    U = torch.zeros((2, 2))
    triu_indices = [[0, 0, 1], [0, 1, 1]]
    U[triu_indices] = params
    U.diagonal().exp_()
    U_inv = U.inverse()
    # make filter
    half_size = torch.ceil((U.t() @ U).diag()**(0.5) * 2.).clamp(min=1.).int()
    kernel_size = half_size*2 + 1
    y = torch.linspace(half_size[0], -half_size[0], steps=kernel_size[0])
    x = torch.linspace(-half_size[1], half_size[1], steps=kernel_size[1])
    coords = torch.stack(torch.meshgrid(y, x), dim=-1).view(-1, 2)
    filter_ = torch.det(U)**-1. * torch.exp(-(0.5)*(((coords @ (U_inv @ U_inv.t())) * coords).sum(1)))
    filter_ /= filter_.sum()
    filter_ = filter_.view(*kernel_size)
    return filter_

# make filter through sigma
filter_ = sigma_filter(sigma_full)

# take cholesky decomposition of sigma, extract log-Cholesky params
full_chol = torch.cholesky(sigma_full, upper=True).contiguous().view(-1)
full_params = torch.stack((torch.log(full_chol[0]), full_chol[1], torch.log(full_chol[-1])))
# make filter through log cholesky params
chol_filter_ = sigma_filter_logchol(full_params)

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title('original filter')
plt.imshow(filter_)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.title('log-Cholesky filter')
plt.imshow(chol_filter_)
plt.axis('off')

print("filter MSE: ", torch.mean((filter_ - chol_filter_)**2).item())

## Toy Experiment: Optimize $\Sigma$ to Recover Full Covariance Blur in 2D

To illustrate receptive field optimization via sigma with a toy problem, let's recover the size of a Gaussian blur from smoothed data in 2D.

1. Generate a random 2D signal and smooth it with a reference sigma.
2. Instantiate our filter with zero initialization of the sigma parameter.
3. Learn sigma by gradient descent.

In [None]:
x = torch.randn(1, 1, 64, 64)
true_sigma = sigma_full
xf = blur2d_full(x, sigma2logchol(true_sigma), std_devs=2).detach()

plt.figure(figsize=(10, 2))
plt.subplot(1, 2, 1)
plt.title('signal')
plt.imshow(x.squeeze().numpy())
plt.subplot(1, 2, 2)
plt.title('smoothed')
plt.imshow(xf.squeeze().numpy())

In [None]:
def plot_recovery(xf, xf_hat, g, iter_):
    plt.figure(figsize=(5, 2))
    plt.title("Recovery iter. {}".format(iter_))
    plt.subplot(1, 3, 1)
    plt.imshow(xf.squeeze().detach().numpy())
    plt.axis('off')
    plt.tight_layout()
    plt.subplot(1, 3, 2)
    plt.imshow(xf_hat.squeeze().detach().numpy())
    plt.axis('off')
    plt.tight_layout()
    plt.subplot(1, 3, 3)
    plt.imshow(g.squeeze().detach().numpy())
    plt.axis('off')
    plt.tight_layout()
    
cov = torch.nn.Parameter(torch.zeros(3))
opt = torch.optim.Adamax([cov], lr=0.1)

max_iter = 100
for iter_ in range(max_iter):
    xf_hat = blur2d_full(x, cov, std_devs=2)
    diff = xf_hat - xf
    loss = (diff**2).mean()
    loss.backward()
    opt.step()
    opt.zero_grad()
    
    g = gauss2d_full(cov, std_devs=2)
    
    if iter_ % 10 == 0:
        print(f"iter {iter_:04d} loss {loss.item()}")
    if iter_ in (0, 4, 16):
        plot_recovery(xf, xf_hat, g, iter_)
print(f"iter {iter_:04d} loss {loss.item()}")
plot_recovery(xf, xf_hat, g, iter_ + 1)


print("\ntrue sigma\n{}\nrecovered sigma\n{}".format(true_sigma.detach().numpy(), logchol2sigma(cov).detach().numpy()))

Let's check that optimizing over full covariance can recover a simpler, spherical covariance.

In [None]:
true_sigma = sigma_sphere
xf = blur2d_full(x, sigma2logchol(true_sigma)).detach()

plt.figure(figsize=(10, 2))
plt.subplot(1, 2, 1)
plt.title('signal')
plt.imshow(x.squeeze().numpy())
plt.subplot(1, 2, 2)
plt.title('smoothed')
plt.imshow(xf.squeeze().numpy())

In [None]:
def plot_recovery(xf, xf_hat, g, iter_):
    plt.figure(figsize=(5, 2))
    plt.title("Recovery iter. {}".format(iter_))
    plt.subplot(1, 3, 1)
    plt.imshow(xf.squeeze().detach().numpy())
    plt.axis('off')
    plt.tight_layout()
    plt.subplot(1, 3, 2)
    plt.imshow(xf_hat.squeeze().detach().numpy())
    plt.axis('off')
    plt.tight_layout()
    plt.subplot(1, 3, 3)
    plt.imshow(g.squeeze().detach().numpy())
    plt.axis('off')
    plt.tight_layout()
    
cov = torch.nn.Parameter(torch.zeros(3))
opt = torch.optim.Adamax([cov], lr=0.1)

max_iter = 100
for iter_ in range(max_iter):
    xf_hat = blur2d_full(x, cov)
    diff = xf_hat - xf
    loss = (diff**2).mean()
    loss.backward()
    opt.step()
    opt.zero_grad()
    
    g = gauss2d_full(cov)
    
    if iter_ % 10 == 0:
        print(f"iter {iter_:04d} loss {loss.item()}")
    if iter_ in (0, 4, 16):
        plot_recovery(xf, xf_hat, g, iter_)
print(f"iter {iter_:04d} loss {loss.item()}")
plot_recovery(xf, xf_hat, g, iter_ + 1)


print("\ntrue sigma\n{}\nrecovered sigma\n{}".format(true_sigma.detach().numpy(), logchol2sigma(cov).detach().numpy()))

For a little more visual interest and historical reference, let's recover a full covariance blur on a real image: a portrait of Carl Friedrich Gauss himself.

In [None]:
# load image of Gauss, convert to standard 4d float array, and normalize
im_gauss = torch.tensor(np.array(Image.open(root_dir / 'notebooks/gauss.jpg'))).float()
im_gauss = im_gauss.view(1, 1, *im_gauss.size())
im_gauss = (im_gauss - im_gauss.min()) / (im_gauss.max() - im_gauss.min())
# shrink and pad for convenience
im_gauss = F.interpolate(im_gauss, scale_factor=0.5, mode='bilinear', align_corners=False)
#im_gauss = F.pad(im_gauss, (50, 50, 50, 50), value=1.)

# blur Gauss with known Gaussian
true_sigma = sigma_full * 4
g = gauss2d_full(sigma2logchol(true_sigma))
blur_gauss = blur2d_full(im_gauss, sigma2logchol(true_sigma))

# crop to remove border effects
im_gauss = F.pad(im_gauss, (-20, -20, -20, -20), value=1.)
blur_gauss = F.pad(blur_gauss, (-20, -20, -20, -20), value=1.)

plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.imshow(im_gauss.squeeze().numpy())
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(blur_gauss.squeeze().numpy())
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(g.squeeze().numpy())
plt.axis('off')

x, xf = im_gauss, blur_gauss

In [None]:
def plot_recovery(xf, xf_hat, g, iter_):
    plt.figure(figsize=(5, 2))
    plt.title("Recovery iter. {}".format(iter_))
    plt.subplot(1, 3, 1)
    plt.imshow(xf.squeeze().detach().numpy())
    plt.axis('off')
    plt.tight_layout()
    plt.subplot(1, 3, 2)
    plt.imshow(xf_hat.squeeze().detach().numpy())
    plt.axis('off')
    plt.tight_layout()
    plt.subplot(1, 3, 3)
    plt.imshow(g.squeeze().detach().numpy())
    plt.axis('off')
    plt.tight_layout()
    
cov = torch.nn.Parameter(torch.zeros(3))
opt = torch.optim.Adamax([cov], lr=0.1)

max_iter = 256
for iter_ in range(max_iter):
    xf_hat = blur2d_full(x, cov)
    #diff = xf_hat - xf
    diff = xf_hat[..., 20:-20, 20:-20] - xf[..., 20:-20, 20:-20]
    loss = (diff**2).mean()
    loss.backward()
    opt.step()
    opt.zero_grad()
    
    g = gauss2d_full(cov)
    
    if iter_ % 10 == 0:
        print(f"iter {iter_:04d} loss {loss.item()}")
    if iter_ in (0, 4, 16):
        plot_recovery(xf, xf_hat, g, iter_)
print(f"iter {iter_:04d} loss {loss.item()}")
plot_recovery(xf, xf_hat, g, iter_ + 1)


print("\ntrue sigma\n{}\nrecovered sigma\n{}".format(true_sigma.detach().numpy(), logchol2sigma(cov).detach().numpy()))