# Figure 3: Hierarchical Invariant Template Detection Demo

The code below demonstrates a (pre-calibrated) network for the hierarchical invariant template detection task, i.e. Algorithm 1 in the arXiv paper. It extracts motifs (as described in the appendix), then performs detections.

The implementation is in Pytorch.

In [None]:
# Handle imports
import sys
sys.path.append("../src")
from crab import hierarchy
from hierarchy import extract_hierarchical, detect_hierarchical
from matplotlib.image import imread
import torch
from torch import tensor
from registration_pt import device, precision
import time

# Setup

In [None]:
# Extraction: only needs to be run once.
G = hierarchy()
flag = extract_hierarchical(G)

## Transformation options
Edit in the next cell to test detections with a different rotation magnitude of the template.

In [None]:
# Different scenes available (global rotations of crab)
#scene = imread('../data/crab_beach_0.0deg.png')
scene = imread('../data/crab_beach_7.5deg.png')
#scene = imread('../data/crab_beach_15.0deg.png')
#scene = imread('../data/crab_beach_22.5deg.png')


In [None]:
# Various helper functions for testing, etc
def draw_boxes(G, spikes, scene):
    """Draw bounding boxes for a detection output"""

    from matplotlib.image import imread

    motifs = ['claw_left', 'claw_right', 'eye_left', 'eye_right', 'eye_pair',
            'crab']
    colors = {'claw_left': 'r',
            'claw_right': 'r',
            'eye_left': 'r',
            'eye_right': 'r',
            'eye_pair': 'g',
            'crab': 'b'}
    linewidths = {'claw_left': 1,
            'claw_right': 1,
            'eye_left': 1,
            'eye_right': 1,
            'eye_pair': 3,
            'crab': 3}

    # show the scene
    plt.imshow(scene)

    # plot the bounding boxes specified in the detection results
    for motif in motifs:
        # 1. find the best detection index (smallest error)
        errs = G.nodes[motif]['detection_dict']['errors'][:,-1]
        best_idx = torch.argmin(errs)

        # 2. Plot the bounding box using parameters for this index
        spikeloc = G.nodes[motif]['detection_dict']['spike_locs'][best_idx,
                :].to('cpu')
        phi = G.nodes[motif]['detection_dict']['phi'][best_idx].to('cpu')
        A = torch.tensor(((torch.cos(phi), -torch.sin(phi)), (torch.sin(phi),
            torch.cos(phi))), device='cpu', dtype=precision())
        C, M, N = G.nodes[motif]['content'][0].shape

        dir_u = torch.tensor(((M-1)/2, 0), device='cpu',
                dtype=precision())
        dir_v = torch.tensor((0, (N-1)/2), device='cpu',
                dtype=precision())
        ctr = dir_u + dir_v

        ul = - dir_u - dir_v
        ur = - dir_u + dir_v
        ll = dir_u - dir_v
        lr = dir_u + dir_v

        ul_new = (spikeloc + A @ ul).numpy()
        ur_new = (spikeloc + A @ ur).numpy()
        ll_new = (spikeloc + A @ ll).numpy()
        lr_new = (spikeloc + A @ lr).numpy()

        plt.plot( [ ul_new[1], ur_new[1], lr_new[1], ll_new[1], ul_new[1] ], [
            ul_new[0], ur_new[0], lr_new[0], ll_new[0], ul_new[0] ], color =
            colors[motif], linewidth = linewidths[motif] )

    
    # 3. Save the scene with all bounding boxes
    ax = plt.gca()
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    plt.show()
    plt.clf()

    # 4. save crab detection trace to file too
    trace = spikes[0,...].to('cpu').numpy()
    markers, stems, base = plt.stem(np.sum(trace,-1))
    ax = plt.gca()
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    plt.setp(stems, 'linewidth', 2)
    plt.setp(base, 'linewidth', 1)
    plt.setp(markers, 'markersize', 4)
    plt.show()

# This isn't used, but can be used to create more scenes
def make_beach_scene():
    """Create a scene of the crab on the beach"""

    from data import crab_beach
    from matplotlib.image import imsave

    # Try to find a rotated crab
    offset_u = 900
    offset_v = 500
    b = np.array((offset_u,offset_v))
    phi = np.pi/8
    A = np.array(((np.cos(phi), -np.sin(phi)), (np.sin(phi), np.cos(phi))))

    scene = crab_beach(b, A=A)
    sz_u = 384
    sz_v = 512
    scene_crop = scene[offset_u-100:offset_u-100+sz_u,
            offset_v-100:offset_v-100+sz_v]

    deg = np.round(10 * (phi / (np.pi / 180))) / 10
    fn = 'crab_beach_' + str(deg) + 'deg.png'
    imsave(fn, np.minimum(np.maximum(scene_crop, 0), 1))

    pass

# Simple timer class with context
class Timer(object):
    def __init__(self, name=None):
        self.name = name

    def __enter__(self):
        self.tstart = time.time()

    def __exit__(self, type, value, traceback):
        if self.name:
            print('[%s]' % self.name,)
        print('Elapsed: %s' % (time.time() - self.tstart))


In [None]:
# Code for detection, given the selected scene.
scene = scene[..., 0:3]
dev = device()
Y = torch.tensor(scene, device=dev, dtype=precision())
Y = torch.moveaxis(Y, -1, 0)

with Timer('timing detection...'):
    spikes = detect_hierarchical(Y, G)

# Visualize results
draw_boxes(G, spikes, scene)