# Figure 4: Rate of Convergence of Affine Spike Registration Formulation

The code below tests the affine spike registration formulation, i.e. equation (10) in the paper (arXiv v1).

The implementation is in pytorch.

In [None]:
# Handle imports
import sys
sys.path.append("../src")
from images import (imagesc, show_false_color)
from registration_pt import reg_l2_spike, device, precision
import torch
from torch import tensor

from images import show_false_color, get_affine_grid, gaussian_cov
from torch.nn.functional import pad


# convenient lambdas
sfc = lambda X: show_false_color(X,normalize=True)
op_norm_22 = lambda A: torch.linalg.norm(A, ord=2)
op_norm_12 = lambda A: torch.sqrt(torch.max(torch.sum(A**2, 0)))
fro_norm = lambda A: torch.linalg.norm(A, ord='fro')
vecnorm_2 = lambda A: torch.linalg.norm(A, ord=2)
cond_num = lambda A: torch.linalg.cond(A)

In [None]:
# Helper function: make the test data
def make_test_data(M,N,C,B, show_plots=False,
        MAX_ANGLE=torch.tensor(np.pi,device=device(),dtype=precision()),
        MAX_SCALE=0.3, MAX_TRANSLATION=10, SPIKE_HALFWIDTH=0, sigma0=3):
    """Make some affine test data"""

    from images import (show_false_color, get_affine_grid, gaussian_cov)
    from torch.nn.functional import pad

    dev = device()

    # Tolerances
    pi = torch.tensor(np.pi,device=dev,dtype=precision())

    # Make transforms
    A = torch.zeros((B,2,2), device=dev, dtype=precision())
    b = torch.zeros((B,2), device=dev, dtype=precision())
    for idx in range(B):
        angle1 = MAX_ANGLE * torch.rand((1,), device=dev, dtype=precision())
        angle2 = MAX_ANGLE * torch.rand((1,), device=dev, dtype=precision())
        scale1 = (1 + MAX_SCALE * (2*torch.rand((1,), device=dev,
            dtype=precision())-1))
        scale2 = (1 + MAX_SCALE * (2*torch.rand((1,), device=dev,
            dtype=precision())-1))
        U = torch.tensor(((torch.cos(angle1), -torch.sin(angle1)),
            (torch.sin(angle1), torch.cos(angle1))), device=dev,
            dtype=precision())
        V = torch.tensor(((torch.cos(angle2), -torch.sin(angle2)),
            (torch.sin(angle2), torch.cos(angle2))), device=dev,
            dtype=precision())
        scale = torch.block_diag(scale1, scale2)
        A[idx, ...] = torch.chain_matmul(U, scale, V.T)
        b[idx, :] = MAX_TRANSLATION*(2*torch.rand((2,), device=dev,
            dtype=precision())-1)

    # Make spike scene
    # Spike locs uniformly at random on the grid
    # Compute pad parameters so that we don't have spikes truncated at the edges
    pad_M = int(np.ceil(M * (1+MAX_SCALE)/2))# + 2*SPIKE_HALFWIDTH))
    pad_N = int(np.ceil(N * (1+MAX_SCALE)/2))# + 2*SPIKE_HALFWIDTH))
    locs = torch.rand((1, C, 2), device=dev, dtype=precision())
    locs = locs * torch.tensor((M-1, N-1), device=dev, dtype=precision())[None,
            None, :]
    locs_pad = locs + torch.tensor((pad_M, pad_N), device=dev,
            dtype=precision())[None, None, :]
    centroid = torch.mean(locs_pad, 1)[0, ...]
    ctr = torch.tensor(((M-1)/2 + pad_M, (N -1)/2 + pad_N), device=dev,
            dtype=precision())
    # Make them centered at the centroid
    locs_pad -= (centroid - ctr)[None, None, :]
    # Above, locs gives indices into the unpadded coordinate frame for the
    # spikes
    # locs_pad adds the left/top pad offset, so these locs are into the padded
    # frame
    scene = torch.zeros((1, C, M + 2*pad_M, N + 2*pad_N), device=dev,
            dtype=precision())
    chan = 0
    for loc in locs_pad[0,...]:
        scene[0,chan,:,:] = gaussian_cov(M + 2*pad_M, N + 2*pad_N,
                Sigma=sigma0**2*torch.eye(2,device=device(), dtype=precision()),
                offset_u=loc[0], offset_v=loc[1])
        chan += 1

    # Create the motif (by transforming the spike locs with a sigma0 gaussian)
    # Also transform relative to the center of the scene!
    corr = torch.bmm(torch.eye(2,device=dev,dtype=precision())[None,...] - A,
            ctr[None,:,None].expand(B,-1,-1))[...,0]
    motif_locs = torch.einsum('bij,bkj->bki', A,
        locs_pad.expand(B,-1,-1).to(precision())) + (corr+b)[:, None, :]
    motifs = torch.zeros((B, C, M+2*pad_M, N+2*pad_N), device=dev,
            dtype=precision())
    for motif_idx in range(B):
        chan = 0
        for loc in motif_locs[motif_idx,...]:
            motifs[motif_idx,chan,:,:] = gaussian_cov(M+2*pad_M, N+2*pad_N,
                    Sigma=sigma0**2*torch.eye(2,device=device(),
                        dtype=precision()), offset_u=loc[0], offset_v=loc[1])
            chan += 1
    if show_plots:
        f, (ax1, ax2) = plt.subplots(1,2)
        show_false_color(scene[0], ax=ax1)
        show_false_color(motifs[0], ax=ax2)
        plt.show()

    
    return scene, motifs, A, b+corr, locs_pad, b


In [None]:
# set to false to not plot anything
do_plots = True

dev = device()

# Data params
M = 61
N = 81
C = 8
B = 1
sigma0 = 3
spike_halfwidth = 0
crop_thresh = 1e-6

scene, motif, A, b, locs_raw, b_uncorr = make_test_data(M, N, C, B,
        sigma0=sigma0, MAX_ANGLE=0.5)

# Calculate scene centroid
centroid_scene = torch.mean(locs_raw[0], 0)

# Calculate motif location
Bm, Cm, Mm, Nm = motif.shape
locs = torch.zeros((Bm, 2), device=dev, dtype=precision())


# Compute theoretical parameters
U = locs_raw[0].transpose(0,1)
U = U - torch.mean(U, 1).unsqueeze(1)
P1perp = torch.eye(C,device=dev,dtype=precision()) - 1/C * torch.ones((C,C),device=dev,dtype=precision())
eye2 = torch.eye(2,device=dev,dtype=precision())

kappa = cond_num(torch.matmul(torch.matmul(U, P1perp), U.transpose(0,1)))

sigma2_lb = 2*kappa * (op_norm_12(U)**2 * fro_norm(A[0]-eye2)**2 + 
                       C * op_norm_12(U)**2 * vecnorm_2(b_uncorr[0])**2 / op_norm_22(U)**2)
sigma = torch.sqrt(sigma2_lb).cpu().numpy().item()

pi = torch.tensor(np.pi,device=dev,dtype=precision())
tA = C / op_norm_22(torch.matmul(U, P1perp))**2
tb = 1
v = 4 * pi**2 * torch.tensor(sigma**4, device=dev, dtype=precision())

# Opt params
params = {'sigma': sigma,
        'sigma0': sigma0,
        'step_A': tensor(v*tA/5, device=device(), dtype=precision()),
        'step_b': tensor(v*tb/5, device=device(), dtype=precision()),
        'max_iter': 30,
        'rejection_thresh': 0}
A_opt, b_opt, error, scene_opt, motif_opt, A_list, b_list, Rvals = reg_l2_spike(
        scene[0,...], motif[0,...], locs, **params, external_smoothing=True,
        record_process=True, quiet=False)

# Plot the results as a subplot
if do_plots:
    show_spike_reg_results(error, scene_opt, motif_opt)

# Plot optimization error path for A,b
A_flip = torch.flip(A[0], [0,1])
b_flip = torch.flip(b_uncorr[0], [0])
max_iter = params['max_iter']

center_mid = torch.tensor([scene.shape[2]//2, scene.shape[3]//2], device=dev,
            dtype=precision())
center_mass = torch.mean(locs_raw[0], 0)
corr_diff = torch.bmm(torch.eye(2,device=dev,dtype=precision())[None,...] - A,
        (center_mid - center_mass)[None,:,None].expand(B,-1,-1))[...,0]

A_err = torch.zeros(max_iter).to(dev)
b_err = torch.zeros(max_iter).to(dev)
position_err = torch.zeros(max_iter).to(dev)
locs_rectr = locs_raw - torch.mean(locs_raw, 1)[None, ...]
positions_true = torch.matmul(A_flip, torch.flip(locs_rectr[0,...],
    (1,)).T).T + b_flip
for i in range(max_iter):
    A_err[i] = torch.log(torch.sqrt(torch.sum((A_list[i] - A_flip) ** 2)))
    Q = torch.eye(2, device=dev, dtype=precision()) - A_list[i,...]
    corr = torch.matmul(Q, torch.flip(centroid_scene, (0,))[:,None])
    #b_err[i] = torch.log(torch.sqrt(torch.sum((b_list[i] - b_flip + corr) ** 2)))
    b_err[i] = torch.log(torch.sqrt(torch.sum((b_list[i] - b_flip) ** 2)))
    positions_opt = torch.matmul(A_list[i], torch.flip(locs_rectr[0,...],
        (1,)).T).T + b_list[i][None,...]
    position_err[i] = torch.log(torch.sqrt(torch.sum((positions_opt-
        positions_true) ** 2)))

if do_plots:
    plt.figure()
    plt.plot(A_err.cpu()[1:])
    plt.title('Error in A')
    plt.show()
    plt.figure()
    plt.plot(b_err.cpu()[1:])
    plt.title('Error in b')
    plt.show()
    plt.figure()
    plt.plot(position_err.cpu()[1:])
    plt.title('Cumulative positional error')
    plt.show()

    plt.figure()
    plt.plot(Rvals.detach().cpu().numpy())
    plt.show()

    Ab_err = A_err / tA + b_err
    log_err_bound = 2*torch.arange(max_iter,device=dev,dtype=precision()) * torch.log(1-1/(2*kappa)) + \
                        torch.log(fro_norm(eye2-A_flip)**2 / tA + vecnorm_2(b_flip)**2)
        
    plt.figure()
    plt.plot(torch.log(Ab_err[:20]).cpu().numpy())
    plt.plot(log_err_bound[:20].cpu().numpy())
    plt.legend(['log(LHS)', 'log(RHS)'])
    plt.show()

