In [1]:
import numpy as np
import load_utils
import spine_augmentation as aug
import confidence_map as cmap
import spine_model
import torch.optim as optim
import torch.nn as nn
import torch
import os.path as path
import torchvision
import matplotlib.pyplot as plt
import cv2
import torch.nn.functional as F
from PIL import Image
import folders as f
import os
import argparse



In [2]:
net = spine_model.SpineModelPAF()
save_path = f.checkpoint_heat_path
net.load_state_dict(torch.load(save_path))
net.eval()
net.cuda()
test_data_loader = load_utils.test_loader(2)
device = torch.device("cuda")

In [3]:
with torch.no_grad():
    test_imgs, test_labels = next(test_data_loader)
    test_imgs = np.asarray(test_imgs, np.float32)[:, np.newaxis, :, :]
    test_imgs_01 = test_imgs / 255.0
    test_imgs_tensor = torch.from_numpy(test_imgs_01).cuda()
    out_pcm, out_paf, _, _ = net(test_imgs_tensor)  # NCHW
    np_pcm = out_pcm.detach().cpu().numpy()
    np_paf = out_paf.detach().cpu().numpy()

## Gaussian to point

In [4]:


# lc = np_pcm[0,4]
# n

In [5]:
def cvshow(img):
    assert len(img.shape)==2
    img = cv2.resize(img, dsize=None, fx=4, fy=4, interpolation=cv2.INTER_NEAREST)
    cv2.imshow("img", img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

In [6]:
def centeroid(heat):
    # Parse center point of connected components
    # Return [p][xy]
    ret, heat = cv2.threshold(heat, 0.99, 1., cv2.THRESH_BINARY)
    heat = np.array(heat*255., np.uint8)
    # num: point number + 1 background
    num, labels = cv2.connectedComponents(heat)
    coords = []
    for label in range(1, num):
        mask = np.zeros_like(labels, dtype=np.uint8)
        mask[labels == label] = 255
        M = cv2.moments(mask)
        cX = int(M["m10"] / M["m00"])
        cY = int(M["m01"] / M["m00"])
        coords.append([cX, cY])
    return coords
    

In [7]:
def line_mask(pt1, pt2_list, hw):
    # Return images with a line from pt1 to each pts in pt2_list
    # Return image pixel value range: [0,1], nparray.
    assert len(hw)==2
    zeros = np.zeros([len(pt2_list), hw[0], hw[1]], dtype=np.uint8)
    masks_with_line = [cv2.line(zeros[i_pt2],tuple(pt1),tuple(pt2),255) for i_pt2, pt2 in enumerate(pt2_list)]
    masks_01 = np.array(masks_with_line, dtype=np.float32)/255.
    
    return masks_01

In [8]:
def line_dist(pt1, pt2_list):
    # Return distances of a point to a list of points.
    # Return numpy array
    pt1 = np.array(pt1)
    pt2_list = np.array(pt2_list)
    dist_1d = pt2_list-pt1
    dist_2d = np.linalg.norm(dist_1d, axis=-1)
    return dist_2d

In [17]:
def coincidence_rate(paf, line_masks, distances):
    # Return confidences of a point connects to a list of points
    # Return nparray, range from 0 to around 5 (not 0-1 due to opencv line divides distance not equal to 1.0)
    assert len(paf.shape)==2
    assert len(line_masks.shape)==3
    assert len(distances.shape)==1
    coincidence = line_masks * paf  # [p2_len][h][w]
    co_sum = np.sum(coincidence, axis=(1,2))
    co_rate = co_sum / distances
    return co_rate

In [18]:
def coincidence_rate_from_pcm_paf(lcrc_pcm, hw, paf):
    # Return confidences nparray with shape: [p1_len][p2_len]
    assert len(lcrc_pcm.shape)==3, "pcm shape should be: (lr, h, w)"
    assert lcrc_pcm.shape[0]==2, "1st dim of pcm should have 2 elements: l and r"
    assert len(hw)==2, "hw shape length should be 2"
    assert len(paf.shape)==2, "paf shape length should be 2"
    lc_coord, rc_coord = [centeroid(c) for c in lcrc_pcm[:]]  # lc_coord: [p][xy]
    coins = []  # coincidence rate list, shape: [pt1_len][pt2_len]
    for lc_pt in lc_coord[:]:
        p1_masks = line_mask(lc_pt, rc_coord, lcrc_pcm.shape[1:3])  #[p2_len][h][w]
        p1_dist = line_dist(lc_pt, rc_coord)
        coin = coincidence_rate(paf, p1_masks,p1_dist)
        coins.append(coin)
    return np.array(coins)

In [19]:
lcrc = np_pcm[0, 4:6]
coins = coincidence_rate_from_pcm_paf(lcrc, lcrc.shape[1:3], np_paf[0, 0])

In [25]:
# These 2 lists contains paired points. e.g.[3,4,5] and [3,4,6] means l3->r3, l4->r4, l6->r6
pair_l, pair_r = [], []
args_1d = np.argsort(coins, axis=None)
lc_args, rc_args = np.unravel_index(args_1d, coins.shape)
for i_arg in reversed(range(len(lc_args))):  # reverse: default argsort gives min->max sort, we want max->min results
    al = lc_args[i_arg]
    ar = rc_args[i_arg]
    if (al not in pair_l) and (ar not in pair_r):
        pair_l.append(al)
        pair_r.append(ar)
    else:
        # One point already had a better pair.
        pass
print(pair_l, pair_r)

[3, 4, 5, 0, 1, 6, 12, 7, 11, 13, 2, 14, 8, 16, 15, 9, 10] [3, 4, 5, 0, 1, 6, 12, 7, 11, 13, 2, 14, 8, 16, 15, 9, 10]
