In [1]:
import numpy as np #matrix math
import math
import tensorflow as tf #machine learning
import matplotlib.pyplot as plt #plotting
%matplotlib inline

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


In [2]:
num_channels = 1
len_edge = 28
bs = 32
n_filters = [1, 16, 16]
h_dim = 133
latent_dim = 16

In [3]:
def weight_variable(shape, name):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial, name=name)

def bias_variable(shape, name):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial, name=name)
    
def fc_layer(inp, channels_in, channels_out, name='fc'):
    with tf.name_scope(name):
        w = tf.Variable(tf.zeros([channels_in, channels_out]), name='W')
        b = tf.Variable(tf.zeros([channels_out]), name='B')
        return tf.nn.relu(tf.matmul(inp, w) + b)

In [4]:
#Use TensorBoard to visualize: this code fails to create meaningful latent variables
#(within a batch, each looks the same), and therefore also fails at reconstruction.
#Both log-likelihood error and KL divergence appear to develop "normally", yet KL drops to 0 very quickly?

with tf.name_scope('Input'):
    X = tf.placeholder(tf.float32, shape=([None, len_edge, len_edge, num_channels]))#bs
    tf.summary.image('input_images', X, max_outputs=5)

cur_input = X

Ws = []    
shapes = []

for l, n_out in enumerate(n_filters[1:]):
    cur_name = 'conv' + str(l)    
    n_input = cur_input.get_shape().as_list()[3]
    shapes.append(cur_input.get_shape().as_list())
    with tf.name_scope('conv_indecon_' + str(l)):
        W = tf.Variable(tf.random_uniform([3, 3, n_input, n_out], -1.0/math.sqrt(n_input), 1.0/math.sqrt(n_input)), name='weights')
        b = tf.Variable(tf.zeros([n_out]), name='bias')
        Ws.append(W)   
        conv = tf.nn.conv2d(cur_input, W, strides=[1, 2, 2, 1], padding='SAME')
        act = tf.nn.sigmoid(conv+b)#tf.nn.relu(conv,b)    
    cur_input = act

with tf.name_scope('Dense'):
    with tf.name_scope('Fully_Encode'):
        flattened = tf.reshape(cur_input, [-1, 7 * 7 * n_filters[-1]])# ...
        full1 = fc_layer(flattened, 7 * 7 * n_filters[-1], h_dim, 'fc1')
    with tf.name_scope('Mu'):
        W_mu = weight_variable([h_dim, latent_dim], 'W_mu')
        b_mu = bias_variable([latent_dim], 'b_mu')
        mu = tf.matmul(full1, W_mu) + b_mu
    with tf.name_scope('Logstd'):
        W_logstd = weight_variable([h_dim, latent_dim], 'W_logstd')
        b_logstd = bias_variable([latent_dim], 'b_logstd')
        logstd = tf.matmul(full1, W_logstd) + b_logstd
    with tf.name_scope('VAE_final'):
        noise = tf.random_normal([1, latent_dim])
        z = mu + tf.multiply(noise, tf.exp(.5*logstd))
    with tf.name_scope('Fully_Decode'):
        full2 = fc_layer(z, latent_dim, h_dim, 'fc2')
        full3 = fc_layer(full2, h_dim, 7 * 7 * n_filters[-1], 'fc3')
    reshaped = tf.reshape(full3, [-1, 7, 7, n_filters[-1]])
        
z_visual = tf.reshape(z, [-1, 4, 4, 1])
tf.summary.image('latents', z_visual, max_outputs=5)
tf.summary.histogram('Latent', z)

Ws.reverse()
shapes.reverse()
cur_input = reshaped

for l, shape in enumerate(shapes):
    cur_name = 'deconv' + str(l)  
    W = Ws[l]    
    with tf.name_scope('conv_indecon_' + str(len(Ws)-(l+1))):
        b = tf.Variable(tf.zeros([W.get_shape().as_list()[2]]), name='bias_dec_'+str(l))
        dec = tf.nn.conv2d_transpose(cur_input, W, tf.stack([tf.shape(X)[0], shape[1], shape[2], shape[3]]), strides=[1,2,2,1], padding='SAME')
        if l+1 < len(shapes):
            act = tf.nn.sigmoid(dec+b)
            cur_input = act

with tf.name_scope('reconst'):
    reconstruction = tf.nn.sigmoid(dec + b)

tf.summary.image('reconstructed_images', reconstruction, max_outputs=5)

<tf.Tensor 'reconstructed_images:0' shape=() dtype=string>

In [5]:
log_likelihood = tf.reduce_sum(X*tf.log(reconstruction + 1e-9)+(1 - X)*tf.log(1 - reconstruction + 1e-9), reduction_indices=1)
tf.summary.scalar('LogLike', tf.reduce_mean(log_likelihood))

KL_term = -.5*tf.reduce_sum(1 + 2*logstd - tf.pow(mu,2) - tf.exp(2*logstd), reduction_indices=1)
tf.summary.scalar('KL', tf.reduce_mean(KL_term))

variational_lower_bound = tf.reduce_mean(log_likelihood - KL_term)
tf.summary.scalar('cost', variational_lower_bound)

#optimizer = tf.train.AdadeltaOptimizer().minimize(-variational_lower_bound)
optimizer = tf.train.AdamOptimizer(1e-4).minimize(-variational_lower_bound)

In [6]:
merged_summary = tf.summary.merge_all()
init = tf.global_variables_initializer()
sess = tf.InteractiveSession()
sess.run(init)
saver = tf.train.Saver()
writer = tf.summary.FileWriter('./vae_logs/2')
writer.add_graph(sess.graph)

In [7]:
num_iterations = 1000000
recording_interval = 10
variational_lower_bound_array = []
log_likelihood_array = []
KL_term_array = []
iteration_array = [i*recording_interval for i in range(num_iterations/recording_interval)]

for i in range(num_iterations):
    mn_l = mnist.train.next_batch(bs)[0]
    x_batch = np.reshape(mn_l, [bs, 28, 28, 1])
#    x_batch = sample_a_batch(filename, bs, 64, sb=2, normalize=stats, flattened=False)
    _, s = sess.run([optimizer, merged_summary], feed_dict={X: x_batch})
    writer.add_summary(s, i)
    if (i%recording_interval == 0):
        #every 1K iterations record these values
        vlb_eval = variational_lower_bound.eval(feed_dict={X: x_batch})
        print "Iteration: {}, Loss: {}".format(i, vlb_eval)

Iteration: 0, Loss: -22.8667087555
Iteration: 10, Loss: -22.5926856995
Iteration: 20, Loss: -22.4780864716
Iteration: 30, Loss: -22.1377544403
Iteration: 40, Loss: -22.0238437653
Iteration: 50, Loss: -21.8365535736
Iteration: 60, Loss: -21.6610164642
Iteration: 70, Loss: -21.4929962158
Iteration: 80, Loss: -21.2866859436
Iteration: 90, Loss: -21.0990543365
Iteration: 100, Loss: -20.9184150696
Iteration: 110, Loss: -20.7674388885
Iteration: 120, Loss: -20.5729026794
Iteration: 130, Loss: -20.3891487122
Iteration: 140, Loss: -20.2836494446
Iteration: 150, Loss: -20.0570163727
Iteration: 160, Loss: -19.878156662
Iteration: 170, Loss: -19.8642692566
Iteration: 180, Loss: -19.6327972412
Iteration: 190, Loss: -19.5402317047
Iteration: 200, Loss: -19.3276462555
Iteration: 210, Loss: -19.1888599396
Iteration: 220, Loss: -19.0823688507
Iteration: 230, Loss: -18.9490509033
Iteration: 240, Loss: -18.807806015
Iteration: 250, Loss: -18.7744865417
Iteration: 260, Loss: -18.6289558411
Iteration: 270

KeyboardInterrupt: 

**NEXT CELLS TO BE IGNORED**

In [None]:
#filename = '/run/media/ron/silver_small/twelve_months/3d/S1A_IW_GRDH_1SDV_20160325T083601_20160325T083630_010523_00FA23_6F51.tif'
#stats = normalization_parameters(filename)

def conv_layer(inp, channels_in, channels_out, vscope, name='conv'):
    with tf.name_scope(name):
        with tf.variable_scopec(vscope):
            tf.Variable(tf.zeros([3, 3, channels_in, channels_out]), name='W')
        b = tf.Variable(tf.zeros([channels_out]), name='B')
        conv = tf.nn.conv2d(inp, w, strides=[1, 1, 1, 1], padding='SAME')
        act = tf.nn.relu(conv + b)
        tf.summary.histogram('weights', w)
        tf.summary.histogram('biases', b)
        tf.summary.histogram('activations', act)
        return tf.nn.max_pool(act, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    
def conv_t_layer(inp, channels_in, channels_out, name='conv_t'):#, activation=tf.nn.relu):
    with tf.name_scope(name):
        #w = tf.Variable(tf.zeros([3, 3, channels_out, channels_in]), name='W')
        w = conv1
        b = tf.Variable(tf.zeros([conv1.get_shape().as_list()[2]]))
        
        tf.nn.conv2d_transpose(inp, W, tf.stack([]))
        
        
        
        b = tf.Variable(tf.zeros([channels_out]), name='B')
        batch_size = 32 #tf.shape(inp)[0]
        deconv_shape = tf.stack([batch_size, inp.shape[2].value * 2, inp.shape[2].value * 2, channels_out])
        conv_t = tf.nn.conv2d_transpose(inp, w, deconv_shape, strides=[1, 2, 2, 1], padding='SAME')
        #act = activation(conv_t + b)
        act = tf.nn.relu(conv_t + b)
        tf.summary.histogram('weights', w)
        tf.summary.histogram('biases', b)
        tf.summary.histogram('activations', act)
        return act#tf.nn.max_pool(act, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    
def sample_a_batch(filebase, batchsize, tilesize=128, sb=0, normalize=False, flattened=True):
    # can sample for spatio-temporal (single_file = False), and spatial-only case (single_file = True).
    import numpy as np
    from osgeo import gdal
    
    if '.' in filebase:
        single_file = True
        S = gdal.Open(filebase)
    else:
        single_file = False
        S = gdal.Open(filebase + '_1.vrt')
        
    samples = []
    
    if single_file:
        while len(samples) < batchsize:
            RX = np.random.randint(S.RasterXSize-tilesize,size=1)
            RY = np.random.randint(S.RasterYSize-tilesize,size=1)
            
            #A = np.transpose(S.ReadAsArray(RX[0], RY[0], tilesize, tilesize))           
            
            if sb:
                B = S.GetRasterBand(sb)
                A = np.transpose(B.ReadAsArray(RX[0], RY[0], tilesize, tilesize))                
                #print np.min(A)
                if np.min(A) > 0:
                    if normalize:
                        A = A / normalize[2][sb-1]                        
                    A = np.expand_dims(A, 2)                    
                    if flattened:
                        A = A.flatten()                        
                    samples.append(A)
            else:
                A = np.transpose(S.ReadAsArray(RX[0], RY[0], tilesize, tilesize))
                if np.min(A) > 0:
                    samples.append(A)
        
    else: # must be overhauled
        while len(samples) < batchsize:
            RX = np.random.randint(S.RasterXSize-tilesize,size=1)
            RY = np.random.randint(S.RasterYSize-tilesize,size=1) 
            
            skip_loc = False
            months = []
            
            for m in range(1,13):
                S = gdal.Open(filebase + '_' + str(m) + '.vrt')
                A = np.transpose(S.ReadAsArray(RX[i], RY[i], tilesize, tilesize))
                if np.min(A) == 0.0:
                    skip_loc = True
                    break
                else:
                    months.append(A)                
            if not skip_loc:
                months = np.array(months)
                samples.append(months)
        
    return np.array(samples)

def normalization_parameters(fn):
    from osgeo import gdal
    S = gdal.Open(fn)
    mns = []
    sds = []
    maxs = []
    
    for b in range(S.RasterCount):
        B = S.GetRasterBand(b+1)
        mn, sd = B.ComputeStatistics(1)[2:4]
        mns.append(mn)
        sds.append(sd)
        maxs.append(B.GetMaximum())
        
    return mns, sds, maxs

In [None]:
with tf.name_scope('Input'):
    X = tf.placeholder(tf.float32, shape=([None, len_edge, len_edge, num_channels]))#bs
    tf.summary.image('input_images', X, max_outputs=3)

conv1 = conv_layer(X, num_channels, filters[0], 'conv1')
conv2 = conv_layer(conv1, conv1.shape[-1].value, filters[1], 'conv2')
conv3 = conv_layer(conv2, conv2.shape[-1].value, filters[2], 'conv3')
conv4 = conv_layer(conv3, conv3.shape[-1].value, filters[3], 'conv4')

#with tf.name_scope('Dense_Encode'):
flattened = tf.reshape(conv4, [-1, 2 * 2 * filters[-1]])
full1 = fc_layer(flattened, 2 * 2 * filters[-1], h_dim, 'fc1') # channels_in ?????

W_mu = weight_variable([h_dim, latent_dim], 'W_mu')
b_mu = bias_variable([latent_dim], 'b_mu')
mu = FC_layer(full1, W_mu, b_mu)
W_logstd = weight_variable([h_dim, latent_dim], 'W_logstd')
b_logstd = bias_variable([latent_dim], 'b_logstd')
logstd = FC_layer(full1, W_logstd, b_logstd)

noise = tf.random_normal([1, latent_dim])
z = mu + tf.multiply(noise, tf.exp(.5*logstd))

z_visual = tf.reshape(z, [-1, 6, 6, 1])
tf.summary.image('latents', z_visual, max_outputs=3)

tf.summary.histogram('Latent', z)

full2 = fc_layer(z, latent_dim, h_dim, 'fc2')
full3 = fc_layer(full2, h_dim, 2 * 2 * filters[-1], 'fc3')# ???????

reshaped = tf.reshape(full3, [-1, 2, 2, filters[-1]])

conv_t1 = conv_t_layer(reshaped, filters[1], filters[1], 'conv_t1')
conv_t2 = conv_t_layer(conv_t1, filters[1], filters[1], 'conv_t2')
conv_t3 = conv_t_layer(conv_t2, filters[1], filters[1], 'conv_t3')
reconstruction = conv_t_layer(conv_t3, filters[1], 1, 'conv_t4')#, tf.nn.sigmoid)

tf.summary.image('reconstructed_images', reconstruction, max_outputs=3) #............