# Adaptive Blur by Mixing Min/Max

For a recklessly approximate approach to adaptivity, mix minimum and maximum blurs by weighting with a convex combination.
The weights can be set across location to locally adapt scale.

This approximation is *rough*.
The result for $\frac{1}{2}\sigma_{\text{min}} + \frac{1}{2}\sigma_{\text{max}}$ is not the same as blurring with $\sigma_{\frac{\text{min} + \text{max}}{2}}$, and in particular the high frequencies of the minimum blur come through.

However, solving for the mixing weights is able to improve the fit from the average of the min/max extremes, so it is doing something.

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn.functional as F


# work from the project root
import os
import sys
import subprocess
from pathlib import Path

# root here refers to the segmentron-master folder
root_dir = Path(subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).strip().decode("utf-8"))
root_dir = root_dir / "segmentron-master"
os.chdir(root_dir)
sys.path.append(str(root_dir))

from sigma.blur import blur1d, blur2d_sphere

torch.manual_seed(1337)

## Toy Experiment: Optimize Mixing to Recover Blur Kernel in 2D

To illustrate the optimization of mixing scales with a toy problem, let's recover the size of the Gaussian blur when the blur is (1) convolutional and (2) local.

1. Generate a random 2D signal and smooth it with a reference sigma.
2. Instantiate our scale steering mixing with a minimum blur, maximum blur, and equal weighting.
3. Learn the weighting by gradient descent.

Note that the min/max blurs are themselves differentiable blurs, so the bounds could be tuned.

First we inspect the approximation when the true blur is the average of the min/max blur.
Of course when the true blur is the min or max, it is exact, but the intermediate blurs are (very) approximate.

In [None]:
def blur2d_mix(x, min_sigma, max_sigma, mix=0.5):
    x_min = blur2d_sphere(x, min_sigma)
    x_max = blur2d_sphere(x, max_sigma)
    return mix * x_min + (1 - mix) * x_max
    
x = torch.randn(1, 1, 64, 64)
true_sigma = torch.tensor(3.)
xf = blur2d_sphere(x, true_sigma).detach()
xm = blur2d_mix(x, torch.tensor(1.0), torch.tensor(5.0), torch.tensor(0.5))

plt.figure(figsize=(10, 2))
plt.subplot(1, 3, 1)
plt.title('signal')
plt.imshow(x.squeeze().numpy())
plt.subplot(1, 3, 2)
plt.title('true')
plt.imshow(xf.squeeze().numpy())
plt.subplot(1, 3, 3)
plt.title('mixed')
plt.imshow(xm.squeeze().numpy())

Let's attempt to recover an intermediate scale by mixing.
It isn't too close, but it's better than the beginning.

In [None]:
def plot_recovery(xf, xf_hat, iter_):
    plt.figure(figsize=(5, 2))
    plt.title("Recovery iter. {}".format(iter_))
    plt.subplot(1, 2, 1)
    plt.imshow(xf.squeeze().detach().numpy())
    plt.subplot(1, 2, 2)
    plt.imshow(xf_hat.squeeze().detach().numpy())
    
true_sigma = torch.tensor(3.)

min_sigma = torch.tensor(3.)
max_sigma = torch.tensor(9.)

xf = blur2d_sphere(x, true_sigma, std_devs=2)

mix = torch.nn.Parameter(torch.tensor(0.0))
opt = torch.optim.Adamax([mix], lr=0.1)

max_iter = 1000
for iter_ in range(max_iter):
    xf_hat = blur2d_mix(x, min_sigma, max_sigma, torch.sigmoid(mix))
    diff = xf_hat - xf
    loss = (diff**2).mean()
    loss.backward()
    opt.step()
    opt.zero_grad()
    
    if iter_ % 100 == 0:
        print('loss ', loss.item())
    if iter_ in (0, 4, 16):
        plot_recovery(xf, xf_hat, iter_)
plot_recovery(xf, xf_hat, iter_ + 1)

weight = torch.sigmoid(mix)
approx_sigma = weight * min_sigma + (-weight + 1.) * max_sigma

print('\ntrue sigma {:0.2f} recovered sigma {:0.2f}'.format(true_sigma.item(), approx_sigma.item()))

The last blur was still global.
Let's experiment with mixing to recover a local blur: the left half of the reference is sharper and the right half is blurrier.

In [None]:
def plot_recovery(xf, xf_hat, iter_):
    plt.figure(figsize=(5, 2))
    plt.title("Recovery iter. {}".format(iter_))
    plt.subplot(1, 2, 1)
    plt.imshow(xf.squeeze().detach().numpy())
    plt.subplot(1, 2, 2)
    plt.imshow(xf_hat.squeeze().detach().numpy())
    
true_lil_sigma = torch.tensor(4.)
true_big_sigma = torch.tensor(7.)

min_sigma = torch.tensor(3.)
max_sigma = torch.tensor(9.)

mix = torch.cat((torch.ones(x.size(-1) // 2), torch.zeros(x.size(-1) // 2)))
xf = blur2d_mix(x, true_lil_sigma, true_big_sigma, mix)

mix = torch.nn.Parameter(torch.zeros_like(x))
opt = torch.optim.Adamax([mix], lr=0.1)

max_iter = 1000
for iter_ in range(max_iter):
    xf_hat = blur2d_mix(x, min_sigma, max_sigma, torch.sigmoid(mix))
    diff = xf_hat - xf
    loss = (diff**2).mean()
    loss.backward()
    opt.step()
    opt.zero_grad()
    
    if iter_ % 100 == 0:
        print('loss ', loss.item())
    if iter_ in (0, 4, 16):
        plot_recovery(xf, xf_hat, iter_)
plot_recovery(xf, xf_hat, iter_ + 1)