In [1]:
%matplotlib widget

In [2]:
from torch.utils.data import DataLoader
import torch

from model.base.geometry import Geometry
from common.evaluation import Evaluator
from common.logger import AverageMeter
from common.logger import Logger
from data import download
from model import chmnet

from matplotlib import pyplot as plt
from matplotlib.patches import ConnectionPatch

from PIL import Image
import torchvision.transforms as transforms

from glob import glob
import numpy as np

from ipywidgets import interact, interactive, fixed

### CUDA Status 

In [3]:
# Make sure I am using only One GPU!
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

True
0
1
NVIDIA GeForce GTX 1080 Ti


## Model and Parameter Initialization

### Model Parameters

In [4]:
args = dict({
    'alpha' : [0.05, 0.1], 
    'benchmark':'pfpascal', 
    'bsz':32, 
    'datapath':'../Datasets_CHM', 
    'img_size':240, 
    'ktype':'psi', 
    'load':'pretrained/pas_psi.pt',
    'thres':'img'
    })

### Model initialization

In [5]:
model = chmnet.CHMNet(args['ktype']).cuda()
model.load_state_dict(torch.load(args['load']))
Evaluator.initialize(args['alpha'])
Geometry.initialize(img_size=args['img_size'])

## Data Utils

### Transform

In [6]:
my_transform = transforms.Compose(
   [transforms.Resize((args['img_size'], args['img_size'])),
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])

## Keypoint Widget

In [7]:
import matplotlib.pyplot as plt
import IPython.display as Disp
from ipywidgets import widgets
import numpy as np

In [8]:
class point_selection_widget():
#   %matplotlib notebook 
  def __init__(self, im):
    self.im = im
    self.selected_points = []
    self.fig, self.ax = plt.subplots()
    self.img = self.ax.imshow(self.im.copy())
    self.ka = self.fig.canvas.mpl_connect('button_press_event', self.onclick)
    disconnect_button = widgets.Button(description="End Point Selection")
    Disp.display(disconnect_button)
    disconnect_button.on_click(self.disconnect_mpl)

  def update_dots(self, img, pts):
    pts = np.array(pts, np.int32)
    pts = pts.T
    self.ax.imshow(img)
    self.ax.scatter(pts[0, :], pts[1, :], c='red')

  def onclick(self, event):
    self.selected_points.append([event.xdata, event.ydata])
    if len(self.selected_points)>1:
      self.fig
      self.update_dots(self.im.copy(), self.selected_points)

  def disconnect_mpl(self,_):
      self.fig.canvas.mpl_disconnect(self.ka)

## Colors

In [9]:
colors = []

for k in range(40):
  colors.append(np.random.rand(3))

## Run CHM and Visualize the Output

In [10]:
def run_model(imageA_path, imgaeB_path, selected_points, plot_title='CHM Keypoint Transfer Output'):
  # Load Images
  src_pil_img = Image.open(imageA_path).convert('RGB')
  tgt_pil_img = Image.open(imgaeB_path).convert('RGB')
  # Convert to Tensor
  src_img_tnsr = my_transform(src_pil_img).unsqueeze(0)
  tgt_img_tnsr = my_transform(tgt_pil_img).unsqueeze(0)
  
  # SRC POINT PREPARATION
  src_w, src_h = src_pil_img.size
  selected_points[:, 0] = 240*selected_points[:, 0] / src_w
  selected_points[:, 1] = 240*selected_points[:, 1] / src_h

  selected_points = selected_points.T 
  keypoints = torch.tensor(selected_points).unsqueeze(0)
  n_pts = torch.tensor(np.asarray([selected_points.shape[1]])) # Must be an Integer Tensor

  # RUN CHM
  with torch.no_grad():
    corr_matrix = model(src_img_tnsr.cuda(), tgt_img_tnsr.cuda())
    prd_kps = Geometry.transfer_kps(corr_matrix, keypoints.cuda(), n_pts.cuda(), normalized=False)
    
  # VISUALIZATION
  src_points = keypoints[0].squeeze(0).squeeze(0).numpy()
  tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy()

  nkpts = selected_points.shape[1]

  src_points_converted  = []
  w, h = src_pil_img.size

  for x,y in zip(src_points[0], src_points[1]):
    src_points_converted.append([int(x*w/args['img_size']),int((y)*h/args['img_size'])])

  src_points_converted = np.asarray(src_points_converted[:nkpts])
  tgt_points_converted  = []

  w, h = tgt_pil_img.size

  for x,y in zip(tgt_points[0], tgt_points[1]):
    tgt_points_converted.append([int(((x+1)/2.0)*w),int(((y+1)/2.0)*h)])

  tgt_points_converted = np.asarray(tgt_points_converted[:nkpts])

  # PLOT
  fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))

  ax[0].imshow(src_pil_img)
  ax[0].scatter(src_points_converted[:, 0], src_points_converted[:, 1], c=colors[:nkpts])
  ax[0].set_title('Source')
  ax[0].set_xticks([])
  ax[0].set_yticks([])

  ax[1].imshow(tgt_pil_img)
  ax[1].scatter(tgt_points_converted[:, 0], tgt_points_converted[:, 1], c=colors[:nkpts])
  ax[1].set_title('Target')
  ax[1].set_xticks([])
  ax[1].set_yticks([])
  for i in range(nkpts):
    con = ConnectionPatch(xyA=src_points_converted[i], 
                        xyB=tgt_points_converted[i], 
                        coordsA="data", 
                        coordsB="data",
                        axesA=ax[0], axesB=ax[1], color=colors[i])
    ax[1].add_artist(con)

    ax[0].plot(src_points_converted[i][0], src_points_converted[i][1], markersize=6, color=colors[i])
    ax[1].plot(tgt_points_converted[i][0], tgt_points_converted[i][1], markersize=6, color=colors[i])

  fig.suptitle(plot_title, fontsize=16)
  plt.show()

## Load Images

In [11]:
image_paths = glob('sample_images/n02119022/*.jpeg')
images = [Image.open(img) for img in image_paths]

## Choose Source Points

An interactive widget to choose keypoints on the source image

In [12]:
class point_transfer_demo():
  def __init__(self, folder_path):
    self.image_paths = glob(f'{folder_path}/*/*.jpeg')
    self.fig, self.axes = plt.subplots(1, 2, figsize=(12, 6))
    self.selected_points = []
    self.ka = self.fig.canvas.mpl_connect('button_press_event', self.onclick)
    self.source_path = ''
    self.target_path = ''
    
    # Drop Down
    w1 = widgets.Dropdown(
        options=self.image_paths,
        value=self.image_paths[0],
        description='Source',
    )
    
    w2 = widgets.Dropdown(
        options=self.image_paths,
        value=self.image_paths[0],
        description='Target',
    )
    
    wbox = widgets.HBox([w1, w2])
    
    w1.observe(self.on_change_source)
#     Disp.display(w1)

    w2.observe(self.on_change_target)
#     Disp.display(w2)  
    
    Disp.display(wbox)
  
    # Buttons
    calcualte_points = widgets.Button(description=">> Transfer Points <<")
#     Disp.display(calcualte_points)
    calcualte_points.on_click(self.transfer_points)
    
    clear_btn = widgets.Button(description="Clear Points")
#     Disp.display(clear_btn)
    clear_btn.on_click(self.clear_selected_points)
    
    disconnect_button = widgets.Button(description="End Point Selection")
#     Disp.display(disconnect_button)
    disconnect_button.on_click(self.disconnect_mpl)
  
    bbox = widgets.HBox([calcualte_points, clear_btn, disconnect_button])
    Disp.display(bbox)

  def on_change_source(self, change):
    if change['type'] == 'change' and change['name'] == 'value':
      print("changed to %s" % change['new'])

      self.source_path = change['new']

      self.axes[0].clear()

      self.axes[0].imshow(Image.open(self.source_path))
      self.axes[0].set_title(f'Selected image: {i}')
      plt.tight_layout()
      plt.show() 

  def on_change_target(self, change):
    if change['type'] == 'change' and change['name'] == 'value':
      print("changed to %s" % change['new'])

      self.target_path = change['new']

      self.axes[1].clear()
      self.axes[1].imshow(Image.open(self.target_path))
      self.axes[1].set_title(f'Selected image: {i}')
      plt.tight_layout()
      plt.show() 

  def update_dots(self, pts, event):
    pts = np.array(pts, np.int32)
    pts = pts.T
    
    if self.axes[0] == event.inaxes:
      self.axes[0].clear()
      self.axes[0].imshow(Image.open(self.source_path))
      self.axes[0].scatter(pts[0, :], pts[1, :], c='red')

  def onclick(self, event):
    self.selected_points.append([event.xdata, event.ydata])
    if len(self.selected_points) > 1:
      self.update_dots(self.selected_points, event)

  def disconnect_mpl(self, _):
    self.fig.canvas.mpl_disconnect(self.ka)
  
  def transfer_points(self, _):
     # Load Images
    src_pil_img = Image.open(self.source_path).convert('RGB')
    tgt_pil_img = Image.open(self.target_path).convert('RGB')
    # Convert to Tensor
    src_img_tnsr = my_transform(src_pil_img).unsqueeze(0)
    tgt_img_tnsr = my_transform(tgt_pil_img).unsqueeze(0)

    # SRC POINT PREPARATION
    selected_points = np.asarray(self.selected_points)
    src_w, src_h = src_pil_img.size
    selected_points[:, 0] = 240*selected_points[:, 0] / src_w
    selected_points[:, 1] = 240*selected_points[:, 1] / src_h

    selected_points = selected_points.T 
    keypoints = torch.tensor(selected_points).unsqueeze(0)
    n_pts = torch.tensor(np.asarray([selected_points.shape[1]])) # Must be an Integer Tensor

    # RUN CHM
    with torch.no_grad():
      corr_matrix = model(src_img_tnsr.cuda(), tgt_img_tnsr.cuda())
      prd_kps = Geometry.transfer_kps(corr_matrix, keypoints.cuda(), n_pts.cuda(), normalized=False)

    # VISUALIZATION
    src_points = keypoints[0].squeeze(0).squeeze(0).numpy()
    tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy()

    nkpts = selected_points.shape[1]

    src_points_converted  = []
    w, h = src_pil_img.size

    for x,y in zip(src_points[0], src_points[1]):
      src_points_converted.append([int(x*w/args['img_size']),int((y)*h/args['img_size'])])

    src_points_converted = np.asarray(src_points_converted[:nkpts])
    tgt_points_converted  = []

    w, h = tgt_pil_img.size

    for x,y in zip(tgt_points[0], tgt_points[1]):
      tgt_points_converted.append([int(((x+1)/2.0)*w),int(((y+1)/2.0)*h)])

    tgt_points_converted = np.asarray(tgt_points_converted[:nkpts])

    self.axes[0].clear()
    self.axes[1].clear()
    
    self.axes[1].imshow(tgt_pil_img)
    self.axes[1].scatter(tgt_points_converted[:, 0], tgt_points_converted[:, 1], c=colors[:nkpts])
    self.axes[1].set_title('Target')
    self.axes[1].set_xticks([])
    self.axes[1].set_yticks([])
    
    self.axes[0].imshow(src_pil_img)
    self.axes[0].scatter(src_points_converted[:, 0], src_points_converted[:, 1], c=colors[:nkpts])
    self.axes[0].set_title('Source')
    self.axes[0].set_xticks([])
    self.axes[0].set_yticks([])

    for i in range(nkpts):
      con = ConnectionPatch(xyA=src_points_converted[i], 
                          xyB=tgt_points_converted[i], 
                          coordsA="data", 
                          coordsB="data",
                          axesA=self.axes[0], axesB=self.axes[1], color=colors[i])
      
      self.axes[1].add_artist(con)
      self.axes[0].plot(src_points_converted[i][0], src_points_converted[i][1], markersize=6, color=colors[i])
      self.axes[1].plot(tgt_points_converted[i][0], tgt_points_converted[i][1], markersize=6, color=colors[i])

    self.fig.suptitle('CHM TRANSFER DEMO \n (ImageNetV2)', fontsize=16)
    plt.show()   
    
  def clear_selected_points(self, _):
    self.selected_points = []
    self.axes[0].clear()
    self.axes[0].imshow(Image.open(self.source_path))
    self.axes[0].set_xticks([])
    self.axes[0].set_yticks([])

In [13]:
demo = point_transfer_demo('./sample_images/')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

HBox(children=(Dropdown(description='Source', options=('./sample_images/n02119022/59e819081634950f15287f6d584e…

HBox(children=(Button(description='>> Transfer Points <<', style=ButtonStyle()), Button(description='Clear Poi…

In [14]:
demo.selected_points

[]