In [None]:
import tensorflow as tf
import skimage.io
import skimage.color
import numpy as np
from pprint import pprint

In [None]:
#Read CIFAR images

import read_cifar10 as cf10

#@read_data.restartable
def cifar10_dataset_generator(dataset_name, batch_size, restrict_size=1000):
    assert dataset_name in ['train', 'test']
    assert batch_size > 0 or batch_size == -1  # -1 for entire dataset
    
    X_all_unrestricted, y_all = (cf10.load_training_data() if dataset_name == 'train'
                                 else cf10.load_test_data())
    
    actual_restrict_size = restrict_size if dataset_name == 'train' else int(1e10)
    X_all = X_all_unrestricted[:actual_restrict_size]
    data_len = X_all.shape[0]
    batch_size = batch_size if batch_size > 0 else data_len
    
    X_all_padded = np.concatenate([X_all, X_all[:batch_size]], axis=0)
    y_all_padded = np.concatenate([y_all, y_all[:batch_size]], axis=0)
    
    for slice_i in range(math.ceil(data_len / batch_size)):
        idx = slice_i * batch_size
        #X_batch = X_all_padded[idx:idx + batch_size]
        X_batch = X_all_padded[idx:idx + batch_size]*255  # bugfix: thanks Zezhou Sun!
        y_batch = np.ravel(y_all_padded[idx:idx + batch_size])
        yield X_batch.astype(np.uint8), y_batch.astype(np.uint8)

cifar10_dataset_generators = {
    'train': cifar10_dataset_generator('train', 1000),
    'test': cifar10_dataset_generator('test', -1)
}


In [None]:
#Load cifar-10 data
cf10_tr=cf10.load_training_data()
cf10_tr_img=cf10_tr[0]
cf10_tr_label = cf10_tr[1]

cf10_test=cf10.load_test_data()
cf10_test_img=cf10_test[0]
cf10_test_label = cf10_test[1]

In [None]:


import skimage.io
def img2block(im):
    '''
    Image patching code. It patches a given RGB image into 32x32 blocks and returns a 4D array with size 
    [number_of_patches,32,32,3]
    '''
    im = im.astype(np.float32)
    row,col,color = im.shape
    im_bl=np.zeros((int(row*col/1024),32,32,3)).astype(np.float32)
    count=0
    for i in range(0,row-row%32,32):
        for j in range(0,col-col%32,32):
            im_bl[count,:,:,:]=im[i:i+32,j:j+32,:]
            count = count +1
    im_bl=im_bl/255.
    return im_bl

def block2img(img_blocks,img_size):
    '''
    Function for reconstructing the image back from patches
    '''
    row,col = img_size
    img=np.zeros((row,col,3)).astype(np.float32)
    n,k,l,c=img_blocks.shape
                 
    for i in range(0,int(row/k)):
        for j in range(0,int(col/k)):
            img[i*k:(i+1)*k,j*l:(j+1)*l,:]=img_blocks[int(i*col/k+j),:,:,:]
    return img

#Get the patches of lena image
lena_img = skimage.io.imread('../test_img/lena512color.tiff')
lena_32=img2block(lena_img)

        
        

In [None]:
#Function for converting a double to uint8
def convert2uint8(img):
    img[img>255]=255
    img[img<0]=0
    return img.astype(np.uint8)

# Part-2 - CNN

In [None]:
#Create the inputs in the desired format
x_tr = cf10_tr_img.astype(np.float32)#*255.
x_test = cf10_test_img.astype(np.float32)#*255.
x_test=x_test[:200,:,:,:]
img = skimage.io.imread('../test_img/lena512color.tiff')
img_32=img2block(img)

In [None]:
def cnn_autoencoder(x_,kernels1=[5,7],kernels2=[7,5],filters1=[16,128],filters2=[128,3],pool_size=[1,2,2,1]):
    '''
    Autoencoder network
    
    Inputs:
    x_ (tf.placeholder) : Input tensor
    kernels1 (1D array) : Size of the encoder kernels (assumed square kernels)
    kernels2 (1D array) : Size of the decoder kernels (assumed square kernels)
    filters1 (1D array) : Number of filters in encoder layers
    filters2 (1D array) : Number of filters in decoder layers
    pool_size (1D array): Pooling size in each layer. Its length must be equal to len(kernels1)+len(kernels2)
                          First len(kernels1) terms will be used as pooling layers of encoder/
                          Remainin terms will be used as unpooling layers of decoder
                          
    Returns:
    out_ (tf.placeholder)     : Output of the autoencoder without quantization in the middle
    out_quant (tf.placeholder): Output of the autoencoder with quantization in the middle
    '''
    out_=x_
    
    #Encoder
    for k in range(len(kernels1)):
        conv = tf.layers.conv2d(inputs=out_,
                                filters=filters1[k],
                                kernel_size=[kernels1[k],kernels1[k]],
                                padding="same",
                                activation=tf.nn.relu,
                                name='conv'+str(k))
        pool_now=pool_size[k]
        if(pool_now==1):
            out_=conv
        else:
            out_ = tf.layers.max_pooling2d(inputs=conv, 
                                           pool_size=[pool_now,pool_now], 
                                           strides=pool_now,
                                           name = 'pool'+str(k))
        
        
    #Quantization of output
    out_quant=tf.round(out_*256.)/256.

    #Decoder
    for k in range(len(kernels2)):
        with tf.variable_scope("deconv") as var_scope:
            pool_now=pool_size[k+len(kernels1)]
            if(pool_now==1):
                x_up=out_
                out_ = tf.layers.conv2d(inputs=x_up,
                                        filters=filters2[k],
                                        kernel_size=[kernels2[k],kernels2[k]],
                                        padding="same",
                                        activation=tf.nn.relu,
                                        name='deconv'+str(k))
                var_scope.reuse_variables() 
                x_quant_up=out_quant
                out_quant = tf.layers.conv2d(inputs=x_quant_up,
                                            filters=filters2[k],
                                            kernel_size=[kernels2[k],kernels2[k]],
                                            padding="same",
                                            activation=tf.nn.relu,
                                            name='deconv'+str(k))
            else:
                #Bilinear interpolation of images
                sh = out_.get_shape().as_list()
                x_up=tf.image.resize_images(out_,[sh[1]*pool_now,sh[2]*pool_now])
                #Convolution
                out_ = tf.layers.conv2d(inputs=x_up,
                                        filters=filters2[k],
                                        kernel_size=[kernels2[k],kernels2[k]],
                                        padding="same",
                                        activation=tf.nn.relu,
                                        name='deconv'+str(k))
                var_scope.reuse_variables() 
                x_quant_up=tf.image.resize_images(out_quant,[sh[1]*pool_now,sh[2]*pool_now])
                out_quant = tf.layers.conv2d(inputs=x_quant_up,
                                            filters=filters2[k],
                                            kernel_size=[kernels2[k],kernels2[k]],
                                            padding="same",
                                            activation=tf.nn.relu,
                                            name='deconv'+str(k))


    return out_,out_quant


In [None]:
def apply_classification_loss_mse(kernels1=[5,7],kernels2=[7,5],
                                 filters1=[16,128],filters2=[128,3],
                                pool_size=[1,2,2,1],learning_rate=1.,FT=False):
    '''
    MSE based autoencoder optimizer.
    
    Inputs:
    kernels1 (1D array) : Size of the encoder kernels (assumed square kernels)
    kernels2 (1D array) : Size of the decoder kernels (assumed square kernels)
    filters1 (1D array) : Number of filters in encoder layers
    filters2 (1D array) : Number of filters in decoder layers
    pool_size (1D array): Pooling size in each layer. Its length must be equal to len(kernels1)+len(kernels2)
                          First len(kernels1) terms will be used as pooling layers of encoder/
                          Remainin terms will be used as unpooling layers of decoder
    learning_rate(float): Learning rate of the optimizer
    FT (boolean)        : Boolean value for fine-tuning operation on decoder weights
    
    
    Returns:
    model_dict          : Dictionary of the required output files
    '''
    
    with tf.Graph().as_default() as g:
        with tf.device("/gpu:0"):  # use gpu:0 if on GPU
            x_ = tf.placeholder(tf.float32, [None, 32, 32, 3])
            (x_out,x_out_quant)=cnn_autoencoder(x_,pool_size=pool_size,kernels1=kernels1,filters1=filters1,
                                kernels2=kernels2,filters2=filters2)

            mse_loss1=tf.reduce_mean(tf.subtract(x_,x_out)**2)
            mse_loss2=tf.reduce_mean(tf.subtract(x_,x_out_quant)**2)
            
            trainer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            if(FT):
                with tf.variable_scope('deconv', reuse=True) as vs:
                    var_list=[v for v in tf.global_variables() if v.name.startswith(vs.name)]
                train_op = trainer.minimize(mse_loss1,var_list=var_list)
            else:
                train_op = trainer.minimize(mse_loss1)

    model_dict = {'graph': g, 'inputs': x_,'outputs':x_out, 'train_op': train_op, 'loss1': mse_loss1,'loss2': mse_loss2}
    
    return model_dict

In [None]:
def train_model(model_dict,x_32=img_32, train_every=100, test_every=200, load=False,
                learning_rate=1.,fname='cifar10_recon',outname='/tmp/cnn_autoencoder',ftname='/tmp/cnn_autoencoder'):
    '''
    Inputs:
    model_dict: Output of apply_classification_loss_mse
    x_tr      : Training images
    x_test    : Test Images
    x_32      : 32x32 patches of a big image
    load      : Boolean for loading the weights from pre-trained network
    fname     : Directory to save outputs
    outname   : Directory to save (load=False) or load (load=True) weights
    ftname    : Directory to save new weights when load+True
    '''
    with model_dict['graph'].as_default(), tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver=tf.train.Saver()
        if(load):
            saver.restore(sess, outname)
            print("Model loaded")
        else:
            sess.run(tf.global_variables_initializer())
        
        ids=[i for i in range(100)]
        for iter_i in range(20001):
            batch_xs = x_tr[ids,:,:,:] 
            ids=[(ids[0]+100+i)%x_tr.shape[0] for i in range(100)]
            sess.run(model_dict['train_op'], feed_dict={model_dict['inputs']: batch_xs})
            
            # test trained model
            if iter_i % train_every == 0:
                tf_feed_dict = {model_dict['inputs']: batch_xs}
                loss_val = sess.run(model_dict['loss1'], feed_dict={model_dict['inputs']: batch_xs})
                print('iteration %d\t train mse: %.3E\t'%(iter_i,loss_val))
                if iter_i % test_every == 0:
                    #tf_feed_dict = {x_: x_test}
                    loss_val1 = sess.run(model_dict['loss1'], feed_dict={model_dict['inputs']: x_test})
                    loss_val2 = sess.run(model_dict['loss2'], feed_dict={model_dict['inputs']: x_test})
                    print('iteration %d\t TEST MSE: %.3E\t TEST MSE(Quantized): %.3E\t'%(iter_i,loss_val1,loss_val2))
                    
                    img_block=sess.run(model_dict['outputs'], 
                                       feed_dict={model_dict['inputs']:img_32})
                    x_from_test=sess.run(model_dict['outputs'], 
                                         feed_dict={model_dict['inputs']:x_test[:5,:,:,:].reshape([-1,32,32,3])})
                    
                    img_recon=block2img(img_block,(512,512))
                    img_recon = convert2uint8(img_recon*255.)
                    skimage.io.imsave('../'+fname+'/img32_recon_'+str(int(iter_i/test_every))+'.tiff',img_recon)

                    for i in range(5):
                        img_recon=convert2uint8((255*x_from_test[i,:,:,:]).reshape([32,32,3])).astype(np.uint8)
                        skimage.io.imsave('../'+fname+'/test'+str(i)+'_'+str(int(iter_i/test_every))+'.tiff',img_recon)
                        
        saver = tf.train.Saver()
        if(load):
            outname=ftname
        save_path = saver.save(sess, outname)
        print("Model saved in file: %s" % save_path)
                


## CNN-AE with the best params

In [None]:
#YOU NEED TO CREATE A FOLDER NAMED 'cifar10_recon0' BEFORE RUNNING THAT CODE
tf.reset_default_graph()
model_dict=apply_classification_loss_mse(kernels1=[5,7,9,9],kernels2=[9,7,7,5],
                                     filters1=[128,64,16,4],filters2=[8,8,3,3],
                                     pool_size=[1,2,2,1,1,2,2,1],learning_rate=7e-5)
saver = train_model(model_dict, [], train_every=100, test_every=2000,load=False,
                    fname='cifar10_recon0',outname='/tmp/cnnx4_test0')

## CNN-AE-FT with the best params

In [None]:
#YOU NEED TO CREATE A FOLDER NAMED 'cifar10_recon1' BEFORE RUNNING THAT CODE
tf.reset_default_graph()
model_dict=apply_classification_loss_mse(kernels1=[5,7,9,9],kernels2=[9,7,7,5],
                                     filters1=[128,64,16,4],filters2=[8,8,3,3],
                                     pool_size=[1,2,2,1,1,2,2,1],learning_rate=7e-5)
saver = train_model(model_dict, [], train_every=100, test_every=2000,load=False,
                    fname='cifar10_recon1',outname='/tmp/cnnx4_test1')

In [None]:
#Code for fine-tuning
#YOU NEED TO CREATE A FOLDER NAMED 'cifar10_recon2' BEFORE RUNNING THAT CODE
img = skimage.io.imread('../test_img/lena512color.tiff')
img_32=img2block(img)
x_tr=img_32.copy()

tmp=x_tr.copy()
x_tr=img_32.copy()
idx=np.random.permutation(x_tr.shape[0])
x_tr=x_tr[idx,:,:,:]

tf.reset_default_graph()
model_dict=apply_classification_loss_mse(kernels1=[5,7,9],kernels2=[9,7,5],
                                     filters1=[64,16,3],filters2=[3,3,3],
                                     pool_size=[1,2,2,2,2,1],learning_rate=1e-6,FT=True)
saver = train_model(model_dict, [], train_every=100, test_every=1000,load=True,
                    fname='cifar10_recon2',outname='/tmp/cnnx4_test1',ftname='/tmp/cnnx4_test1_lion')

x_tr=tmp.copy()

In [None]:
#Apply the pre-trained netowrk on another image

tfsave ='/tmp/cnnx4_test0'
imgpath = '../test_img/lion.tiff'
outpath='../test_img/lion_recon2_convrealFTpx8.tiff'
with model_dict['graph'].as_default(), tf.Session() as sess:
#with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver()
    saver.restore(sess, tfsave)
    print("Model loaded")
    
    img = skimage.io.imread(imgpath)
    [w,l,c]=img.shape
    img_32=img2block(img)
    img_block = np.zeros(img_32.shape)
    
    n= np.floor(img_32.shape[0]/2000).astype(int)
    print(n)
    for i in range(0,n):
        print(str(i+1)+'th slice')
        img_block[i*2000:(i+1)*2000,:,:,:]=sess.run(model_dict['outputs'], 
                                    feed_dict={model_dict['inputs']:img_32[i*2000:(i+1)*2000,:,:,:]})
    img_block[n*2000:,:,:,:]=sess.run(model_dict['outputs'], 
                                    feed_dict={model_dict['inputs']:img_32[n*2000:,:,:,:]})
        
    img_recon=block2img(img_block,(w,l))
    img_recon = convert2uint8(img_recon*255.)
    skimage.io.imsave(outpath,img_recon)
