In [1]:
%load_ext autoreload
%autoreload 2
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import numpy as np
from spatial_net import *
import tensorflow as tf
from tensorflow.python.ops.gen_math_ops import *
from tf_dropblock.nets.dropblock import DropBlock2D
from readdata import InputData
from evaluation import *

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
network_type = "SAFA_8"
start_epoch = 0
batch_size = 64
is_training = True
number_of_epoch = 1000
learning_rate_val = 5e-5
keep_prob_val = 0.8
keep_prob_dropblock_val = 0.8

# -------------------------------------------------------- #

def train(start_epoch=0, radius=50):
    '''
    Train the network and do the test
    :param start_epoch: the epoch id start to train. The first epoch is 0.
    '''
    # import data
    print('radius', radius)
    input_data = InputData(radius)

    # define placeholders
    sat_x = tf.placeholder(tf.float32, [None, 256, 256, 3], name='sat_x')
    grd_x = tf.placeholder(tf.float32, [None, 154, 231, 3], name='grd_x')
    useful_pair_s2g_op = tf.placeholder(tf.float32, [batch_size, batch_size], name='useful_pair_s2g_op')
    useful_pair_g2s_op = tf.placeholder(tf.float32, [batch_size, batch_size], name='useful_pair_g2s_op')
    utms_x = tf.placeholder(tf.float32, [None, None], name='utms')
    
    keep_prob = tf.placeholder(tf.float32)
    keep_prob_dropblock = tf.placeholder(tf.float32)
    learning_rate = tf.placeholder(tf.float32)
    training = tf.placeholder(tf.bool)
    
    drop_block_sat = DropBlock2D(keep_prob=keep_prob_dropblock, block_size=15)
    drop_block_grd = DropBlock2D(keep_prob=keep_prob_dropblock, block_size=15)
    sat_x_drop = drop_block_sat(sat_x, training)
    grd_x_drop = drop_block_grd(grd_x, training)
    
    # build model
    dimension = int(network_type[-1])
    sat_global, grd_global = SAFA(sat_x_drop, grd_x_drop, keep_prob, dimension, is_training)

    out_channel = sat_global.get_shape().as_list()[-1]
    sat_global_descriptor = np.zeros([input_data.get_full_dataset_size(), out_channel])
    grd_global_descriptor = np.zeros([input_data.get_full_dataset_size(), out_channel])
    loss = compute_loss(sat_global, grd_global, utms_x, input_data.sig, useful_pair_s2g_op, useful_pair_g2s_op)
    
    # set training
    global_step = tf.Variable(0, trainable=False)
    with tf.name_scope('train'):
        train_step = tf.train.AdamOptimizer(learning_rate, 0.9, 0.999).minimize(loss, global_step=global_step)
        

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)
    
    global_vars = tf.global_variables()
   
    var_list = []
    for var in global_vars:
        if 'VGG' in var.op.name and 'Adam' not in var.op.name:
            var_list.append(var)

    saver_to_restore = tf.train.Saver(var_list)

    # run model
    print('run model...')
    config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 1
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        print('load model...')

        if start_epoch == 0:
            load_model_path_init = '/local/zxia/checkpoints/safa/Model/Initialize/initial_model.ckpt' # replace the path with your path to the initialization model
            variables_to_restore_init = tf.contrib.framework.get_variables_to_restore(exclude=['spatial_grd','spatial_sat'])
            init_fn = tf.contrib.framework.assign_from_checkpoint_fn(load_model_path_init, variables_to_restore_init)
            print("   Model initialized from: %s" % load_model_path_init)
        else:
            load_model_path = '/local/zxia/checkpoints/safa/Model/Oxford/' + str(start_epoch - 1) + '/model.ckpt'

            saver.restore(sess, load_model_path)

            print("   Model loaded from: %s" % load_model_path)
        print('load model...FINISHED')

        # Train
        for epoch in range(start_epoch, number_of_epoch):
            iter = 0
            while True:
                batch_sat, batch_grd, batch_dis, useful_pairs_s2g, useful_pairs_g2s = input_data.next_pair_batch(batch_size)

                if batch_sat is None:
                    break

                global_step_val = tf.train.global_step(sess, global_step)

                feed_dict = {sat_x: batch_sat, grd_x: batch_grd ,utms_x: batch_dis, useful_pair_s2g_op: useful_pairs_s2g,
                             useful_pair_g2s_op: useful_pairs_g2s,
                             learning_rate: learning_rate_val, keep_prob: keep_prob_val,
                            training: True, keep_prob_dropblock: keep_prob_dropblock_val}
                _, loss_val = sess.run([train_step, loss], feed_dict=feed_dict)
                

                if iter % 20 == 0:
                    print('global %d, epoch %d, iter %d: loss : %.8f ' % (global_step_val, epoch, iter, loss_val))
                    
                iter += 1

#             model_dir = '/local/zxia/checkpoints/safa/Model/Oxford/' + str(epoch) + '/'

#             if not os.path.exists(model_dir):
#                 os.makedirs(model_dir)
#             save_path = saver.save(sess, model_dir + 'model.ckpt')
#             print("Model saved in file: %s" % save_path)

            # ---------------------- validation ----------------------
            if epoch % 10 == 0:
                print('validate...')
                print('   compute global descriptors')
                input_data.reset_scan()

                val_i = 0
                while True:
                    batch_sat, batch_grd, _ = input_data.next_batch_scan(batch_size)
                    if batch_sat is None:
                        break
                    feed_dict = {sat_x: batch_sat, grd_x: batch_grd, keep_prob: 1.0,
                                training: False, keep_prob_dropblock: keep_prob_dropblock_val}
                    sat_global_val, grd_global_val = \
                        sess.run([sat_global, grd_global], feed_dict=feed_dict)


                    sat_global_descriptor[val_i: val_i + sat_global_val.shape[0], :] = sat_global_val
                    grd_global_descriptor[val_i: val_i + grd_global_val.shape[0], :] = grd_global_val
                    val_i += sat_global_val.shape[0]

                grd_global_descriptor_validation = grd_global_descriptor[0:input_data.valNum,:]

                print('   compute accuracy')
                dist_array = 2 - 2 * np.matmul(sat_global_descriptor, np.transpose(grd_global_descriptor_validation))

                val_accuracy_global = validate(dist_array, 1, input_data)
                val_accuracy_local = validate_local(dist_array, 1, input_data)
                print('epoch',  epoch, 'val global accuracy =', val_accuracy_global * 100.0)
                print('epoch',  epoch, 'val local accuracy = ', val_accuracy_local * 100.0)
                

In [3]:
train(0,50)

radius 50
hypothesis coarse localization prior: 50
number of satellite images 23854
number of ground images in training set 17067
number of ground images in validation set 1698
number of ground images in test set 5089
Storing the index of nearby images for all satellite images. This might take a while
Instructions for updating:
Use `tf.cast` instead.


VGG16: trainable = True



The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
VGG16: trainable = True
Instructions for updating:
dim is deprecated, use axis instead

Instructions for updating:
Use tf.where in 2.0, whi

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/home/zxia/anaconda3/envs/tf_RAL/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3417, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-3-f84050fb6333>", line 1, in <module>
    train(0,50)
  File "<ipython-input-2-0a245d117f87>", line 101, in train
    _, loss_val = sess.run([train_step, loss], feed_dict=feed_dict)
  File "/home/zxia/anaconda3/envs/tf_RAL/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/home/zxia/anaconda3/envs/tf_RAL/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/zxia/anaconda3/envs/tf_RAL/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
    run_metadata)
  File "/home/zxia/anaconda3/envs/tf_RAL/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 135

TypeError: object of type 'NoneType' has no len()