# Learning Filter Scale by $\partial\sigma$ (with a PyTorch Module)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn.functional as F


# work from the project root
import os
import sys
import subprocess
from pathlib import Path

# root here refers to the segmentron-master folder
root_dir = Path(subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).strip().decode("utf-8"))
root_dir = root_dir / "segmentron-master"
os.chdir(root_dir)
sys.path.append(str(root_dir))

from sigma.blur import blur1d, blur2d_sphere

torch.set_default_dtype(torch.float64)

torch.manual_seed(1337)

## Toy Experiment: Optimize $\sigma$ to Recover Blur Kernel in 1D

To illustrate the optimization of kernel size via sigma with a toy problem, let's recover the size of a Gaussian blur from smoothed data in 1D.

1. Generate a random 1D signal and smooth it with a reference sigma.
2. Instantiate our filter with zero initialization of the sigma parameter.
3. Learn sigma by gradient descent.

In [None]:
x = torch.randn(1, 1, 100)
true_sigma = torch.tensor(3.)
xf = blur1d(x, true_sigma, std_devs=2).detach()

plt.figure(figsize=(10, 2))
plt.subplot(1, 2, 1)
plt.title('signal')
plt.plot(x.squeeze().numpy())
plt.subplot(1, 2, 2)
plt.title('smoothed')
plt.plot(xf.squeeze().numpy())

In [None]:
def plot_recovery(xf, xf_hat, iter_):
    plt.figure(figsize=(5, 2))
    plt.title("Recovery iter. {}".format(iter_))
    plt.plot(xf.squeeze().detach().numpy(), 'b', label='ref.')
    plt.plot(xf_hat.squeeze().detach().numpy(), 'r', label='rec.')
    plt.legend()
    
scale = torch.nn.Parameter(torch.tensor(0.))
opt = torch.optim.SGD([scale], lr=1.0)

max_iter = 100
for iter_ in range(max_iter):
    xf_hat = blur1d(x, scale.exp(), std_devs=2)
    diff = xf_hat - xf
    loss = 0.5 * (diff**2).mean()
    loss.backward()
    opt.step()
    opt.zero_grad()
    
    if iter_ % 10 == 0:
        print('loss ', loss.item())
    if iter_ in (0, 4, 16):
        plot_recovery(xf, xf_hat, iter_)
plot_recovery(xf, xf_hat, iter_ + 1)

print('\ntrue sigma {:0.2f} recovered sigma {:0.2f}'.format(true_sigma.item(), scale.exp().item()))

Check the gradient by finite differences.

In [None]:
eps = torch.tensor(1e-4)
# check gradient at random scales
for _ in range(10):
    scale = torch.nn.Parameter((torch.randn(1))[0])
    # forward-backward
    xf_hat = blur1d(x, scale.exp())
    loss = 0.5 * ((xf_hat - xf)**2).mean()
    loss.backward()
    grad = scale.grad
    
    # forward +eps
    xf_eps = blur1d(x, (scale + eps).exp(), std_devs=2)
    loss_eps = 0.5 * ((xf_eps - xf)**2).mean()
    grad_eps = (loss_eps - loss) / eps
    err = torch.abs(grad - grad_eps)
    print('analytic {: 09.5f} numerical {: 09.5f} error {:0.8f}'.format(grad.item(), grad_eps.item(), err.item()))
    assert(err < 10*eps)

## Toy Experiment: Optimize $\sigma$ to Recover Blur Kernel in 2D

To illustrate the optimization of kernel size via sigma with a toy problem, let's recover the size of a Gaussian blur from smoothed data in 2D.

1. Generate a random 2D signal and smooth it with a reference sigma.
2. Instantiate our filter with zero initialization of the sigma parameter.
3. Learn sigma by gradient descent.

In [None]:
x = torch.randn(1, 1, 32, 32)
true_sigma = torch.tensor(3.)
xf = blur2d_sphere(x, true_sigma).detach()

plt.figure(figsize=(10, 2))
plt.subplot(1, 2, 1)
plt.title('signal')
plt.imshow(x.squeeze().numpy())
plt.subplot(1, 2, 2)
plt.title('smoothed')
plt.imshow(xf.squeeze().numpy())

In [None]:
def plot_recovery(xf, xf_hat, iter_):
    plt.figure(figsize=(5, 2))
    plt.title("Recovery iter. {}".format(iter_))
    plt.subplot(1, 2, 1)
    plt.imshow(xf.squeeze().detach().numpy())
    plt.subplot(1, 2, 2)
    plt.imshow(xf_hat.squeeze().detach().numpy())
    
scale = torch.nn.Parameter(torch.tensor(0.))
opt = torch.optim.SGD([scale], lr=1.0)

max_iter = 500
for iter_ in range(max_iter):
    xf_hat = blur2d_sphere(x, scale.exp())
    diff = xf_hat - xf
    loss = 0.5 * (diff**2).mean()
    loss.backward()
    opt.step()
    opt.zero_grad()
    
    if iter_ % 50 == 0:
        print('loss ', loss.item())
    if iter_ in (0, max_iter // 16, max_iter // 4):
        plot_recovery(xf, xf_hat, iter_)
plot_recovery(xf, xf_hat, iter_ + 1)

print('\ntrue sigma {:0.2f} recovered sigma {:0.2f}'.format(true_sigma.item(), scale.exp().item()))

Check the gradient by finite differences.

In [None]:
eps = torch.tensor(1e-4)
# check gradient at random scales
for _ in range(10):
    scale = torch.nn.Parameter((torch.randn(1))[0])
    # forward-backward
    xf_hat = blur2d_sphere(x, scale.exp())
    loss = 0.5 * ((xf_hat - xf)**2).mean()
    loss.backward()
    grad = scale.grad
    
    # forward +eps
    xf_eps = blur2d_sphere(x, (scale + eps).exp())
    loss_eps = 0.5 * ((xf_eps - xf)**2).mean()
    grad_eps = (loss_eps - loss) / eps
    err = torch.abs(grad - grad_eps)
    print('analytic {: 09.5f} numerical {: 09.5f} error {:0.8f}'.format(grad.item(), grad_eps.item(), err.item()))
    assert(err < 10*eps)