In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import cv2
import random
import numpy as np
import copy
from scipy.spatial import KDTree
import scipy.io as sio
from scipy.spatial.distance import cdist
import scipy
import scipy.stats as stats
from spatial_net import *
import tensorflow as tf
from tensorflow.python.ops.gen_math_ops import *
from tf_dropblock.nets.dropblock import DropBlock2D
from readdata import InputData
from evaluation import *
import matplotlib.pyplot as plt
import seaborn as sns


In [None]:
input_data = InputData(50)

In [None]:
network_type = "SAFA_8"
is_training = False
batch_size = 32
tf.reset_default_graph()


# define placeholders
sat_x = tf.placeholder(tf.float32, [None, 256, 256, 3], name='sat_x')
grd_x = tf.placeholder(tf.float32, [None, 154, 231, 3], name='grd_x')

keep_prob = tf.placeholder(tf.float32)

# build model
dimension = int(network_type[-1])
sat_global, grd_global = SAFA(sat_x, grd_x, keep_prob, dimension, is_training)

out_channel = sat_global.get_shape().as_list()[-1]
sat_global_descriptor = np.zeros([input_data.get_full_dataset_size(), out_channel])
grd_global_descriptor = np.zeros([input_data.get_full_dataset_size(), out_channel])


saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)

# run model
print('run model...')
config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 1
with tf.Session(config=config) as sess:
    sess.run(tf.global_variables_initializer())

    print('load model...')


    load_model_path = './model.ckpt' # path to the our model
    
    saver.restore(sess, load_model_path)

    print("   Model loaded from: %s" % load_model_path)
    print('load model...FINISHED')

    print('validate...')
    print('   compute global descriptors')
    input_data.reset_scan()

    val_i = 0
    while True:
        print('      progress %d' % val_i)
        batch_sat, batch_grd, _ = input_data.next_batch_scan(batch_size)
        if batch_sat is None:
            break
        feed_dict = {sat_x: batch_sat, grd_x: batch_grd, keep_prob: 1.0}
        sat_global_val, grd_global_val = \
            sess.run([sat_global, grd_global], feed_dict=feed_dict)


        sat_global_descriptor[val_i: val_i + sat_global_val.shape[0], :] = sat_global_val
        grd_global_descriptor[val_i: val_i + grd_global_val.shape[0], :] = grd_global_val
        val_i += sat_global_val.shape[0]

    grd_global_descriptor_to_use = grd_global_descriptor[0:input_data.valNum,:]

    grd_global_descriptor_to_use_trainingset = grd_global_descriptor[input_data.valNum+input_data.testNum:,:]

    print('   compute accuracy')
    dist_array = 2 - 2 * np.matmul(sat_global_descriptor, np.transpose(grd_global_descriptor_to_use))
    print('dist_array shape', np.shape(dist_array))
    val_accuracy_global = np.zeros((1, 11))
    val_accuracy_local = np.zeros((1, 11))
    print('start')
    for i in range(1,11):
        val_accuracy_global[0, i] = validate(dist_array, i, input_data)
        val_accuracy_local[0, i] = validate_local(dist_array, i, input_data)
    print( 'val global accuracy =', val_accuracy_global * 100.0)
    print( 'val local accuracy = ', val_accuracy_local * 100.0)

In [None]:
network_type = "SAFA_8"
is_training = False
batch_size = 32
tf.reset_default_graph()

# import data
# input_data = InputData(polar, 500)

# define placeholders
sat_x = tf.placeholder(tf.float32, [None, 256, 256, 3], name='sat_x')
grd_x = tf.placeholder(tf.float32, [None, 154, 231, 3], name='grd_x')

keep_prob = tf.placeholder(tf.float32)

# build model
dimension = int(network_type[-1])
sat_global, grd_global = SAFA(sat_x, grd_x, keep_prob, dimension, is_training)

out_channel = sat_global.get_shape().as_list()[-1]
sat_global_descriptor = np.zeros([input_data.get_full_dataset_size(), out_channel])
grd_global_descriptor = np.zeros([input_data.get_full_dataset_size(), out_channel])


saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)

# run model
print('run model...')
config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 1
with tf.Session(config=config) as sess:
    sess.run(tf.global_variables_initializer())

    print('load model...')


    load_model_path = './model.ckpt' # path to baseline model
    
    saver.restore(sess, load_model_path)

    print("   Model loaded from: %s" % load_model_path)
    print('load model...FINISHED')

    print('validate...')
    print('   compute global descriptors')
    input_data.reset_scan()

    val_i = 0
    while True:
        print('      progress %d' % val_i)
        batch_sat, batch_grd, _ = input_data.next_batch_scan(batch_size)
        if batch_sat is None:
            break
        feed_dict = {sat_x: batch_sat, grd_x: batch_grd, keep_prob: 1.0}
        sat_global_val, grd_global_val = \
            sess.run([sat_global, grd_global], feed_dict=feed_dict)


        sat_global_descriptor[val_i: val_i + sat_global_val.shape[0], :] = sat_global_val
        grd_global_descriptor[val_i: val_i + grd_global_val.shape[0], :] = grd_global_val
        val_i += sat_global_val.shape[0]

    grd_global_descriptor_to_use = grd_global_descriptor[0:input_data.valNum,:]

    print('   compute accuracy')
    dist_array_baseline = 2 - 2 * np.matmul(sat_global_descriptor, np.transpose(grd_global_descriptor_to_use))
    print('dist_array shape', np.shape(dist_array))
    val_accuracy_global = np.zeros((1, 11))
    val_accuracy_local = np.zeros((1, 11))
    print('start')
    for i in range(1,11):
        val_accuracy_global[0, i] = validate(dist_array_baseline, i, input_data)
        val_accuracy_local[0, i] = validate_local(dist_array_baseline, i, input_data)
    print( 'val global accuracy =', val_accuracy_global * 100.0)
    print( 'val local accuracy = ', val_accuracy_local * 100.0)

In [None]:
fullUTM = input_data.fullUTM
dist_array_ours = dist_array
similarities_ours = np.exp(-dist_array_ours)
similarities_baseline = np.exp(-dist_array_baseline)

In [None]:
query = 936 # Give an index of the ground query

fig, ax = plt.subplots(figsize=(10,10))
cm = plt.cm.get_cmap('RdYlBu')
sc = plt.scatter(fullUTM[0,:], fullUTM[1,:], marker='+',s=2, c=similarities_ours[:,query], cmap='Reds', alpha=1)
circle = plt.Circle((fullUTM[0,query], fullUTM[1,query]), 50, color='b', fill=False)
ax.add_artist(circle)
ax.set_yticklabels([])
ax.set_xticklabels([])
plt.axis('equal')
plt.title('Our Model Global Heat Map')


fig, ax = plt.subplots(figsize=(10,10))
cm = plt.cm.get_cmap('RdYlBu')
sc = plt.scatter(fullUTM[0,:], fullUTM[1,:], marker='+',s=2, c=similarities_baseline[:,query], cmap='Reds', alpha=1)
circle = plt.Circle((fullUTM[0,query], fullUTM[1,query]), 50, color='b', fill=False)
ax.add_artist(circle)
ax.set_yticklabels([])
ax.set_xticklabels([])
plt.axis('equal')
plt.title('Baseline Global Heat Map')

nearby_UTM = fullUTM[:,input_data.nearest_neighbor[str(query)]]
dist_nearby = dist_array[input_data.nearest_neighbor[str(query)],query]
fig, ax = plt.subplots(figsize=(10,10))
cm = plt.cm.get_cmap('RdYlBu')
similarities_nearby = similarities_ours[input_data.nearest_neighbor[str(query)],query]
sc = plt.scatter(nearby_UTM[0,:], nearby_UTM[1,:], marker='o',s=10, c=similarities_nearby, cmap='Reds', alpha=1)
circle1 = plt.Circle((fullUTM[0,query], fullUTM[1,query]), 50, color='b', fill=False)
ax.axhline(fullUTM[1,query], linestyle='--', color='k', linewidth=1) 
ax.axvline(fullUTM[0,query], linestyle='--', color='k', linewidth=1)
ax.add_artist(circle1)
ax.set_yticklabels([])
ax.set_xticklabels([])
plt.axis('equal')
plt.title('Our Model Local Heat Map')

dist_nearby_baseline = dist_array_baseline[input_data.nearest_neighbor[str(query)],query]
fig, ax = plt.subplots(figsize=(10,10))
cm = plt.cm.get_cmap('RdYlBu')
similarities_nearby_baseline = similarities_baseline[input_data.nearest_neighbor[str(query)],query]
sc = plt.scatter(nearby_UTM[0,:], nearby_UTM[1,:], marker='o',s=10, c=similarities_nearby_baseline, cmap='Reds', alpha=1)
ax.axhline(fullUTM[1,query], linestyle='--', color='k', linewidth=1) 
ax.axvline(fullUTM[0,query], linestyle='--', color='k', linewidth=1)
circle1 = plt.Circle((fullUTM[0,query], fullUTM[1,query]), 50, color='b', fill=False)
ax.add_artist(circle1)
ax.set_yticklabels([])
ax.set_xticklabels([])
plt.axis('equal')
plt.title('Baseline Local Heat Map')

plt.show()

In [None]:
# ours
fig, axes = plt.subplots(figsize=(5,5))

cm = plt.cm.get_cmap('RdYlBu')

axes.scatter(fullUTM[0,:], fullUTM[1,:], marker='o',s=10, c=similarities_ours[:,query], cmap='Reds', alpha=1)
circle = plt.Circle((fullUTM[0,query], fullUTM[1,query]), 50, color='b', fill=False)
axes.axhline(fullUTM[1,query], linestyle='--', color='k', linewidth=0.5) 
axes.axvline(fullUTM[0,query], linestyle='--', color='k', linewidth=0.5)
axes.add_artist(circle)
axes.set_yticklabels([])
axes.set_xticklabels([])
axes.axis('equal')
axes.set_xlim(fullUTM[0,query]-200, fullUTM[0,query]+200)
axes.set_ylim(fullUTM[1,query]-200, fullUTM[1,query]+200)

plt.show()


In [None]:
# baseline
fig, axes = plt.subplots(figsize=(5,5))

cm = plt.cm.get_cmap('RdYlBu')

axes.scatter(fullUTM[0,:], fullUTM[1,:], marker='o',s=10, c=similarities_baseline[:,query], cmap='Reds', alpha=1)
circle = plt.Circle((fullUTM[0,query], fullUTM[1,query]), 50, color='b', fill=False)
axes.axhline(fullUTM[1,query], linestyle='--', color='k', linewidth=0.5) 
axes.axvline(fullUTM[0,query], linestyle='--', color='k', linewidth=0.5)
axes.add_artist(circle)
axes.set_yticklabels([])
axes.set_xticklabels([])
axes.axis('equal')
axes.set_xlim(fullUTM[0,query]-200, fullUTM[0,query]+200)
axes.set_ylim(fullUTM[1,query]-200, fullUTM[1,query]+200)

plt.show()


In [None]:
# ours
fig, axes = plt.subplots(figsize=(5,5))
nearby_UTM = fullUTM[:,input_data.nearest_neighbor[str(query)]]
similarities_nearby = similarities_ours[input_data.nearest_neighbor[str(query)],query]

cm = plt.cm.get_cmap('RdYlBu')

axes.scatter(nearby_UTM[0,:], nearby_UTM[1,:], marker='o',s=20, c=similarities_nearby, cmap='Reds', alpha=1)
circle = plt.Circle((fullUTM[0,query], fullUTM[1,query]), 50, color='b', fill=False)
axes.axhline(fullUTM[1,query], linestyle='--', color='k', linewidth=0.5) 
axes.axvline(fullUTM[0,query], linestyle='--', color='k', linewidth=0.5)
axes.add_artist(circle)
axes.set_yticklabels([])
axes.set_xticklabels([])
axes.axis('equal')
axes.set_xlim(fullUTM[0,query]-50, fullUTM[0,query]+50)
axes.set_ylim(fullUTM[1,query]-50, fullUTM[1,query]+50)
plt.show()


In [None]:
# baseline
fig, axes = plt.subplots(figsize=(5,5))
similarities_nearby_baseline = similarities_baseline[input_data.nearest_neighbor[str(query)],query]

cm = plt.cm.get_cmap('RdYlBu')

axes.scatter(nearby_UTM[0,:], nearby_UTM[1,:], marker='o',s=20, c=similarities_nearby_baseline, cmap='Reds', alpha=1)
circle = plt.Circle((fullUTM[0,query], fullUTM[1,query]), 50, color='b', fill=False)
axes.axhline(fullUTM[1,query], linestyle='--', color='k', linewidth=0.5) 
axes.axvline(fullUTM[0,query], linestyle='--', color='k', linewidth=0.5)
axes.add_artist(circle)
axes.set_yticklabels([])
axes.set_xticklabels([])
axes.axis('equal')
axes.set_xlim(fullUTM[0,query]-50, fullUTM[0,query]+50)
axes.set_ylim(fullUTM[1,query]-50, fullUTM[1,query]+50)
plt.show()
