In [None]:
import tensorflow as tf

In [None]:
sess = tf.Session()
from keras import backend as K
K.set_session(sess)

from time import sleep

import keras
from tensorflow.python.keras.layers import *
from tensorflow.python.keras.layers import ELU
from keras.losses import binary_crossentropy
from tensorflow.python.ops import array_ops
from tensorflow.python.keras.layers import Conv2D, Lambda, Dense, Multiply, Add
from tensorflow.initializers import glorot_normal, lecun_normal
from scipy.ndimage import median_filter
from skimage.transform import resize

import pandas as pd
import numpy as np
from random import shuffle
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
import os
import random
import itertools
from tensorflow.contrib.framework import arg_scope
from keras.regularizers import l1
from tensorflow.layers import batch_normalization
from tensorflow.python.util import deprecation as deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False

In [None]:
%run ../../src/layers/zoneout.py
%run ../../src/layers/adabound.py
%run ../../src/layers/convgru.py
%run ../../src/layers/dropblock.py
%run ../../src/layers/extra_layers.py

In [None]:
def conv_relu(inp, 
                 is_training, 
                 scope,
                 kernel_size,
                 filters, 
                 stride = (1, 1),
                 activation = True,
                 use_bias = True):
    '''2D convolution, batch renorm, relu block, 3x3 drop block. 
       Use_bias must be set to False for batch normalization to work. 
       He normal initialization is used with batch normalization.
       RELU is better applied after the batch norm.
       DropBlock performs best when applied last, according to original paper.
          
    '''

    with tf.variable_scope(scope + "_conv"):
        conv = Conv2D(filters = filters, kernel_size = (kernel_size, kernel_size),  strides = stride,
                      activation = None, padding = 'valid', use_bias = use_bias,
                      kernel_initializer = tf.keras.initializers.he_uniform())(inp)
    if activation:
        conv = tf.nn.relu(conv)
    return conv

class ReflectionPadding2D(Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, s):
        """ If you are using "channels_last" configuration"""
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')
    
def resblock(inp, is_training, scope, filters):
    inp_pad = ReflectionPadding2D()(inp)
    conv = conv_relu(inp_pad, is_training, scope + "1", 3, filters, activation = True, use_bias = True)
    conv_pad = ReflectionPadding2D()(conv)
    conv2 =  conv_relu(conv_pad, is_training, scope + "2", 3, filters, activation = False, use_bias = True)
    conv2 = tf.multiply(conv2, tf.constant(0.10))
    add = tf.add(inp, conv2)
    return add
    
    

In [None]:
inp = tf.placeholder(tf.float32, shape=(None, None, None, 10))
bilinear_input = tf.placeholder(tf.float32, shape = (None, None, None, 6))
labels =  tf.placeholder(tf.float32, shape = (None, None, None, 6))
is_training = tf.placeholder_with_default(False, (), 'is_training')

In [None]:
depth = [2, 3, 4, 5, 6]
width = [32, 48, 64, 80, 96]

depth = 2
width = 32

inp_pad = ReflectionPadding2D()(inp)
conv = conv_relu(inp_pad, is_training, "in", 3, width, activation = True)

for d in range(depth):
    conv = resblock(conv, is_training, str(d), width)
    print(d, conv.shape)
    
conv = ReflectionPadding2D()(conv)
outconv = conv_relu(conv, is_training, "out", 3, 6, activation = False)
outconv = tf.nn.tanh(outconv)
skipconnect = tf.add(bilinear_input, outconv)

In [None]:
optimizer = tf.train.AdamOptimizer(5e-4)
loss_fn = tf.keras.losses.MAE(labels, skipconnect)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss_fn)   
        

init_op = tf.global_variables_initializer()
sess.run(init_op)
saver = tf.train.Saver(max_to_keep = 150)
print("The graph has been finalized")

In [None]:
path = "../../models/supres/nov-40k-swir/"
saver.restore(sess, tf.train.latest_checkpoint(path))
save_path = saver.save(sess, f"../../models/supres/nov-40k-swir/model")

In [None]:
meta_path = '../../models/supres/nov-40k-swir/' # Your .meta file
output_node_names = ['Add_2']    # Output nodes
#output_node_names = ['conv2d_12/Sigmoid']

with tf.Session() as sess:
    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path + "model.meta")

    # Load weights
    saver.restore(sess,tf.train.latest_checkpoint(meta_path))
    #output_node_names = [n.name for n in tf.get_default_graph().as_graph_def().node]
    #print(output_node_names)
    
    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('../../models/supres/nov-40k-swir/superresolve_graph.pb', 'wb') as f:
        f.write(frozen_graph_def.SerializeToString())