In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import numpy as np
from spatial_net import *
import tensorflow as tf
from tensorflow.python.ops.gen_math_ops import *
from tf_dropblock.nets.dropblock import DropBlock2D
from readdata import InputData
from evaluation import *


In [None]:
input_data = InputData(50)

### our model on 3 test traversals

In [None]:
network_type = "SAFA_8"
is_training = False
batch_size = 32
tf.reset_default_graph()


# define placeholders
sat_x = tf.placeholder(tf.float32, [None, 256, 256, 3], name='sat_x')
grd_x = tf.placeholder(tf.float32, [None, 154, 231, 3], name='grd_x')

keep_prob = tf.placeholder(tf.float32)

# build model
dimension = int(network_type[-1])
sat_global, grd_global = SAFA(sat_x, grd_x, keep_prob, dimension, is_training)

out_channel = sat_global.get_shape().as_list()[-1]
sat_global_descriptor = np.zeros([input_data.get_full_dataset_size(), out_channel])
grd_global_descriptor = np.zeros([input_data.get_full_dataset_size(), out_channel])


saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)

# run model
print('run model...')
config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 1
with tf.Session(config=config) as sess:
    sess.run(tf.global_variables_initializer())

    print('load model...')


    load_model_path = '/local/zxia/checkpoints/safa/Model/Oxford/sigma/10_noL2_bs64/980/model.ckpt' # the path to the model
    
    saver.restore(sess, load_model_path)

    print("   Model loaded from: %s" % load_model_path)
    print('load model...FINISHED')

    print('validate...')
    print('   compute global descriptors')
    input_data.reset_scan()

    val_i = 0
    while True:
        print('      progress %d' % val_i)
        batch_sat, batch_grd, _ = input_data.next_batch_scan(batch_size)
        if batch_sat is None:
            break
        feed_dict = {sat_x: batch_sat, grd_x: batch_grd, keep_prob: 1.0}
        sat_global_val, grd_global_val = sess.run([sat_global, grd_global], feed_dict=feed_dict)

        sat_global_descriptor[val_i: val_i + sat_global_val.shape[0], :] = sat_global_val
        grd_global_descriptor[val_i: val_i + grd_global_val.shape[0], :] = grd_global_val
        val_i += sat_global_val.shape[0]

    grd_global_descriptor_tr1 = grd_global_descriptor[input_data.valNum:input_data.valNum+input_data.testNum_tr1,:]
    grd_global_descriptor_tr2 = grd_global_descriptor[input_data.valNum+input_data.testNum_tr1:input_data.valNum+input_data.testNum_tr1+input_data.testNum_tr2,:]
    grd_global_descriptor_tr3 = grd_global_descriptor[input_data.valNum+input_data.testNum_tr1+input_data.testNum_tr2:input_data.valNum+input_data.testNum,:]

    print('   compute accuracy')
    dist_array_tr1 = 2 - 2 * np.matmul(sat_global_descriptor, np.transpose(grd_global_descriptor_tr1))
    dist_array_tr2 = 2 - 2 * np.matmul(sat_global_descriptor, np.transpose(grd_global_descriptor_tr2))
    dist_array_tr3 = 2 - 2 * np.matmul(sat_global_descriptor, np.transpose(grd_global_descriptor_tr3))
    
    test_tr1_accuracy_global = validate(dist_array_tr1, 1, input_data, input_data.valNum)
    test_tr1_accuracy_local = validate_local(dist_array_tr1, 1, input_data, input_data.valNum)
    print( 'test traversal 1 global accuracy =', test_tr1_accuracy_global * 100.0)
    print( 'test traversal 1 local accuracy = ', test_tr1_accuracy_local * 100.0)
    
    test_tr2_accuracy_global = validate(dist_array_tr2, 1, input_data, input_data.valNum+input_data.testNum_tr1)
    test_tr2_accuracy_local = validate_local(dist_array_tr2, 1, input_data, input_data.valNum+input_data.testNum_tr1)
    print( 'test traversal 2 global accuracy =', test_tr2_accuracy_global * 100.0)
    print( 'test traversal 2 local accuracy = ', test_tr2_accuracy_local * 100.0)
    
    test_tr3_accuracy_global = validate(dist_array_tr3, 1, input_data, input_data.valNum+input_data.testNum_tr1+input_data.testNum_tr2)
    test_tr3_accuracy_local = validate_local(dist_array_tr3, 1, input_data, input_data.valNum+input_data.testNum_tr1+input_data.testNum_tr2)
    print( 'test traversal 3 global accuracy =', test_tr3_accuracy_global * 100.0)
    print( 'test traversal 3 local accuracy = ', test_tr3_accuracy_local * 100.0)

### baseline model on 3 test traversals

In [None]:
network_type = "SAFA_8"
is_training = False
batch_size = 32
tf.reset_default_graph()

# define placeholders
sat_x = tf.placeholder(tf.float32, [None, 256, 256, 3], name='sat_x')
grd_x = tf.placeholder(tf.float32, [None, 154, 231, 3], name='grd_x')

keep_prob = tf.placeholder(tf.float32)

# build model
dimension = int(network_type[-1])
sat_global, grd_global = SAFA(sat_x, grd_x, keep_prob, dimension, is_training)

out_channel = sat_global.get_shape().as_list()[-1]
sat_global_descriptor = np.zeros([input_data.get_full_dataset_size(), out_channel])
grd_global_descriptor = np.zeros([input_data.get_full_dataset_size(), out_channel])


saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)

# run model
print('run model...')
config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 1
with tf.Session(config=config) as sess:
    sess.run(tf.global_variables_initializer())

    print('load model...')


    load_model_path = '/local/zxia/checkpoints/safa/Model/Oxford/sigma/0_noL2_baseline_bs64/960/model.ckpt'
    
    saver.restore(sess, load_model_path)

    print("   Model loaded from: %s" % load_model_path)
    print('load model...FINISHED')

    print('validate...')
    print('   compute global descriptors')
    input_data.reset_scan()

    val_i = 0
    while True:
        print('      progress %d' % val_i)
        batch_sat, batch_grd, _ = input_data.next_batch_scan(batch_size)
        if batch_sat is None:
            break
        feed_dict = {sat_x: batch_sat, grd_x: batch_grd, keep_prob: 1.0}
        sat_global_val, grd_global_val = sess.run([sat_global, grd_global], feed_dict=feed_dict)


        sat_global_descriptor[val_i: val_i + sat_global_val.shape[0], :] = sat_global_val
        grd_global_descriptor[val_i: val_i + grd_global_val.shape[0], :] = grd_global_val
        val_i += sat_global_val.shape[0]

    grd_global_descriptor_tr1 = grd_global_descriptor[input_data.valNum:input_data.valNum+input_data.testNum_tr1,:]
    grd_global_descriptor_tr2 = grd_global_descriptor[input_data.valNum+input_data.testNum_tr1:input_data.valNum+input_data.testNum_tr1+input_data.testNum_tr2,:]
    grd_global_descriptor_tr3 = grd_global_descriptor[input_data.valNum+input_data.testNum_tr1+input_data.testNum_tr2:input_data.valNum+input_data.testNum,:]

    print('   compute accuracy')
    dist_array_tr1 = 2 - 2 * np.matmul(sat_global_descriptor, np.transpose(grd_global_descriptor_tr1))
    dist_array_tr2 = 2 - 2 * np.matmul(sat_global_descriptor, np.transpose(grd_global_descriptor_tr2))
    dist_array_tr3 = 2 - 2 * np.matmul(sat_global_descriptor, np.transpose(grd_global_descriptor_tr3))
    
    test_tr1_accuracy_global = validate(dist_array_tr1, 1, input_data, input_data.valNum)
    test_tr1_accuracy_local = validate_local(dist_array_tr1, 1, input_data, input_data.valNum)
    print( 'test traversal 1 global accuracy =', test_tr1_accuracy_global * 100.0)
    print( 'test traversal 1 local accuracy = ', test_tr1_accuracy_local * 100.0)
    
    test_tr2_accuracy_global = validate(dist_array_tr2, 1, input_data, input_data.valNum+input_data.testNum_tr1)
    test_tr2_accuracy_local = validate_local(dist_array_tr2, 1, input_data, input_data.valNum+input_data.testNum_tr1)
    print( 'test traversal 2 global accuracy =', test_tr2_accuracy_global * 100.0)
    print( 'test traversal 2 local accuracy = ', test_tr2_accuracy_local * 100.0)
    
    test_tr3_accuracy_global = validate(dist_array_tr3, 1, input_data, input_data.valNum+input_data.testNum_tr1+input_data.testNum_tr2)
    test_tr3_accuracy_local = validate_local(dist_array_tr3, 1, input_data, input_data.valNum+input_data.testNum_tr1+input_data.testNum_tr2)
    print( 'test traversal 3 global accuracy =', test_tr3_accuracy_global * 100.0)
    print( 'test traversal 3 local accuracy = ', test_tr3_accuracy_local * 100.0)