# SOLAR: Second-Order Loss and Attention for Image Retrieval

## Image matching visulisation of `solar_local` with `opencv`

Below we show you a quick example of image matching performance, and compare with the baseline `SOSNet` 

In [None]:
import os
%matplotlib notebook
from matplotlib import pyplot as plt
import numpy as np
import cv2
import torch
torch.no_grad()

## Initialise SOSNet and SOLAR

First download [SOSNet](https://github.com/scape-research/SOSNet) weights

In [None]:
!wget https://github.com/scape-research/SOSNet/raw/master/sosnet-weights/sosnet-32x32-liberty.pth -O /media/anlab/0e731fe3-5959-4d40-8958-e9f6296b38cb/home/anlab/songuyen/SOLAR/solar_local/weights/sosnet-32x32-liberty.pth

Now load the networks and send them to cuda

In [None]:
import sys
sys.path.insert(0,'..')

from solar_local.models.model import SOLAR_LOCAL, SOSNet32x32
from solar_local.utils import describe_opencv

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# load the baseline SOSNet
sosnet = SOSNet32x32()
sosnet.load_state_dict(torch.load('/media/anlab/0e731fe3-5959-4d40-8958-e9f6296b38cb/home/anlab/songuyen/SOLAR/solar_local/weights/sosnet-32x32-liberty.pth'))
sosnet = sosnet.to(device).eval()

# load SOLAR
solar_local = SOLAR_LOCAL(soa=True, soa_layers='345')
solar_local.load_state_dict(torch.load('/media/anlab/0e731fe3-5959-4d40-8958-e9f6296b38cb/home/anlab/songuyen/SOLAR/solar_local/weights/local-solar-345-liberty.pth'))
solar_local = solar_local.to(device).eval()

In [None]:
print(solar_local)

## Read images, detect keypoints and describe using SOSNet & SOLAR

In [None]:
# Load the images and detect BRISK keypoints using openCV
img1 = cv2.imread('/media/anlab/0e731fe3-5959-4d40-8958-e9f6296b38cb/home/anlab/songuyen/SOLAR/demo/22.jpg',0)
img1 = cv2.resize(img1, (600, 600))
img2 = cv2.imread('/media/anlab/0e731fe3-5959-4d40-8958-e9f6296b38cb/home/anlab/songuyen/SOLAR/demo/22_23_24.jpg',0)
img2 = cv2.resize(img2, (600, 600))

brisk = cv2.BRISK_create(100)
kp1 = brisk.detect(img1, None)
kp2 = brisk.detect(img2, None)

# We use the tfeat_utils methods that rectify patches around openCV keypoints and 
sosnet_desc_1 = describe_opencv(sosnet, img1, kp1, patch_size=32, mag_factor=3)
sosnet_desc_2 = describe_opencv(sosnet, img2, kp2, patch_size=32, mag_factor=3)

solar_desc_1 = describe_opencv(solar_local, img1, kp1, patch_size=32, mag_factor=3)
solar_desc_2 = describe_opencv(solar_local, img2, kp2, patch_size=32, mag_factor=3)

In [None]:
solar_desc_2.shape

In [None]:
sco = np.dot(solar_desc_1, solar_desc_2.T)
sco


## Brute-force matching with `opencv` 

In [None]:
bf = cv2.BFMatcher(cv2.NORM_L2)
matches_sosnet = bf.knnMatch(sosnet_desc_1, sosnet_desc_2, k=2)
matches_solar = bf.knnMatch(solar_desc_1, solar_desc_2, k=2)

Here, we use the same ratio for both descriptors (0.8, same as that from [here](https://github.com/scape-research/SOSNet/blob/master/SOSNet-demo.ipynb)) in the ratio test

In [None]:
# Apply SIFT's ratio test, notice that 0.8 may not be the best ratio for SOSNet
good_sosnet = []
good_solar = []
ratio = 0.8

for m,n in matches_sosnet:
    if m.distance < ratio*n.distance:
        good_sosnet.append([m])
        
for m,n in matches_solar:
    if m.distance < ratio*n.distance:
        good_solar.append([m])
       
print('Number of matches: SOSNet {} | SOLAR {}'.format(len(good_sosnet), len(good_solar)))

In [None]:
good_solar

You can see that SOLAR gives 15% more matches in this simple example

## Visualising

Now we plot the matches between the two images, for both descriptors, and save them as images

In [None]:
sosnet_img_matches = cv2.drawMatchesKnn(img1, kp1, img2, kp2, good_sosnet, 0, flags=2)
solar_img_matches = cv2.drawMatchesKnn(img1, kp1, img2, kp2, good_solar, 0, flags=2)

# save images
cv2.imwrite("sosnet-matches.png", sosnet_img_matches)
cv2.imwrite("solar-matches.png", solar_img_matches)

We can also see the results with `matplotlib` here

In [None]:
fig = plt.figure(dpi=200)
ax_sosnet = fig.add_subplot(211)
ax_solar = fig.add_subplot(212)

for ax in [ax_sosnet, ax_solar]:
    ax.set_xticks([])
    ax.set_yticks([])

# ax_sosnet.imshow(sosnet_img_matches)
# ax_sosnet.set_title("SOSNET: Num Matches={}".format(len(good_sosnet)))
ax_solar.imshow(solar_img_matches)
ax_solar.set_title("SOLAR: Num Matches={}".format(len(good_solar)))

plt.tight_layout()
plt.show()

In [None]:
cd /media/anlab/0e731fe3-5959-4d40-8958-e9f6296b38cb/home/anlab/songuyen/data_pill_1912/SOLAR

If you look closely, SOLAR actually avoids an obvious wrong match that is present in SOSNet's example (top right of the left image matched to the right-end of the building in the right image)

In [10]:
import argparse
import os
import time
import pickle
import pdb
from tqdm import tqdm
import math

import cv2

import numpy as np

from PIL import Image
import matplotlib
matplotlib.use('Agg')

import torch
import torch.nn.functional as F

from solar_global.networks.imageretrievalnet import init_network
from solar_global.datasets.datahelpers import default_loader 
from solar_global.utils.networks import load_network
from solar_global.utils.plots import draw_soa_map


MODEL = 'resnet101-solar-best.pth'
IMSIZE = 1024

parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image", default="assets/グループ1①_1_3.png", help="Path to the image")
args = parser.parse_args()

def nothing(x):
    pass


# initialize the list of reference points and boolean indicating
# whether cropping is being performed or not
refPt = []
cropping = False
 

def click_and_crop(event, x, y, flags, param):
	# grab references to the global variables
	global refPt, cropping
 
	# if the left mouse button was clicked, record the starting
	# (x, y) coordinates and indicate that cropping is being
	# performed
	if event == cv2.EVENT_LBUTTONDOWN:
		refPt = [(x, y)]
		cropping = True
 
	# check to see if the left mouse button was released
	elif event == cv2.EVENT_LBUTTONUP:
		# record the ending (x, y) coordinates and indicate that
		# the cropping operation is finished
		refPt.append((x, y))
		cropping = False
 
		# draw a rectangle around the region of interest
		cv2.rectangle(image, refPt[0], refPt[1], (0, 255, 0), 2)
		cv2.imshow("image", image)


# loading network
net = load_network(network_name=MODEL)

print(">>>> loaded network: ")
print(net.meta_repr())

# moving network to gpu and eval mode
net.cuda() 
net.eval() 

# load the image, clone it, and setup the mouse callback function
image = cv2.imread(args.image)
h, w = image.shape[0], image.shape[1]
if (h <= w):
    resize = (int(w * IMSIZE/h), IMSIZE)
else:
    resize = (IMSIZE, int(h * IMSIZE/w))
image = cv2.resize(image, resize)
clone = image.copy()
# cv2.namedWindow("image")
cv2.setMouseCallback("image", click_and_crop)


while True: 
    # display the image and wait for a keypress
    cv2.imshow("image", image)
    key = cv2.waitKey(20) #& 0xFF
 
	# if the 'r' key is pressed, reset the cropping region
    if key == ord("r"):
        image = clone.copy()
	# if the 'c' key is pressed, break from the loop       
    elif key == ord("q"):
        print("Exit")
        cv2.destroyAllWindows
        break
    if cv2.getWindowProperty('image', cv2.WND_PROP_VISIBLE) < 1:
        break
 
    # if there are two reference points, then crop the region of interest
    # from the image and display it
    if len(refPt) == 2:
        # display soa
        soa = draw_soa_map(default_loader(args.image), net, refPt)
        cv2.imshow("Second order attention", soa)
        cv2.waitKey(20)

# close all open windows
cv2.destroyAllWindows()

usage: ipykernel_launcher.py [-h] [-i IMAGE]
ipykernel_launcher.py: error: unrecognized arguments: --ip=127.0.0.1 --stdin=9003 --control=9001 --hb=9000 --Session.signature_scheme="hmac-sha256" --Session.key=b"3f119f99-7f0f-4a75-a418-9bee49352c4d" --shell=9002 --transport="tcp" --iopub=9004 --f=/tmp/tmp-3319rBQhB2FdQLof.json


SystemExit: 2

In [None]:
!pip install --ignore-installed --upgrade jupyter

In [None]:
!pip install ipykernel
!python -m ipykernel install

!conda install notebook ipykernel
!ipython kernelspec install-self