# Figure 3: Hierarchical Invariant Template Detection Demo

The code below demonstrates a (pre-calibrated) network for the hierarchical invariant template detection task, i.e. Algorithm 1 in the arXiv paper. It extracts motifs (as described in the appendix), then performs detections.

The implementation is in Pytorch.

In [None]:
# Handle imports
import sys
sys.path.append("../src")
from crab import hierarchy
from hierarchy import extract_hierarchical, detect_hierarchical
from matplotlib.image import imread
import torch
from torch import tensor
from registration_pt import device, precision
import time
import matplotlib.pyplot as plt
import numpy as np

# Setup
data_path = '../data/detection_images'

In [None]:
# Extraction: only needs to be run once.
G = hierarchy(base_path = data_path)
ext_spikes = extract_hierarchical(G)

## Helper functions: observation generation and plotting...

In [None]:
# Code for generating an articulated crab observation
# First some helper functions.
# Various helper functions for testing, etc
def draw_boxes(G, spikes, scene, out_fn=None):
    """Draw bounding boxes for a detection output"""

    from matplotlib.image import imread

    motifs = ['claw_left', 'claw_right', 'eye_left', 'eye_right', 'eye_pair',
            'crab']
    colors = {'claw_left': 'r',
            'claw_right': 'r',
            'eye_left': 'r',
            'eye_right': 'r',
            'eye_pair': 'g',
            'crab': 'b'}
    linewidths = {'claw_left': 1,
            'claw_right': 1,
            'eye_left': 1,
            'eye_right': 1,
            'eye_pair': 3,
            'crab': 3}
    vertical_pads = {'claw_left': 0,
                     'claw_right': 0,
                     'eye_left': 0,
                     'eye_right': 0,
                     'eye_pair': 15,
                     'crab': 80}
    horizontal_pads = {'claw_left': 0,
                     'claw_right': 0,
                     'eye_left': 0,
                     'eye_right': 0,
                     'eye_pair': 15,
                     'crab': 40}

    # show the scene
    plt.imshow(scene)

    # plot the bounding boxes specified in the detection results
    for motif in motifs:
        # 1. find the best detection index (smallest error)
        errs = G.nodes[motif]['detection_dict']['errors'][:,-1]
        best_idx = torch.argmin(errs)
        
        pad_shift = torch.tensor((vertical_pads[motif], horizontal_pads[motif]),
                                 device='cpu', dtype=precision())

        # 2. Plot the bounding box using parameters for this index
        spikeloc_uncorr = G.nodes[motif]['detection_dict']['spike_locs'][best_idx,
                :].to('cpu')
        phi = G.nodes[motif]['detection_dict']['phi'][best_idx].to('cpu')
        A = torch.tensor(((torch.cos(phi), -torch.sin(phi)), (torch.sin(phi),
            torch.cos(phi))), device='cpu', dtype=precision())
        C, M, N = G.nodes[motif]['content'][0].shape

        dir_u = torch.tensor(((M-1)/2 + vertical_pads[motif], 0), device='cpu',
                dtype=precision())
        dir_v = torch.tensor((0, (N-1)/2 + horizontal_pads[motif]), device='cpu',
                dtype=precision())
        ctr = dir_u + dir_v
        try:
            articulation_pt = G.nodes[motif]['params']['articulation_pt'][0,...].to('cpu')
        except KeyError:
            articulation_pt = ctr
        
#        from pdb import set_trace
#        set_trace()
        corr_term = A @ (articulation_pt - ctr)
        spikeloc = spikeloc_uncorr - corr_term

        ul = - dir_u - dir_v
        ur = - dir_u + dir_v
        ll = dir_u - dir_v
        lr = dir_u + dir_v

        ul_new = (spikeloc + A @ ul).numpy()
        ur_new = (spikeloc + A @ ur).numpy()
        ll_new = (spikeloc + A @ ll).numpy()
        lr_new = (spikeloc + A @ lr).numpy()

        plt.plot( [ ul_new[1], ur_new[1], lr_new[1], ll_new[1], ul_new[1] ], [
            ul_new[0], ur_new[0], lr_new[0], ll_new[0], ul_new[0] ], color =
            colors[motif], linewidth = linewidths[motif] )

    
    # 3. Save the scene with all bounding boxes
    ax = plt.gca()
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    if out_fn is None:
        plt.show()
    else:
        plt.savefig(out_fn + '_boxes.png')
    plt.clf()

    # 4. save crab detection trace to file too
    trace = spikes[0,...].to('cpu').numpy()
    markers, stems, base = plt.stem(np.sum(trace,-1))
    plt.plot(np.ones_like(np.sum(trace,-1)), 'r--')
    ax = plt.gca()
    ax.set_ylim([0.0, 1.5])
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    plt.setp(stems, 'linewidth', 2)
    plt.setp(base, 'linewidth', 1)
    plt.setp(markers, 'markersize', 4)
    if out_fn is None:
        plt.show()
    else:
        plt.savefig(out_fn + '_trace.png')
    plt.clf()

# This isn't used, but can be used to create more scenes
def make_beach_scene():
    """Create a scene of the crab on the beach"""

    from data import crab_beach
    from matplotlib.image import imsave

    # Try to find a rotated crab
    offset_u = 900
    offset_v = 500
    b = np.array((offset_u,offset_v))
    phi = np.pi/8
    A = np.array(((np.cos(phi), -np.sin(phi)), (np.sin(phi), np.cos(phi))))

    scene = crab_beach(b, A=A)
    sz_u = 384
    sz_v = 512
    scene_crop = scene[offset_u-100:offset_u-100+sz_u,
            offset_v-100:offset_v-100+sz_v]

    deg = np.round(10 * (phi / (np.pi / 180))) / 10
    fn = 'crab_beach_' + str(deg) + 'deg.png'
    imsave(fn, np.minimum(np.maximum(scene_crop, 0), 1))

    pass

# Simple timer class with context
class Timer(object):
    def __init__(self, name=None):
        self.name = name

    def __enter__(self):
        self.tstart = time.time()

    def __exit__(self, type, value, traceback):
        if self.name:
            print('[%s]' % self.name,)
        print('Elapsed: %s' % (time.time() - self.tstart))


def cconv_fourier(x, y):
    """Compute the circulant convolution of two images in Fourier space.

    Implementing this on its own because scipy.signal.fftconvolve seems to
    handle restriction in its 'same' mode incorrectly
    
    This function is implemented to work with potentially many-channel images:
    it will just perform the 2D convolution on the *first two dimensions* of
    the inputs. So permute dims if data is such that batch size/etc is first...

    Requires:
    x and y need to have the same shape / be broadcastable. (no automatic
    padding)

    """


    F_X = np.fft.fft2(x, axes=(0, 1), norm='backward')
    F_Y = np.fft.fft2(y, axes=(0, 1), norm='backward')

    F_XY = F_X * F_Y

    return np.real(np.fft.ifft2(F_XY, axes=(0, 1)))

def dsp_flip(X):
    """Compute the 'dsp flip' of input numpy tensor X

    If X[i1, ..., ik] represents the input tensor, this function returns the
    'dsp flipped' tensor X[-i1, ..., -ik], where all indexing is done modulo
    the sizes of each individual dimension of X. So compared to the usual
    flipud/fliplr (for matrices) flip, this leaves the first element in-place.

    Inputs:
    X : numpy array of any size

    Outputs:
    X with each dimension 'dsp flipped' as described above. Output type may be
    float (not same as X)

    """
    
    Ndims = len(X.shape)
    ax = tuple(range(Ndims)) * 2
    # what's a log factor between friends?
    return np.real(np.fft.fft2(X, axes=ax, norm='ortho'))

def ncc(w, x, W=None):
    """Compute the normalized cross correlation between filter w and scene x

    Assumptions:
    1. w is 

    Inputs:
    ---------
    w - (m, n, C) numpy array
        Filter. Smaller than x
    x - (M, N, C) numpy array
        Scene. larger than w

    """

    m, n, C = w.shape
    M, N, C = x.shape

    # patch summing filter
    if W is None:
        pad_W = np.zeros((M, N, C))
        pad_W[:m, :n, :] = np.ones((C,))
    else:
        # passed a mask
        pad_W = np.tile(W[...,None], (1, 1, 3))
        pad_W = np.pad(pad_W, ((0, M-m), (0, N-n), (0,)*2))

    # pad w
    pad_w = np.pad(w, ((0, M-m), (0, N-n), (0,)*2))

    # get scene patch norms
    norms = np.maximum(0, np.sum(cconv_fourier(dsp_flip(pad_W), x**2), -1))**0.5

    # get xcorr
    xcorr = np.maximum(0, np.sum(cconv_fourier(dsp_flip(pad_w), x), -1))

    # output
    output = (1 / np.sum(w**2)**0.5) * (xcorr / norms)

    # output checking: put -1 at nans
    tol = 1e-6
    output[norms < tol] = -1

    return output

def articulate_crab(xform_dict):
    """Make the crab's parts move independently!

    TODO: Might make more sense to pass transformations as a hierarchical
    structure or something

    """

    from matplotlib.image import imread
    from registration_pt import resample_chunked
    from images import imagesc, get_affine_grid

    # overhead
    dev = device()
    prec = precision()
    interp = lambda Y, tau: resample_chunked(torch.moveaxis(torch.tensor(Y,
        device=dev, dtype=prec), -1, 0)[None, ...], tau, 128)[0, ...]
    rot_mtx = lambda phi: torch.tensor(((torch.cos(phi),
        -torch.sin(phi)),(torch.sin(phi), torch.cos(phi))), device=dev,
        dtype=prec)
    pi = torch.tensor(np.pi, device=dev, dtype=prec)

    # Load template and motifs
    data = imread(data_path + '/left_claw.png').astype('float64')
    lc, lc_mask = (data[..., :-1], data[..., -1])
    lc_mask = lc_mask[..., None]
    data = imread(data_path + '/right_claw.png').astype('float64')
    rc, rc_mask = (data[..., :-1], data[..., -1])
    rc_mask = rc_mask[..., None]
    data = imread(data_path + '/left_eye.png').astype('float64')
    le, le_mask = (data[..., :-1], data[..., -1])
    le_mask = le_mask[..., None]
    data = imread(data_path + '/right_eye.png').astype('float64')
    re, re_mask = (data[..., :-1], data[..., -1])
    re_mask = re_mask[..., None]
    data = imread(data_path + '/body.png').astype('float64')
    body, body_mask = (data[..., :-1], data[..., -1])
    body_mask = body_mask[..., None]
    data = imread(data_path + '/crab.png').astype('float64')
    crab, crab_mask = (data[..., :-1], data[..., -1])
    crab_mask = crab_mask[..., None]

    # Pad the template for convenience while articulating
    pad_sz = 25
    template = np.pad(crab, ((pad_sz,)*2,)*2 + ((0,0),))
    template_mask = np.pad(crab_mask, ((pad_sz,)*2,)*2 + ((0,0),))

    # Locate the motifs in the template using ncc
    lc_ncc = ncc(lc, template)
    lc_idx = np.unravel_index(np.argmax(lc_ncc), lc_ncc.shape)
    rc_ncc = ncc(rc, template)
    rc_idx = np.unravel_index(np.argmax(rc_ncc), rc_ncc.shape)
    le_ncc = ncc(le, template)
    le_idx = np.unravel_index(np.argmax(le_ncc), le_ncc.shape)
    re_ncc = ncc(re, template)
    re_idx = np.unravel_index(np.argmax(re_ncc), re_ncc.shape)
    body_ncc = ncc(body, template)
    body_idx = np.unravel_index(np.argmax(body_ncc), body_ncc.shape)

    # Set up part transformation parameters
    # Transformation grid mode: need to fix locs and centers (centers is for
    # linked rigid body motion...)
    # note: centers need to be 1x2
    # note: use inverse parameterization?
    m, n, c = template.shape
    #lc_A = torch.eye(2, device=dev, dtype=prec)
    lc_A = xform_dict['lc_A']
    lc_b = torch.tensor(lc_idx, device=dev, dtype=prec)
    lc_c = torch.tensor(((51, 38),), device=dev, dtype=prec)

    # rc_A = torch.eye(2, device=dev, dtype=prec)
    rc_A = xform_dict['rc_A']
    rc_b = torch.tensor(rc_idx, device=dev, dtype=prec)
    rc_c = torch.tensor(((45, 3),), device=dev, dtype=prec)

    # le_A = torch.eye(2, device=dev, dtype=prec)
    le_A = xform_dict['le_A']
    le_b = torch.tensor(le_idx, device=dev, dtype=prec)
    le_c = torch.tensor(((82, 4),), device=dev, dtype=prec)

    # re_A = torch.eye(2, device=dev, dtype=prec)
    re_A = xform_dict['re_A']
    re_b = torch.tensor(re_idx, device=dev, dtype=prec)
    re_c = torch.tensor(((81, 4),), device=dev, dtype=prec)

    body_A = torch.eye(2, device=dev, dtype=prec)
    body_b = torch.tensor(body_idx, device=dev, dtype=prec)
    body_c = torch.zeros((1,2), device=dev, dtype=prec)

    template_A = xform_dict['template_A']
    template_b = xform_dict['template_b']
    # Make grids
    loc = torch.zeros((1,2), device=dev, dtype=prec)
    M, N, C = lc.shape
    lc_grid = get_affine_grid((lc_A.T)[None,...], -torch.flip(lc_A.T @
        lc_b, (-1,))[None,...], M, N, m, n, locs=loc, ctrs=lc_c)
    M, N, C = rc.shape
    rc_grid = get_affine_grid((rc_A.T)[None,...], -torch.flip(rc_A.T @
        rc_b, (-1,))[None,...], M, N, m, n, locs=loc, ctrs=rc_c)
    M, N, C = le.shape
    le_grid = get_affine_grid((le_A.T)[None,...], -torch.flip(le_A.T @
        le_b, (-1,))[None,...], M, N, m, n, locs=loc, ctrs=le_c)
    le_grid_2 = get_affine_grid((le_A.T)[None,...], torch.zeros((2,),
        device=dev, dtype=prec)[None,...], M, N, M, M, locs=loc, ctrs=loc)
    M, N, C = re.shape
    re_grid = get_affine_grid((re_A.T)[None,...], -torch.flip(re_A.T @
        re_b, (-1,))[None,...], M, N, m, n, locs=loc, ctrs=re_c)
    M, N, C = body.shape
    body_grid = get_affine_grid((body_A.T)[None,...], 
        -torch.flip(body_A.T @ body_b, (-1,))[None,...], M, N, m, n, locs=loc,
        ctrs=body_c)
    template_grid = get_affine_grid((template_A.T)[None, ...], template_b[None,
        ...], m, n, m, n)

    # Test to see if we can regenerate the template from these locations
    instance = (interp(lc, lc_grid) + interp(rc, rc_grid) + interp(le, le_grid)
            + interp(re, re_grid) + interp(body, body_grid))
    instance_mask = (interp(lc_mask, lc_grid) + interp(rc_mask, rc_grid) +
            interp(le_mask, le_grid) + interp(re_mask, re_grid) +
            interp(body_mask, body_grid))
    # Global transformation of the frame
    instance = resample_chunked(instance[None, ...], template_grid, 128)[0,
            ...]
    instance_mask = resample_chunked(instance_mask[None, ...], template_grid,
            128)[0, ...]
    # clamping
    instance = torch.clamp(instance, min=0.0, max=1.0)
    instance_mask = torch.clamp(instance_mask, min=0.0, max=1.0)
    instance_mask = torch.round(instance_mask)
    instance *= instance_mask

    return instance, instance_mask

def create_observation(xform_dict):
    """
    Demo the articulated crab
    """
    from images import imagesc, get_affine_grid
    from matplotlib.image import imread, imsave
    from registration_pt import resample_chunked

    dev = device()
    prec = precision()
    rot_mtx = lambda phi: torch.tensor(((torch.cos(phi),
        -torch.sin(phi)),(torch.sin(phi), torch.cos(phi))), device=dev,
        dtype=prec)
    pi = torch.tensor(np.pi, device=dev, dtype=prec)
    interp = lambda Y, tau: resample_chunked(Y[None, ...], tau, 128)[0, ...]


    # Get beach background
    beach_bg = imread(data_path + '/beach_bg.jpg')
    beach_bg = 1/255 * beach_bg.astype('float64')
    beach_bg = torch.moveaxis(torch.tensor(beach_bg, device=dev, dtype=prec),
            -1, 0)

    # Embed template instance
    embed_u = 950
    embed_v = 600
    
    target_W = 512
    target_H = 384

    # Get template instance
    instance, instance_mask = articulate_crab(xform_dict)
    instance = torch.nn.functional.pad(instance, ((target_W - instance.shape[2])//2,
                                                  (target_W - instance.shape[2])//2 + instance.shape[2] % 2,
                                                  (target_H - instance.shape[1])//2,
                                                  (target_H - instance.shape[1])//2 + instance.shape[1] % 2))
    instance_mask = torch.nn.functional.pad(instance_mask, ((target_W - instance_mask.shape[2])//2,
                                                  (target_W - instance_mask.shape[2])//2 + instance_mask.shape[2] % 2,
                                                  (target_H - instance_mask.shape[1])//2,
                                                  (target_H - instance_mask.shape[1])//2 + instance_mask.shape[1] % 2))
    embed_coords = torch.tensor((embed_u, embed_v), device=dev, dtype=prec)
    grid_scene = get_affine_grid(torch.eye(2, device=dev,
        dtype=prec)[None,...], torch.zeros((2,), device=dev,
            dtype=prec)[None, :], beach_bg.shape[1], beach_bg.shape[2],
        target_H, target_W, locs=embed_coords[None, :])
    
    image_iter = 0
    bg_crop = interp(beach_bg, grid_scene)
    scene = torch.clamp(instance_mask * instance + (1.0 - instance_mask) *
            bg_crop, max=1.0, min=0.0)
    # imsave(f'test.png', torch.moveaxis(scene, 0,
    #    -1).to('cpu').numpy())
    
    return scene

## Evaluating the calibrated model 

In [None]:
# Create some observations
dev = device()
prec = precision()
rot_mtx = lambda phi: torch.tensor(((torch.cos(phi),
    -torch.sin(phi)),(torch.sin(phi), torch.cos(phi))), device=dev,
    dtype=prec)
pi = torch.tensor(np.pi, device=dev, dtype=prec)

# Initialize
xform_dict = {}
xform_dict['lc_A'] = rot_mtx(0*pi)
xform_dict['rc_A'] = rot_mtx(0*pi)
xform_dict['le_A'] = rot_mtx(0*pi)
xform_dict['re_A'] = rot_mtx(0*pi)
xform_dict['template_A'] = rot_mtx(0*pi)
xform_dict['template_b'] = torch.zeros((2,), device=dev, dtype=prec)

# index to tranform dict keys... scan over 5 keys
idx_to_keys = ['lc_A', 'rc_A', 'le_A', 're_A', 'template_A']

# scan parameters
max_deg = 30
num_side = 3
rot_sizes = torch.linspace(-max_deg, max_deg, 2*num_side + 1, device=dev, dtype=prec)
if num_side == 0:
    rot_sizes = torch.tensor((0*pi,), device=dev, dtype=prec)
rot_sizes_rads = rot_sizes * pi / 180
max_deg_parts = 15
num_side_parts = 0
rot_sizes_parts = torch.linspace(-max_deg_parts, max_deg_parts,
                                 2*num_side_parts + 1, device=dev, dtype=prec)
if num_side_parts == 0:
    rot_sizes_parts = torch.tensor((0*pi,), device=dev, dtype=prec)
rot_sizes_rads_parts = rot_sizes_parts * pi / 180


# logging
fn_prefix = 'detection_results/'
detection_results = torch.zeros((2*num_side_parts+1,)*4 + (2*num_side+1,),
                               device=dev, dtype=prec)

# Perform scan experiment.
for lc_idx in range(2*num_side_parts+1):
    #xform_dict['lc_A'] = rot_mtx(rot_sizes_rads_parts[lc_idx])
    xform_dict['lc_A'] = rot_mtx(pi/12)
    for rc_idx in range(2*num_side_parts+1):
        #xform_dict['rc_A'] = rot_mtx(rot_sizes_rads_parts[rc_idx])
        xform_dict['rc_A'] = rot_mtx(-pi/12)
        for le_idx in range(2*num_side_parts+1):
            #xform_dict['le_A'] = rot_mtx(rot_sizes_rads_parts[le_idx])
            xform_dict['le_A'] = rot_mtx(pi/36)
            for re_idx in range(2*num_side_parts+1):
                #xform_dict['re_A'] = rot_mtx(rot_sizes_rads_parts[re_idx])
                xform_dict['re_A'] = rot_mtx(-pi/36)
                for c_idx in range(2*num_side+1):
                    xform_dict['template_A'] = rot_mtx(rot_sizes_rads[c_idx])
                    
                    scene_art = create_observation(xform_dict)

                    with Timer('timing detection...'):
                        spikes_art = detect_hierarchical(scene_art, G)
                    detection_results[lc_idx,rc_idx,le_idx,re_idx,c_idx] = spikes_art.sum()

                    # Visualize results
                    draw_boxes(G, spikes_art,
                               torch.moveaxis(scene_art, 0, -1).to('cpu').numpy(),
                               out_fn=fn_prefix + f'run{lc_idx}{rc_idx}{le_idx}{re_idx}{c_idx}')