In [None]:
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import cv2
import random
import numpy as np
from scipy.spatial import KDTree
import copy
import scipy
from spatial_net import *
import tensorflow as tf
from tensorflow.python.ops.gen_math_ops import *
import tensorflow_probability as tfp


In [None]:
class InputData:

    sig = 25
    
    def __init__(self, radius):
        # The radius defines the local neighborhood during trainin and testing
        print('radius of the circle =', radius)        
        self.radius = radius
        
        # put your path to the dataset here
        self.image_root = '/local/zxia/datasets'
        
        # load the training, validation, and test set
        trainlist = []
        with open('./trainlist.txt', 'r') as filehandle:
            filecontents = filehandle.readlines()
            for line in filecontents:
                # remove linebreak which is the last character of the string
                content = line[:-1]
                trainlist.append(content.split(" "))
                
        vallist = []
        with open('./vallist.txt', 'r') as filehandle:
            filecontents = filehandle.readlines()
            for line in filecontents:
                # remove linebreak which is the last character of the string
                content = line[:-1]
                vallist.append(content.split(" "))
                
        testlist = []
        with open('./testlist.txt', 'r') as filehandle:
            filecontents = filehandle.readlines()
            for line in filecontents:
                # remove linebreak which is the last character of the string
                content = line[:-1]
                testlist.append(content.split(" "))
                
        
        self.trainList = trainlist
        self.trainNum = len(trainlist)
        trainarray = np.array(trainlist)
        self.trainUTM = np.transpose(trainarray[:,2:].astype(np.float64))
        
        self.valList = vallist
        self.valNum = len(vallist)
        valarray = np.array(vallist)
        self.valUTM = np.transpose(valarray[:,2:].astype(np.float64))
        
        self.testList = testlist
        self.testNum = len(testlist)
        testarray = np.array(testlist)
        self.testlUTM = np.transpose(testarray[:,2:].astype(np.float64))
        
        fulllist = vallist+trainlist+testlist
        self.fullList = fulllist
        self.fullNum = len(fulllist)
        fullarray = np.array(fulllist)
        self.fullUTM = np.transpose(fullarray[:,2:].astype(np.float64))
        
        self.fullIdList = [*range(0,self.fullNum,1)]
        self.IdList_to_use = []
        self.__cur_id = 0 
        self.__cur_test_id = 0
        
        print('For each satellite image, storing index of nearby images with the given radius. This takes some time')
        fullUTM_transposed = np.transpose(self.fullUTM)
        UTMTree = KDTree(fullUTM_transposed)
        self.neighbors = {}
        for i in range(self.fullNum):
            center_UTM = fullUTM_transposed[i,:]
            idx = UTMTree.query_ball_point(center_UTM,r=self.radius, p=2)
            # exclude images at exactly same location
            candidate = np.delete(idx, np.where(np.sum((fullUTM_transposed[idx,:]==center_UTM).astype(int),axis=1)==2))
            self.neighbors.update({str(i):idx})
        print('number of satellite images', self.fullNum)    
        print('number of ground images in training set', self.trainNum)    
        print('number of ground images in validation set', self.valNum) 
        print('number of ground images in test set', self.testNum)


    def next_batch_scan(self, batch_size):
        if self.__cur_test_id >= self.fullNum:
            self.__cur_test_id = 0 # return none and reset current index to zero after scanned all elements 
            return None, None, None
        elif self.__cur_test_id + batch_size >= self.fullNum:
            batch_size = self.fullNum - self.__cur_test_id

        batch_sat = np.zeros([batch_size, 112, 616, 3], dtype=np.float32)
        batch_grd = np.zeros([batch_size, 112, 616, 3], dtype=np.float32)
        batch_utm = np.zeros([batch_size, 2], dtype=np.float32)
        
        for i in range(batch_size):
            img_idx = self.__cur_test_id + i

            # satellite, load all satellite images
            img = cv2.imread(self.image_root + self.fullList[img_idx][1])
            if img is None:
                print('InputData::next_pair_batch: read fail: %s' % (self.fullList[img_idx][1]))
                continue

            img = img.astype(np.float32)

            img[:, :, 0] -= 103.939  # Blue
            img[:, :, 1] -= 116.779  # Green
            img[:, :, 2] -= 123.6  # Red
            batch_sat[i, :, :, :] = img

            # ground, load ground image if it is in the validation set
            if self.fullList[img_idx] in self.valList:
                img = cv2.imread(self.image_root + self.fullList[img_idx][0])

                if img is None:
                    print('InputData::next_pair_batch: read fail: %s' % (self.fullList[img_idx][0]))
                    continue
                img = cv2.resize(img, (616, 112), interpolation=cv2.INTER_AREA)

                img = img.astype(np.float32)

                img[:, :, 0] -= 103.939  # Blue
                img[:, :, 1] -= 116.779  # Green
                img[:, :, 2] -= 123.6  # Red
                batch_grd[i, :, :, :] = img

            batch_utm[i,0] = self.fullUTM[0, img_idx]
            batch_utm[i, 1] = self.fullUTM[1, img_idx]

        self.__cur_test_id += batch_size

        return batch_sat, batch_grd, batch_utm


    def next_pair_batch(self, batch_size):
        if self.__cur_id == 0:
            self.IdList_to_use = copy.deepcopy(self.fullIdList)
            for i in range(20):
                random.shuffle(self.IdList_to_use) #Only shuffle the id at the beginning of every epoch           
       
        
        batch_sat = np.zeros([batch_size, 112, 616, 3], dtype=np.float32)
        batch_grd = np.zeros([batch_size, 112, 616, 3], dtype=np.float32)

        batch_idx = 0
        i = 1
        empty_grd = []
        while True:
            if self.__cur_id + batch_size >= self.fullNum:
                self.__cur_id = 0 # return none and reset current index to zero after every epoch
                print('This epoch finished')
                return None, None, None, None, None
            
            if batch_idx >= batch_size:
                break
                
            if self.__cur_id + i >= len(self.IdList_to_use):
                # go to next center image if cannot find enough nearby images in remaining images
                self.__cur_id += 1
                batch_idx = 0
                i = 1
                continue
            
            if batch_idx == 0:
            # Load the center image
                img_idx = self.IdList_to_use[self.__cur_id]
                
                # Get the indexes of nearby images for current center image
                candidates = self.neighbors[str(img_idx)]
                
                 # If number of nearby images is not enough to form a batch, then move to the next center image
                if len(candidates) < batch_size-1:
#                     print('There is no enough nearby images for current query')
                    self.__cur_id += 1
                    continue
                    
                # satellite
                img = cv2.imread(self.image_root + self.fullList[img_idx][1])
                if img is None:
                    print('InputData::next_pair_batch: read fail: %s' % (self.fullList[img_idx][1]))
                    self.__cur_id += 1
                    continue
                    
                img = img.astype(np.float32)
                img[:, :, 0] -= 103.939  # Blue
                img[:, :, 1] -= 116.779  # Green
                img[:, :, 2] -= 123.6  # Red
                batch_sat[0, :, :, :] = img
                
                # ground
                if self.fullList[img_idx] in self.trainList: 
                    img = cv2.imread(self.image_root + self.fullList[img_idx][0])
                    if img is None:
                        print('InputData::next_pair_batch: read fail: %s' % (self.fullList[img_idx][0]))
                        self.__cur_id += 1
                        continue
                    img = cv2.resize(img, (616, 112), interpolation=cv2.INTER_AREA)
                    img = img.astype(np.float32)
                    img[:, :, 0] -= 103.939  # Blue
                    img[:, :, 1] -= 116.779  # Green
                    img[:, :, 2] -= 123.6  # Red
                    batch_grd[0, :, :, :] = img
                else:
                    # do not load ground images if they are not in the training set
                    empty_grd.append(batch_idx)
                # coordinates
                batch_utm = np.zeros([batch_size, 2], dtype=np.float32)
                batch_utm[batch_idx,0] = self.fullUTM[0, img_idx]
                batch_utm[batch_idx, 1] = self.fullUTM[1, img_idx]
                
                batch_idx += 1
            else:
            # Load other neaby images into the batch
            # if the current image is in the indexes list of nearby images, read this image, and remove this index
                if self.IdList_to_use[self.__cur_id + i] in candidates:
                    img_idx = self.IdList_to_use[self.__cur_id + i]
                    del self.IdList_to_use[self.__cur_id + i]

                    # satellite
                    img = cv2.imread(self.image_root + self.fullList[img_idx][1])
                    if img is None:
                        print('InputData::next_pair_batch: read fail: %s' % (self.fullList[img_idx][1]))
                        i += 1
                        continue

                    img = img.astype(np.float32)
                    # normalize it to -1 --- 1
                    img[:, :, 0] -= 103.939  # Blue
                    img[:, :, 1] -= 116.779  # Green
                    img[:, :, 2] -= 123.6  # Red
                    batch_sat[batch_idx, :, :, :] = img
                    
                    # ground
                    if self.fullList[img_idx] in self.trainList:
                        img = cv2.imread(self.image_root + self.fullList[img_idx][0])
                        if img is None:
                            print('InputData::next_pair_batch: read fail: %s' % (self.fullList[img_idx][0]))
                            i += 1
                            continue
                        img = cv2.resize(img, (616, 112), interpolation=cv2.INTER_AREA)
                        img = img.astype(np.float32)
                        img[:, :, 0] -= 103.939  # Blue
                        img[:, :, 1] -= 116.779  # Green
                        img[:, :, 2] -= 123.6  # Red
                        batch_grd[batch_idx, :, :, :] = img
                    else:
                        empty_grd.append(batch_idx)
                    batch_utm[batch_idx,0] = self.fullUTM[0, img_idx]
                    batch_utm[batch_idx, 1] = self.fullUTM[1, img_idx]

                    batch_idx += 1
                else:
                    i += 1
     
        self.__cur_id += 1
            
        distance_matrix = scipy.spatial.distance_matrix(batch_utm, batch_utm)
        # check the distance between every two images in current batch
        useful_pairs = (distance_matrix <= self.radius).astype(np.int)
        # diagonal elements are positive samples
        np.fill_diagonal(useful_pairs, 0)
        useful_pairs_s2g = copy.deepcopy(useful_pairs)
        useful_pairs_g2s = copy.deepcopy(useful_pairs)
        # We do not use pairs which contains empty ground images
        for i in empty_grd:
            useful_pairs_s2g[:,i] = 0
            useful_pairs_s2g[i,:] = 0
            useful_pairs_g2s[:,i] = 0 
        
        # We use squared distance in our Gaussian weighting term
        squared_batch_dis = np.square(distance_matrix)
        
        return batch_sat, batch_grd, squared_batch_dis, useful_pairs_s2g, useful_pairs_g2s
    
    def get_dataset_size(self):
        return self.trainNum

    def get_test_dataset_size(self):
        return self.valNum
    
    def get_full_dataset_size(self):
        return self.fullNum

    def reset_scan(self):
        self.__cur_test_id = 0


    

In [None]:
network_type = "SAFA_8"
start_epoch = 0
batch_size = 4
is_training = True
loss_weight = 10.0
number_of_epoch = 100
learning_rate_val = 1e-5
keep_prob_val = 0.8

# -------------------------------------------------------- #

def validate(dist_array, top_k):
    accuracy = 0.0
    data_amount = 0.0
    for i in range(dist_array.shape[1]):
        gt_dist = dist_array[i, i]
        prediction = np.sum(dist_array[:, i] < gt_dist)
        if prediction < top_k:
            accuracy += 1.0
        data_amount += 1.0
    
    accuracy /= data_amount
    
    return accuracy

def local_validation(dist_array, top_k, input_data):
    accuracy = 0.0
    data_amount = 0.0
    for i in range(dist_array.shape[1]):
        nearby_indexes = input_data.neighbors[str(i)]
        gt_dist = dist_array[i, i]
        list_withincircle = [dist_array[j, i] for j in nearby_indexes]
        prediction = np.sum(list_withincircle<gt_dist)
        if prediction < top_k:
            accuracy += 1.0
        data_amount += 1.0
    
    accuracy /= data_amount
    
    return accuracy

def loss_on_val_set(dist_array, input_data):
    data_amount = 0.0
    loss = []
    loss_median = []
    loss_90quantile = []
    loss_99quantile = []
    loss_max = []
    for i in range(dist_array.shape[1]):
        nearby_indexes = input_data.neighbors[str(i)]
        gt_dist = dist_array[i, i]
        list_withincircle = [dist_array[j, i] for j in nearby_indexes]
        data_amount += 1.0
        
        triplet_dist = gt_dist - list_withincircle
        individual_loss = np.sum(np.log(1+np.exp(triplet_dist * loss_weight))) / len(list_withincircle)
        loss.append(individual_loss)
        loss_median.append(np.median(np.log(1+np.exp(triplet_dist * loss_weight))))
        loss_max.append(np.max(np.log(1+np.exp(triplet_dist * loss_weight))))
        loss_90quantile.append(np.percentile(np.log(1+np.exp(triplet_dist * loss_weight)), 90))
        loss_99quantile.append(np.percentile(np.log(1+np.exp(triplet_dist * loss_weight)), 99))
    
    loss = np.sum(loss) / data_amount    
    loss_median = np.sum(loss_median) / data_amount
    loss_90quantile = np.sum(loss_90quantile) / data_amount
    loss_99quantile = np.sum(loss_99quantile) / data_amount
    loss_max = np.sum(loss_max) / data_amount
    
    print('on validation set:')
    print('loss: ', loss, ' loss_median: ', loss_median, ' loss_90quantile: ', loss_90quantile, ' loss_99quantile: ', loss_99quantile, ' loss_max: ', loss_max)
    
    return loss

def compute_loss(sat_global, grd_global, utms_x, UTMthres, useful_pairs_s2g, useful_pairs_g2s):
    
    with tf.name_scope('weighted_soft_margin_triplet_loss'):
        # Get the Gaussian weight term for every triplet
        tfd = tfp.distributions
        zeros = tf.fill(tf.shape(useful_pairs_s2g), 0.0)
        sig = tf.fill(tf.shape(useful_pairs_s2g), tf.constant(UTMthres,dtype=tf.float32)) 
        dist = tfd.Normal(loc=zeros, scale=sig)
        Gaussian_weights = (-dist.prob(utms_x)+dist.prob(zeros))/dist.prob(zeros)

       
        batch_size, channels = sat_global.get_shape().as_list()
        dist_array = 2 - 2 * tf.matmul(sat_global, grd_global, transpose_b=True) #[S1G1, S1G2; S2G1, S2G2]
        pos_dist = tf.diag_part(dist_array)
        
        # ground to satellite
        pair_n_g2s = tf.reduce_sum(useful_pairs_g2s) + 0.001
        triplet_dist_g2s = (pos_dist - dist_array)
        loss_g2s = tf.reduce_sum(tf.log(1 + tf.exp(triplet_dist_g2s * loss_weight))*\
                                 tf.multiply(Gaussian_weights, useful_pairs_g2s)) / pair_n_g2s

        # satellite to ground
        pair_n_s2g = tf.reduce_sum(useful_pairs_s2g) + 0.001
        triplet_dist_s2g = (tf.expand_dims(pos_dist, 1) - dist_array)
        loss_s2g = tf.reduce_sum(tf.log(1 + tf.exp(triplet_dist_s2g * loss_weight))*\
                                 tf.multiply(Gaussian_weights, useful_pairs_s2g)) / pair_n_s2g

        loss = (loss_g2s + loss_s2g) / 2.0
        
    return loss

def train(start_epoch=0, radius=500):
    '''
    Train the network and do the test
    :param start_epoch: the epoch id start to train. The first epoch is 0.
    '''
    # put your path to the model here
    model_root = '/local/zxia/checkpoints/safa/Model'
    
    # import data
    print('radius', radius)
    input_data = InputData(radius)

    # define placeholders
    sat_x = tf.placeholder(tf.float32, [None, 112, 616, 3], name='sat_x')
    grd_x = tf.placeholder(tf.float32, [None, 112, 616, 3], name='grd_x')
    useful_pair_s2g_op = tf.placeholder(tf.float32, [batch_size, batch_size], name='useful_pair_s2g_op')
    useful_pair_g2s_op = tf.placeholder(tf.float32, [batch_size, batch_size], name='useful_pair_g2s_op')
    utms_x = tf.placeholder(tf.float32, [None, None], name='utms')
    
    keep_prob = tf.placeholder(tf.float32)
    learning_rate = tf.placeholder(tf.float32)

    # build model
    dimension = int(network_type[-1])
    sat_global, grd_global = SAFA(sat_x, grd_x, keep_prob, dimension, is_training)

    out_channel = sat_global.get_shape().as_list()[-1]
    sat_global_descriptor = np.zeros([input_data.get_full_dataset_size(), out_channel])
    grd_global_descriptor = np.zeros([input_data.get_full_dataset_size(), out_channel])
    loss = compute_loss(sat_global, grd_global, utms_x, input_data.sig, useful_pair_s2g_op, useful_pair_g2s_op)

    # set training
    global_step = tf.Variable(0, trainable=False)
    with tf.name_scope('train'):
        train_step = tf.train.AdamOptimizer(learning_rate, 0.9, 0.999).minimize(loss, global_step=global_step)

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)
    
    global_vars = tf.global_variables()

    var_list = []
    for var in global_vars:
        if 'VGG' in var.op.name and 'Adam' not in var.op.name:
            var_list.append(var)

    saver_to_restore = tf.train.Saver(var_list)

    # run model
    print('run model...')
    config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 1
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        print('load model...')

        if start_epoch == 0:
            load_model_path = model_root + '/Initialize/initial_model.ckpt'
        else:
            load_model_path = model_root + '/CVACT/' + str(radius)+'/' + str(start_epoch - 1) + '/model.ckpt'

        saver.restore(sess, load_model_path)

        print("   Model loaded from: %s" % load_model_path)
        print('load model...FINISHED')

        # Train
        for epoch in range(start_epoch, number_of_epoch):
            iter = 0
            while True:
                # train
                batch_sat, batch_grd, squared_batch_dis, useful_pairs_s2g, useful_pairs_g2s = input_data.next_pair_batch(batch_size)

                if batch_sat is None:
                    break

                global_step_val = tf.train.global_step(sess, global_step)

                feed_dict = {sat_x: batch_sat, grd_x: batch_grd, utms_x:squared_batch_dis, useful_pair_s2g_op: useful_pairs_s2g,
                             useful_pair_g2s_op: useful_pairs_g2s,
                             learning_rate: learning_rate_val, keep_prob: keep_prob_val}
                _, loss_val = sess.run([train_step, loss], feed_dict=feed_dict)

                if iter % 20 == 0:
                    print('global %d, epoch %d, iter %d: loss : %.8f ' % (global_step_val, epoch, iter, loss_val))

                iter += 1

            model_dir = model_root + '/CVACT/' +str(radius)+'/' + str(epoch) + '/'

            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
            save_path = saver.save(sess, model_dir + 'model.ckpt')
            print("Model saved in file: %s" % save_path)

            # ---------------------- validation ----------------------
            if epoch % 10 == 0:
                print('validate...')
                print('   compute global descriptors')
                input_data.reset_scan()

                val_i = 0
                while True:
                    batch_sat, batch_grd, _ = input_data.next_batch_scan(batch_size)
                    if batch_sat is None:
                        break
                    feed_dict = {sat_x: batch_sat, grd_x: batch_grd, keep_prob: 1.0}
                    sat_global_val, grd_global_val = \
                        sess.run([sat_global, grd_global], feed_dict=feed_dict)


                    sat_global_descriptor[val_i: val_i + sat_global_val.shape[0], :] = sat_global_val
                    grd_global_descriptor[val_i: val_i + grd_global_val.shape[0], :] = grd_global_val
                    val_i += sat_global_val.shape[0]

                grd_global_descriptor_to_use = grd_global_descriptor[0:input_data.valNum,:]

                print('   compute accuracy')
                dist_array = 2 - 2 * np.matmul(sat_global_descriptor, np.transpose(grd_global_descriptor_to_use))

                val_accuracy_global = np.zeros((1, 10))
                val_accuracy_local = np.zeros((1, 10))
                print('start')
                for i in range(1,11):
                    print(i)
                    val_accuracy_global[0, i-1] = validate(dist_array, i)
                    val_accuracy_local[0, i-1] = local_validation(dist_array, i, input_data)
                print('epoch',  epoch, 'global accuracy on validation set =', val_accuracy_global * 100.0)
                print('epoch',  epoch, 'local accuracy on validation set = ', val_accuracy_local * 100.0)
                
                loss_val = loss_on_val_set(dist_array, input_data)
                
                file = './' + str(radius) + '_accuracy.txt'
                with open(file, 'a') as file:
                    np.savetxt(file, val_accuracy_global, fmt='%4f', delimiter=',',newline='\n', header='val_global', comments=str(epoch)+'_')
                    np.savetxt(file, val_accuracy_local, fmt='%4f', delimiter=',',newline='\n', header='val_local', comments=str(epoch)+'_')
            



In [None]:
train(0, 100)