In [None]:
import numpy as np
import tensorflow as tf
from collections import OrderedDict

In [None]:
def normalization_parameters(fn):
    from osgeo import gdal
    S = gdal.Open(fn)
    mns = []
    sds = []
    maxs = []
    
    for b in range(S.RasterCount):
        B = S.GetRasterBand(b+1)
        mn, sd = B.ComputeStatistics(1)[2:4]
        mns.append(mn)
        sds.append(sd)
        maxs.append(B.GetMaximum())
        
    return([mns, sds, maxs])

def sample_1s2(filebase, batchsize, tilesize=128, normalize=False, flattened=False):
    import numpy as np
    from osgeo import gdal   
    
    if type(filebase) == str:
        multires = False
        S2_10 = gdal.Open(filebase)
        if normalize:
            maxima_10 = np.array(normalization_parameters(filebase)[2])
        
    elif type(filebase) == list:#interpreted as multiple resolutions
        tilesize = tilesize/2
        multires = True
        S2_10 = gdal.Open(filebase[0])
        S2_20 = gdal.Open(filebase[1])
        if normalize:
            maxima_10 = np.array(normalization_parameters(filebase[0])[2])
            maxima_20 = np.array(normalization_parameters(filebase[1])[2])
    
    samples_10 = []
    
    if multires:        
        samples_20 = []
        
        while len(samples_10) < batchsize:
            RX = np.random.randint(S2_20.RasterXSize-tilesize,size=1)
            RY = np.random.randint(S2_20.RasterYSize-tilesize,size=1)
            
            A_10 = np.transpose(S2_10.ReadAsArray(RX[0] * 2, RY[0] * 2, tilesize * 2, tilesize * 2)).astype(np.float32)
            A_20 = np.transpose(S2_20.ReadAsArray(RX[0], RY[0], tilesize, tilesize)).astype(np.float32)
            
            if (np.min(A_10) > 0) & (np.min(A_20) > 0):
                if normalize:
                    A_10 = A_10 / maxima_10
                    A_20 = A_20 / maxima_20
                if flattened:
                    A_10 = A_10.flatten()
                    A_20 = A_20.flatten()
                    
                samples_10.append(A_10)
                samples_20.append(A_20)
                
        return([np.array(samples_10), np.array(samples_20)])
                
    else:
        while len(samples_10) < batchsize:
            RX = np.random.randint(S2_10.RasterXSize-tilesize,size=1)
            RY = np.random.randint(S2_10.RasterYSize-tilesize,size=1)
            
            A_10 = np.transpose(S2_10.ReadAsArray(RX[0], RY[0], tilesize, tilesize).astype(np.float32))
            
            if np.min(A_10) > 0:            
                if normalize:
                    A_10 = A_10 / maxima_10
                if flattened:
                    A_10 = A_10.flatten()
                    
                samples_10.append(A_10)
                
        return(np.array(samples_10))

In [None]:
len_edge = 32
num_channels_10 = 4
num_channels_20 = 6
num_channels_S1 = 2

fn_10 = 
fn_20 = 
fn_S1 = 

In [None]:
def convolute(inp, name, kernel_size = 3, out_chans = 64, sz = 1):
    inp_chans = inp.get_shape().as_list()[-1]
    with tf.variable_scope(name) as scope:
        W = tf.get_variable('weights', [kernel_size, kernel_size, inp_chans, out_chans], initializer=tf.contrib.layers.xavier_initializer_conv2d(), regularizer=tf.contrib.layers.l2_regularizer(0.0005))#, name='weights')
        b = tf.get_variable('biases', [out_chans], initializer=tf.constant_initializer(0.0), regularizer=None, dtype=tf.float32)
        conv = tf.nn.conv2d(inp, W, strides=[1, sz, sz, 1], padding='SAME')
        conv = tf.contrib.layers.batch_norm(conv, scope=scope) # train?
        conv = tf.nn.relu(conv+b)
#        conv = tf.nn.dropout(conv, 0.8)
    return conv

def pooling(inp, name, factor=2):
    pool = tf.nn.max_pool(inp, ksize=[1, factor, factor, 1], strides=[1, factor, factor, 1], padding='SAME', name=name)
    return pool

In [None]:
#with tf.name_scope('Input'):
#    X = tf.placeholder(tf.float32, shape=([None, len_edge, len_edge, num_channels]))#bs
#    if num_channels == 2:
#        X_show = tf.concat([X, tf.expand_dims(X[:, :, :, 1], 3)], axis=3)
#        tf.summary.image('input_images', X_show, max_outputs=tb_imgs_to_display)
#    elif num_channels > 3:
#        tf.summary.image('input_images', X[:, :, :, 0:3], max_outputs=tb_imgs_to_display)
#    else:
#        tf.summary.image('input_images', X, max_outputs=tb_imgs_to_display)
        
        
X_10 = tf.placeholder(tf.float32, shape=([None, len_edge, len_edge, num_channels_10]))
X_20 = tf.placeholder(tf.float32, shape=([None, len_edge / 2, len_edge / 2, num_channels_20]))
X_S1 = tf.placeholder(tf.float32, shape=([None, len_edge, len_edge, num_channels_S1]))

dw_h_convs = OrderedDict()
up_h_convs = OrderedDict()

X_20_c = convolute(X_20, layer_name('conv'), 3, outer, sz = 1)
X_20_c = convolute(X_20_c, layer_name('conv'), 3, outer, sz = 1)

X_10_c = convolute(X_10, layer_name('conv'), 3, outer, sz = 1)
X_10_c = convolute(X_10_c,layer_name('conv'), 3, outer, sz = 1)
dw_h_convs[0] = pooling(dw_h_convs[0], 'pool1')


dw_h_convs[1] = tf.concat([dw_h_convs[0], X_20_c], 3)

dw_h_convs[1] = convolute(dw_h_convs[1],layer_name('conv'),3,outer)
dw_h_convs[1] = convolute(dw_h_convs[1],layer_name('conv'),3,outer)
dw_h_convs[2] = pooling(dw_h_convs[1], 'pool2')

dw_h_convs[2] = convolute(dw_h_convs[2],layer_name('conv'),3,middle)
dw_h_convs[2] = convolute(dw_h_convs[2],layer_name('conv'),3,middle)
dw_h_convs[3] = pooling(dw_h_convs[2], 'pool3')

dw_h_convs[3] = convolute(dw_h_convs[3],layer_name('conv'),3,inner)
dw_h_convs[3] = convolute(dw_h_convs[3],layer_name('conv'),3,inner)
dw_h_convs[4] = pooling(dw_h_convs[3], 'pool4')

dw_h_convs[4] = convolute(dw_h_convs[4],layer_name('conv'),3,innest)
dw_h_convs[4] = convolute(dw_h_convs[4],layer_name('conv'),3,inner)



up_h_convs[0] = tf.image.resize_images(dw_h_convs[4], [ dw_h_convs[4].get_shape().as_list()[1]*2, 
                                                            dw_h_convs[4].get_shape().as_list()[2]*2] )
up_h_convs[0] = tf.concat([up_h_convs[1], dw_h_convs[3] ],3 )
up_h_convs[0] = convolute(up_h_convs[0], layer_name('conv'), 3, inner)
up_h_convs[0] = convolute(up_h_convs[0], layer_name('conv'), 3, middle)

up_h_convs[1] = tf.image.resize_images(up_h_convs[0], [ up_h_convs[0].get_shape().as_list()[1]*2, 
                                                            up_h_convs[0].get_shape().as_list()[2]*2] ) 
up_h_convs[1] = tf.concat([up_h_convs[1], dw_h_convs[2] ],3 ) 
up_h_convs[1] = convolute(up_h_convs[1], layer_name('conv'), 3, middle)
up_h_convs[1] = convolute(up_h_convs[1], layer_name('conv'), 3, outer)

up_h_convs[2] = tf.image.resize_images(up_h_convs[1], [ up_h_convs[1].get_shape().as_list()[1]*2, 
                                                            up_h_convs[1].get_shape().as_list()[2]*2] )
up_h_convs[2] = tf.concat([up_h_convs[2], dw_h_convs[1] ],3 ) 
up_h_convs[2] = convolute(up_h_convs[2], layer_name('conv'), 3, outer)
up_h_convs[2] = convolute(up_h_convs[2], layer_name('conv'), 3, outer)

up_h_convs[3] = tf.image.resize_images(up_h_convs[2], [ up_h_convs[2].get_shape().as_list()[2]*2, 
                                                            up_h_convs[2].get_shape().as_list()[2]*2] )
up_h_convs[3] = tf.concat([up_h_convs[3], X_10_c],3 ) 
up_h_convs[3] = convolute(up_h_convs[3], layer_name('conv'), 3, outer)
up_h_convs[3] = convolute(up_h_convs[3], layer_name('conv'), 3, outer)

W_rec = tf.get_variable('weights_rec', [1, 1, outer, num_channels_S1], initializer=tf.contrib.layers.xavier_initializer_conv2d(), regularizer=False)#, name='weights')
b_rec = tf.get_variable('biases_rec', [num_channels_S1], initializer=tf.constant_initializer(0.0), regularizer=None, dtype=tf.float32)
reconstruction = tf.nn.sigmoid(tf.nn.conv2d(up_h_convs[3], W_rec, strides=[1, 1, 1, 1], padding='SAME') + b_rec)

In [None]:
X_flat = tf.contrib.layers.flatten(X_S1)
R_flat = tf.contrib.layers.flatten(reconstruction)
log_likelihood = tf.reduce_sum(X_flat*tf.log(R_flat + 1e-9)+(1 - X_flat)*tf.log(1 - R_flat + 1e-9), reduction_indices=1)

tf.summary.scalar('LogLike', tf.reduce_mean(log_likelihood))

optimizer_likeli = tf.train.AdamOptimizer(1e-4).minimize(-log_likelihood)

In [None]:
merged_summary = tf.summary.merge_all()
init = tf.global_variables_initializer()
sess = tf.InteractiveSession()
sess.run(init)
saver = tf.train.Saver()
train_writer = tf.summary.FileWriter(trainlogpath)
valid_writer = tf.summary.FileWriter(testlogpath)
#gen_writer = tf.summary.FileWriter(genlogpath)
train_writer.add_graph(sess.graph)

In [None]:
all_10, all_20, all_S1 = sample(...)
#X_all = sample_1s2(fn3, train_size, tilesize=len_edge, normalize=True, flattened=False)
#x_valid = sample_1s2(fn3, valid_size, tilesize=len_edge, normalize=True, flattened=False)

step = bs
ep = 0
for ep in range(epochs):
#    print ep
    x_batch = X_all[step-bs:step]
#    x1_batch = X1_all[step-bs:step]
#    x2_batch = X2_all[step-bs:step]
    
    if (ep%recording_interval == 0):
        vvv, s = sess.run([validator, merged_summary], feed_dict={X: x_valid})
#        vvv, s = sess.run([validator, merged_summary], feed_dict={X: X1_valid, X2: X2_valid})
        valid_writer.add_summary(s, ep)
        
    _, s = sess.run([optimizer, merged_summary], feed_dict={X: x_batch})
#    _, s = sess.run([optimizer, merged_summary], feed_dict={X: x1_batch, X2: x2_batch})
    
    if (ep%recording_interval == 0):
        train_writer.add_summary(s, ep)
        
    step += bs
    
    if step == train_size:
        step = bs