In [1]:
import numpy as np
np.set_printoptions(suppress=True)
import cv2
from glob import glob
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.io as sio
import scipy.ndimage as ndi
from scipy.interpolate import interp1d
from scipy.signal import convolve2d, convolve
from skimage.morphology import skeletonize, remove_small_objects
from skimage.measure import label, regionprops
from skimage.color import label2rgb
from IPython import display
from pathlib import Path
import os
from tqdm import tqdm

In [2]:
class TracerNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(4, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(16),

            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
        )

        self.regressor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(576, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 2)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.regressor(x)
        x = F.normalize(x, dim=1)
        return x

In [3]:
# ported from https://github.com/tsogkas/matlab-utils/blob/master/nms.m
def contour_orient(E, r):
    if r <= 1:
        p = 12 / r / (r + 2) - 2
        f = np.array([1, p, 1]) / (2 + p)
        r = 1
    else:
        f = np.concatenate([np.arange(1, r+1), [r+1], np.arange(r, 0, -1)]) / (r + 1)**2
    E2 = np.pad(E, [r, r], mode='symmetric')
    E2 = convolve(convolve(E2, f[np.newaxis], mode='valid'), f[:, np.newaxis], mode='valid')
    grad_f = np.array([-1, 2, -1])
    Dx = convolve2d(E2, grad_f[np.newaxis], mode='same')
    Dy = convolve2d(E2, grad_f[:, np.newaxis], mode='same')
    f = np.array([[1, 0, -1], [0, 0, 0], [-1, 0, 1]])
    F = convolve2d(E2, f, mode='same') > 0
    Dy[F] = -Dy[F]
    return np.mod(np.arctan2(Dy, Dx), np.pi)

def nms(E, r=3, s=1):
    E = E.copy()
    O = contour_orient(E, r)
    
    Dx = np.cos(O)
    Dy = np.sin(O)
    
    ht, wd = E.shape
    E1 = np.pad(E, r+1, mode='edge')
    
    cs, rs = np.meshgrid(np.arange(wd), np.arange(ht))

    for i in range(-r, r+1):
        if i == 0: continue
        cs0 = i * Dx + cs
        dcs = cs0 - np.floor(cs0)
        cs0 = np.floor(cs0).astype(int)
        
        rs0 = i * Dy + rs
        drs = rs0 - np.floor(rs0)
        rs0 = np.floor(rs0).astype(int)
        
        rs0_p = rs0 + r + 1
        cs0_p = cs0 + r + 1

        rs0_p = np.clip(rs0_p, 0, ht + 2*r)
        cs0_p = np.clip(cs0_p, 0, wd + 2*r)
        
        E2 = (1 - dcs) * (1 - drs) * E1[rs0_p + 0, cs0_p + 0]
        E2 += dcs * (1 - drs) * E1[rs0_p + 0, cs0_p + 1]
        E2 += (1 - dcs) * drs * E1[rs0_p + 1, cs0_p + 0]
        E2 += dcs * drs * E1[rs0_p + 1, cs0_p + 1]

        E[E * 1.01 < E2] = 0

    for i in range(s):
        scale = (i) / s
        E[i, :] *= scale
        E[-i-1, :] *= scale
        E[:, i] *= scale
        E[:, -i-1] *= scale

    return E 

In [4]:
def checkerboard(height, width, block_size):
    num_blocks_x = height // block_size 
    num_blocks_y = width // block_size
    checker_pattern = (-1) ** (np.add.outer(np.arange(num_blocks_x), np.arange(num_blocks_y)))
    result = np.kron(checker_pattern, np.ones((block_size, block_size)))
    return result[:height, :width]

In [5]:
# Adapted from https://gist.github.com/bmabey/4dd36d9938b83742a88b6f68ac1901a6
def bwmorph_endpoints(image):
    image = image.astype(np.int32)
    k = np.array([[1,1,1],[1,0,1],[1,1,1]])
    neighborhood_count = ndi.convolve(image,k, mode='constant', cval=1)
    neighborhood_count[~image.astype(np.bool)] = 0
    return neighborhood_count == 1

In [6]:
def get_neighbors(arr, y, x):
    h, w = arr.shape
    x1, x2 = max(x-1, 0), min(x+2, w)
    y1, y2 = max(y-1, 0), min(y+2, h)
    return arr[y1:y2, x1:x2]

In [7]:
def get_pixel_idx_list(labeled_image):
    labels = np.unique(labeled_image)
    labels = labels[labels != 0]  # Remove background if labeled as 0
    return [np.flatnonzero(labeled_image == label) for label in labels]

In [8]:
def tensor_cropper(image, cp, crop_wind):
    y, x = cp
    half = crop_wind // 2

    y1, y2 = y - half, y + half + (crop_wind % 2)
    x1, x2 = x - half, x + half + (crop_wind % 2)

    # Handle edge cases by clipping indices to valid ranges
    y1, y2 = max(0, y1), min(image.shape[0], y2)
    x1, x2 = max(0, x1), min(image.shape[1], x2)

    return image[y1:y2, x1:x2].astype(np.float32)

In [9]:
def vector_to_angle_deg(preds):
    cos_vals, sin_vals = preds[:, 0], preds[:, 1]
    angles_rad = torch.atan2(sin_vals, cos_vals)
    angles_deg = torch.rad2deg(angles_rad)
    return angles_deg

In [10]:
def compute_target_angles(p_step):
    w_size = 2 * p_step + 1

    # Create grid of coordinates relative to center
    y, x = np.mgrid[-p_step:p_step+1, -p_step:p_step+1]

    # Create mask with border set to 0
    center_mask = np.ones((w_size, w_size), dtype=bool)
    center_mask[0, :] = False
    center_mask[-1, :] = False
    center_mask[:, 0] = False
    center_mask[:, -1] = False

    # Compute angles, shift by pi, convert to degrees
    vectormatrix = x + 1j * y
    angle_deg = np.degrees(np.angle(vectormatrix) + np.pi)
    angle_deg = np.round(angle_deg)

    # Function to zero out specified angles in the inner region
    def zero_angles(angles, center_mask, target_angles):
        for t in target_angles:
            angles[(angles == t) & center_mask] = 0
        return angles

    # Target angles to zero
    targets = [0, 45, 90, 135, 180, 225, 270, 315, 360]

    # Remove angles where inner mask is True
    angle_deg = zero_angles(angle_deg, center_mask, targets)
    angle_deg_mirr = np.flipud(angle_deg) * -1

    # Get unique angles
    unique_angles = np.unique(angle_deg).astype(np.int32)
    unique_angles_mirr = np.unique(angle_deg_mirr).astype(np.int32)
    return vectormatrix, angle_deg, angle_deg_mirr, unique_angles, unique_angles_mirr

In [11]:
def rescale_values(img):
    return (img - img.min()) / (img.max() - img.min())

In [None]:
vectormatrix, angle_deg, angle_deg_mirr, unique_angles, unique_angles_mirr = compute_target_angles(p_step=3)

files = zip(glob("data/input/rgb/*"), glob("data/input/scm/*"))
data = [
    [
        cv2.imread(img_file)[...,::-1], 
        cv2.imread(cont_file, -1), 
        Path(img_file).stem
    ]
    for img_file, cont_file in files
]

# load model
device = torch.accelerator.current_accelerator(True) if torch.accelerator.is_available() else "cpu"

model = TracerNet()
checkpoint = torch.load("data/models/best.pth", weights_only=True, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()

plt.figure(figsize=(10, 10))

for img, cont, filename in data:
    img = img.astype(np.float32) * 256
    cont = cont.astype(np.float32)

    pad_width = [(50, 50), (50, 50)]
    img = np.pad(img, pad_width + [(0, 0)], mode='constant', constant_values=0)
    cont = np.pad(cont, pad_width, mode='constant', constant_values=0)
    accumulator = np.zeros_like(cont)

    for flipped in [True, False]:
        img = np.flip(img, axis=1)
        cont = np.flip(cont, axis=1)
        accumulator = np.flip(accumulator, axis=1)

        thresh = np.max(cont) * 0.5
        cont_nms = nms(cont, r=3, s=2)
        cont_nms[cont_nms < thresh] = 0
        
        cb_mask = checkerboard(*cont.shape, block_size=2)
        if flipped: cont_nms[cb_mask > 0] = 0
        else:       cont_nms[cb_mask < 0] = 0
        cont_nms[cont_nms > 0] = 1

        cont_nms = skeletonize(cont_nms.astype(bool))
        cont_nms = remove_small_objects(cont_nms, min_size=2, connectivity=8)

        labeled_image, segments = label(cont_nms, return_num=True)

        endpoints = np.zeros((2, segments - 1, 2), dtype=np.int32)
        endpoints_ang = np.zeros((2, segments - 1))
        for i in range(0, segments - 1):
            segment = labeled_image == i + 1
            segment = ndi.binary_fill_holes(segment)
            endpoints_mask = bwmorph_endpoints(segment)

            coords = np.array(np.nonzero(endpoints_mask)).T
            if len(coords) == 1: coords = np.repeat(coords, 2, axis=0)
            if len(coords) != 2:
                d = np.zeros(coords.size)
                props = regionprops(segment.astype(np.int32))
                x2, y2 = props[0].centroid
                for j, (y1, x1) in enumerate(coords):
                    d[j] = np.hypot(x2 - x1, y2 - y1)
                endpoints[:, i] = coords[np.argsort(d)[-2:][::-1]]
            else: endpoints[:, i] = coords

            for ep_idx in range(2):
                neighbors = get_neighbors(segment, *coords[ep_idx])
                neighbors[1, 1] = 0
                dy, dx = np.unravel_index(np.argmax(neighbors), neighbors.shape)
                angle_rad = np.atan2(dy - 1, dx - 1)
                endpoints_ang[ep_idx, i] = np.degrees(angle_rad) % 360

        center_points = endpoints.reshape(-1, 2)
        Y, X = center_points.T
        prio = cont[Y, X]
        angles = endpoints_ang.flatten()
        group = labeled_image[Y, X]
        running = np.ones(Y.shape, dtype=np.int32)
        pix_list = get_pixel_idx_list(labeled_image)

        interp_unique = interp1d(unique_angles, unique_angles, kind="nearest", bounds_error=False)
        interp_unique_mirr = interp1d(unique_angles_mirr, unique_angles_mirr, kind="nearest", bounds_error=False)

        inputs = np.concatenate([img, cont[...,None]], axis=2).astype(np.float32)
        crop_size = 13

        num_tracers = running.sum()
        with tqdm(total=num_tracers, desc=f"Running Tracers [flipped={flipped}]") as pbar:
            while num_tracers > 0:

                mask = running == 1
                X = X[mask]
                Y = Y[mask]
                prio = prio[mask]
                angles = angles[mask]
                group = group[mask]
                running = running[mask]

                center_points = np.column_stack([Y, X])

                crops = []
                for y, x, angle in zip(Y, X, angles):
                    crop = tensor_cropper(inputs, [y, x], crop_size * 2 + 1)
                    rot_crop = ndi.rotate(crop, angle, reshape=False, order=1)
                    final_crop = tensor_cropper(rot_crop, [crop_size+1, crop_size+1], crop_size)
                    crops.append(final_crop)
                with torch.no_grad():
                    tracer_input = torch.tensor(np.array(crops).transpose(0, 3, 2, 1)).to(device)
                    preds = model(tracer_input)
                    delta_angles = vector_to_angle_deg(preds).cpu().numpy()
                angles = (angles + delta_angles) % 360

                results = np.where(
                    angles >= 0,
                    interp_unique(angles),
                    interp_unique_mirr(angles)
                )

                complex_dir = np.array([
                    vectormatrix[angle_deg == res][0] if res > 0 else vectormatrix[angle_deg_mirr == res][0]
                    for res in results
                ])
                magnitudes = np.abs(complex_dir)
                step_size = np.abs(np.random.randn(num_tracers, 1))
                step_size[step_size < 1] = 1
                step_size[step_size >= 2.5] = 3

                dir = np.column_stack((np.imag(complex_dir), np.real(complex_dir))) / magnitudes[:, None]
                dir = np.round(dir * step_size)
                angles = np.degrees(np.atan2(*-dir.T))

                new_center_points = (center_points + dir).astype(np.int32)
                Y, X = new_center_points.T
                labeled_image[Y, X] = group
                accumulator[Y, X] += 0.05

                for i, (y, x, g) in enumerate(zip(Y, X, group)):
                    g = int(g)
                    prio[i] = cont[y, x]
                    i_hel = np.ravel_multi_index([y, x], dims=segment.shape)
                    if i_hel in pix_list[g] or prio[i] <= 15000:
                        running[i] = 0
                    else:
                        pix_list[g] = np.append(pix_list[g], i_hel)

                old_num_tracers = num_tracers
                num_tracers = running.sum()
                pbar.update(old_num_tracers - num_tracers)
                # colored_labels = label2rgb(labeled_image, bg_label=0, kind='overlay')
                # plt.clf()
                # plt.title(f"tracer: {num_tracers}")
                # plt.imshow(colored_labels)
                # #plt.imshow(accumulator, cmap='gray')
                # display.clear_output(wait=True)
                # display.display(plt.gcf())
    cv2.imwrite(f"data/output/tracer_walk/{filename}.png", (np.clip(accumulator, 0, 1) * 255).astype(np.uint8))
    # cv2.imwrite(f"data/output/tracer_walk/{filename}.png", (rescale_values(accumulator) * 255).astype(np.uint8))
plt.close("all")


Running Tracers [flipped=True]: 100%|██████████| 434/434 [00:07<00:00, 59.88it/s] 
Running Tracers [flipped=False]: 100%|██████████| 436/436 [00:06<00:00, 64.05it/s] 
