# Adaptive Blur by Local Filtering

For exact adaptive blurring, instantiate a Gaussian for every kernel window, making a blur for each window.
This local filtering is carried out by matrix multiplication with the `im2col` matrix of the input and the stacked blur kernels.

This is not efficient in time or memory: every blur kernel is fit into the same max size and sampled separately.

In [None]:
import torch
import torch.nn.functional as F

import matplotlib.pyplot as plt
%matplotlib inline

# set display defaults
plt.rcParams['figure.figsize'] = (10, 10)        # large images
plt.rcParams['image.interpolation'] = 'nearest'  # don't interpolate: show square pixels
plt.rcParams['image.cmap'] = 'gray'  # use grayscale output rather than a (potentially misleading) color heatmap

# work from project root for local imports
import os
import sys
import subprocess
from pathlib import Path

# root here refers to the segmentron-master folder
root_dir = Path(subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).strip().decode("utf-8"))
root_dir = root_dir / "segmentron-master"
os.chdir(root_dir)
sys.path.append(str(root_dir))

from sigma.blur import blur2d_full, blur2d_local_full, blur2d_local_sphere
from sigma.blur import gauss2d_full, sigma2logchol, logchol2sigma

torch.manual_seed(1337)

Let's make a Gaussian blur through the log-Cholesky parameterization, and then convolve with it.

In [None]:
sigma_sphere = torch.diag(torch.Tensor([2., 2.]))
sigma_diag = torch.diag(torch.Tensor([2., 4.]))
sigma_full = torch.Tensor([[1., 1.],
                           [1., 4.]])

In [None]:
chol_filter_ = gauss2d_full(sigma2logchol(sigma_full))
kh, kw = chol_filter_.size()[-2:]

plt.figure(figsize=(5, 5))
plt.title('log-Cholesky filter')
plt.imshow(chol_filter_.squeeze().numpy())
plt.axis('off')

Let's review convolution by matrix multiplication:

1. unroll the image by extracting patches, flattening, and stacking into the `im2col` matrix.
2. consider the filter weights as a out x in x height x width matrix, and flatten the trailing spatial dimensions
3. convolve by matrix multiplication of the image matrix and weight matrix
4. roll the output matrix back into an image by restoring the spatial dimensions.

Why convolve this way? It takes advantage of highly-tuned matrix multiplication routines, and we can generalize to local *non-convolutional* filtering by adding a further dimension to the weight matrix.

In [None]:
im = torch.randn(2, 1, 64, 64)
in_mat = F.unfold(im, kernel_size=(kh, kw), padding=(kh // 2, kw // 2))
weight_mat = chol_filter_.view(1, -1, 1).repeat(2, 1, 1)
out_mat = torch.bmm(in_mat.permute(0, 2, 1), weight_mat)
out = out_mat.view(im.size())

print(f"image matrix size {tuple(weight_mat.size())} weight matrix size {tuple(weight_mat.size())}")

plt.figure()
plt.imshow(im[1].squeeze().numpy())
plt.figure()
plt.imshow(out[1].squeeze().numpy())

Instantiate many sigmas at once through the log-Cholesky parameterization.

In [None]:
params = torch.randn(9, 3) * 0.5
blurs = gauss2d_full(params)
plt.figure()
for i, b in enumerate(blurs, 1):
    plt.subplot(3, 3, i)
    plt.imshow(b.squeeze().numpy())
    plt.axis('off')

To make an exact, adaptive blur we go from convolutional filtering to local filtering.
In local filtering, the filter varies with the indices, so that different locations have different kernels.
Note this requires instantiating a kernel for each blur, and stacking them all into a matrix, so computation and memory scales with the *maximum* kernel size, which is clearly not ideal.

In [None]:
im = torch.randn(2, 3, 32, 32)
# instantiate a blur for each position
batch_size, channel_size, spatial_size = im.size(0), im.size(1), im.size(2) * im.size(3)
params = torch.randn(batch_size * spatial_size, 3) * 0.2
blurs = gauss2d_full(params)
kh, kw = blurs.size()[-2:]

# form input matrix by unrolling image, pulling out channel, 
# then stacking everything but the kernel into the batch 
in_mat = F.unfold(im, kernel_size=(kh, kw), padding=(kh // 2, kw // 2))
in_mat = in_mat.view(batch_size, channel_size, kh * kw, spatial_size)
in_mat = in_mat.permute(0, 1, 3, 2).contiguous().view(im.numel(), 1, kh * kw)
# form weight matrix by pulling out batch and flattening kernel,
# repeating across channels, and then absorbing channels into the batch.
weight_mat = blurs.view(batch_size, 1, spatial_size, -1)
weight_mat = weight_mat.repeat(1, channel_size, 1, 1).view(-1, kh * kw, 1)

# matmul + reshape takes the product of each input-filter pair
# and restores the spatial dimensions
out_mat = torch.matmul(in_mat, weight_mat)
out = out_mat.view(im.size())

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title('Image Input')
plt.imshow(im[1, 2].squeeze().numpy())
plt.subplot(1, 2, 2)
plt.title('Locally Blurred Output')
plt.imshow(out[1, 2].squeeze().numpy())

## Toy Experiment: Optimize $\Sigma$ to Recover Local Full Covariance Blurs in 2D

To illustrate adaptive blurring via optimizing sigma with a toy problem, let's recover the covariances of local blurs from smoothed data in 2D.

1. Generate a random 2D signal and smooth it with reference sigmas, in different quadrants.
2. Instantiate our filters with zero initialization of the covariance parameters, which is equivalent to identity covariance.
3. Learn sigmas by gradient descent.

In [None]:
x = torch.randn(1, 1, 32, 32)
quarter_spatial = (x.size(2) * x.size(3)) // 4
sphere_params = sigma2logchol(sigma_sphere)
diag_params = sigma2logchol(sigma_diag)
full_params = sigma2logchol(sigma_full)
true_cov = torch.cat((sphere_params.repeat(quarter_spatial, 1),
                      diag_params.repeat(quarter_spatial, 1),
                      full_params.repeat(quarter_spatial, 1),
                      torch.randn(quarter_spatial, 3) * 0.2), dim=0)
xf = blur2d_local_full(x, true_cov, std_devs=2).detach()

plt.figure(figsize=(10, 2))
plt.subplot(1, 2, 1)
plt.title('signal')
plt.imshow(x.squeeze().numpy())
plt.subplot(1, 2, 2)
plt.title('smoothed')
plt.imshow(xf.squeeze().numpy())

In [None]:
def plot_recovery(xf, xf_hat, iter_):
    plt.figure(figsize=(5, 2))
    plt.title("Recovery iter. {}".format(iter_))
    plt.subplot(1, 2, 1)
    plt.imshow(xf.squeeze().detach().numpy())
    plt.subplot(1, 2, 2)
    plt.imshow(xf_hat.squeeze().detach().numpy())
    
cov = torch.nn.Parameter(torch.zeros(x.numel() // x.size(1), 3))
opt = torch.optim.Adamax([cov], lr=0.1)

max_iter = 100
for iter_ in range(max_iter):
    xf_hat = blur2d_local_full(x, cov, std_devs=2)
    diff = xf_hat - xf
    loss = (diff**2).mean()
    loss.backward()
    opt.step()
    opt.zero_grad()
    
    if iter_ % 10 == 0:
        print('loss ', loss.item())
    if iter_ in (0, max_iter // 4, max_iter // 2):
        plot_recovery(xf, xf_hat, iter_)
plot_recovery(xf, xf_hat, iter_ + 1)

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Recovered - Reference Squared Error")
plt.imshow(((xf - xf_hat)**2).detach().squeeze().numpy())
plt.subplot(1, 2, 2)
plt.title('Covariance Parameter MSE')
plt.imshow((true_cov - cov).pow(2.).mean(-1).view(x.size()[2:]).detach().numpy())

Let's simplify by restricting the true covariance to spherical, and recovering it by spherical local blur and full local blur.

In [None]:
quarter_spatial = (x.size(2) * x.size(3)) // 4
true_scales = torch.cat((torch.full((quarter_spatial,), -1.),
                         torch.full((quarter_spatial,), -0.5),
                         torch.full((quarter_spatial,), 0.),
                         torch.full((quarter_spatial,), 0.5)), dim=0)
xf = blur2d_local_sphere(x, true_scales.exp(), std_devs=2).detach()

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title('signal')
plt.imshow(x.squeeze().numpy())
plt.axis('off')
plt.subplot(1, 2, 2)
plt.title('smoothed')
plt.imshow(xf.squeeze().numpy())
plt.axis('off')

In [None]:
def plot_recovery(xf, xf_hat, iter_):
    plt.figure(figsize=(5, 2))
    plt.title("Recovery iter. {}".format(iter_))
    plt.subplot(1, 2, 1)
    plt.imshow(xf.squeeze().detach().numpy())
    plt.subplot(1, 2, 2)
    plt.imshow(xf_hat.squeeze().detach().numpy())
    
scales = torch.nn.Parameter(torch.zeros(x.numel() // x.size(1)))
opt = torch.optim.Adamax([scales], lr=0.1)

max_iter = 100
for iter_ in range(max_iter):
    xf_hat = blur2d_local_sphere(x, scales.exp(), std_devs=2)
    diff = xf_hat - xf
    loss = (diff**2).mean()
    loss.backward()
    opt.step()
    opt.zero_grad()
    
    if iter_ % 10 == 0:
        print('loss ', loss.item())
    if iter_ in (0, max_iter // 4, max_iter // 2):
        plot_recovery(xf, xf_hat, iter_)
plot_recovery(xf, xf_hat, iter_ + 1)

plt.figure(figsize=(5, 5))
plt.title("Recovered Spherical - Smoothed")
plt.imshow(((xf - xf_hat)**2).detach().squeeze().numpy())

In [None]:
def plot_recovery(xf, xf_hat, iter_):
    plt.figure(figsize=(5, 2))
    plt.title("Recovery iter. {}".format(iter_))
    plt.subplot(1, 2, 1)
    plt.imshow(xf.squeeze().detach().numpy())
    plt.subplot(1, 2, 2)
    plt.imshow(xf_hat.squeeze().detach().numpy())
    
cov = torch.nn.Parameter(torch.zeros(x.numel() // x.size(1), 3))
opt = torch.optim.Adamax([cov], lr=0.1)

max_iter = 100
for iter_ in range(max_iter):
    xf_hat = blur2d_local_full(x, cov, std_devs=2)
    diff = xf_hat - xf
    loss = (diff**2).mean()
    loss.backward()
    opt.step()
    opt.zero_grad()
    
    if iter_ % 10 == 0:
        print('loss ', loss.item())
    if iter_ in (0, max_iter // 4, max_iter // 2):
        plot_recovery(xf, xf_hat, iter_)
plot_recovery(xf, xf_hat, iter_ + 1)

plt.figure(figsize=(5, 5))
plt.title("Recovered Full - Smoothed")
plt.imshow(((xf - xf_hat)**2).detach().squeeze().numpy())

Let's have a look at the learned covariances, first by looking at the mean at a given y coord.
From left-to-right is top-to-bottom, and the covariance grows with y, as expected given the true blur.

In [None]:
im_cov = cov.view(32, 32, 3)
plt.figure(figsize=(10, 10))
for i, y_ in enumerate([0, 8, 16, 31], 1):
    f = gauss2d_full(im_cov[y_, :].mean(0), std_devs=2)
    plt.subplot(1, 4, i)
    plt.title(f"y = {y_}, ks = {tuple(f.size())}")
    plt.imshow(f.detach().squeeze().numpy())
    plt.axis('off')
    plt.tight_layout()

Now let's examine individual covariances more closely.
There's a perhaps surprising amount of diversity, indicating that more data or regularization could help.
Granted, this is a tiny toy experiment.

In [None]:
im_cov = cov.view(32, 32, 3)
plt.figure(figsize=(10, 10))
for i, y_ in enumerate([0, 8, 16, 31], 1):
    for j, x_ in enumerate([0, 8, 16, 31], 1):
        f = gauss2d_full(im_cov[y_, x_], std_devs=2)
        plt.subplot(4, 4, (i-1)*4 + j)
        plt.title(f"y, x = {y_}, {x_} ks = {tuple(f.size())}")
        plt.imshow(f.detach().squeeze().numpy())
        plt.axis('off')
        plt.tight_layout()