In [24]:
import tensorflow as tf
import data_gen
import os
import numpy as np
import h5py
import random

In [25]:
def conv2d(x,W):
    # x --> [image_batch,width,height,color_channels]
    # W --> [filter_height,filter_width,channels_in,channels_out]
    return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')

def relu(x):
    # x --> features
    return tf.nn.relu(x)

In [26]:
image_size = 128
color_channels = 3

In [37]:
class SRCNN:
    
    def __init__(self,
                 sess,
                 image_size,
                 color_channels,
                 num_convolutions = 3,
                 filters=[[9,9,3,64],[1,1,64,32],[5,5,32,3]],
                 bias_shapes=[[64],[32],[3]],
                 train_data_dir="D:/TestData/train/"):
        
        self.sess = sess
        self.image_size = image_size
        self.color_channels = color_channels
        assert num_convolutions == len(filters), "Number of convs does not match number of filters"
        assert len(filters) == len(bias_shapes), "Number of filters does not match number of biases"
        self.num_convolutions = num_convolutions
        self.filters = filters
        self.bias_shapes = bias_shapes
        self.train_data_dir = train_data_dir
    
    def test(self):
        print("Hello world")
    
    def init_bias_zero(self,shape,name):
        return tf.Variable(tf.zeros(shape),name=name)
    
    def init_weights_normal(self,shape,name,stddev=0.0001):
        return tf.Variable(tf.random_normal(shape,stddev=stddev),name=name)
    
    def initialize(self):
        
        # from https://arxiv.org/abs/1501.00092 there are filter sizes of 9-1-5 in 3 different convolutions
        # weights shapes: [9,9,1?3,64] --> [1,1,64,32] --> [5,5,32,1?3]
        # biases shaped [64] --> [32] --> [1?3]
        # Best results when trained on Y-channel (c=1) only or RGB jointly (c=3)
        
        # initialize weights and biases
        self.weights = []
        self.biases = []
        
        for i in range(self.num_convolutions):
            self.weights.append(self.init_weights_normal(self.filters[i],("w"+str(i))))
        
        for j in range(self.num_convolutions):
            self.biases.append(self.init_bias_zero(self.bias_shapes[j],("b"+str(j))))
        
        default = '''
        w1 = _init_weights_normal([9,9,3,64],"w1")
        w2 = _init_weights_normal([1,1,64,32],"w2")
        w3 = _init_weights_normal([5,5,32,3],"w2")
        b1 = _init_bias_zero([64],"b1")
        b2 = _init_bias_zero([32],"b2")
        b3 = _init_bias_zero([1],"b3")
        '''
        
        # placeholders for images (input) and the predictions (output)
        self.images = tf.placeholder('float32',[None,image_size,image_size,color_channels],name='images')
        self.labels = tf.placeholder('float32',[None,image_size,image_size,color_channels],name='labels')
        
        self.prediction = self.model(self.num_convolutions)
        
        # mse loss (TODO: implement SSIM)
        self.loss = tf.reduce_mean(tf.square(self.labels - self.prediction))
    
    def model(self,conv):
        test = '''
        conv -= 1 #to match array indexes
        if conv == 0:
            # first convolutional layer
            layer_first = relu(conv2d(self.images,self.weights[0]) + self.biases[0])
            return layer_first
        elif conv == self.num_convolutions-1:
            # last convolutional layer (no relu)
            layer_final = conv2d(self.model(conv-1),self.weights[conv]) + self.biases(conv)
            return layer_final
        else:
            # any intermediate convolutional layers
            layer_intermediate = relu(conv2d(self.model(conv-1),self.weights[conv]) + biases[conv])
            return layer_intermediate
        '''
        conv_layer = relu(conv2d(self.images,self.weights[0]) + self.biases[0])
        conv_layer = relu(conv2d(conv_layer,self.weights[1]) + self.biases[1])
        conv_layer = conv2d(conv_layer,self.weights[2] + self.biases[2])
        return conv_layer
    
    def train(self, batch_size=64, epochs=500):
        train_sets = os.listdir(self.train_data_dir)
        #TODO: reprocess data, remember to MAKE BACKUP
        current_train_set = train_sets[0]
        images, labels = data_gen.read_data(self.train_data_dir + current_train_set)
        # see optimizers at https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/train
        self.train_op = tf.train.GradientDescentOptimizer(learning_rate=0.00001).minimize(self.loss)
        
        init = tf.global_variables_initializer()
        
        self.sess.run(init)
        
        steps = len(images) // batch_size
        
        #each epoch is a single iteration over (almost) the entire dataset batch
        for epoch in range(epochs):
            for step in range(steps):
                image_batch = images[step*batch_size:(step+1)*batch_size]
                label_batch = labels[step*batch_size:(step+1)*batch_size]
                
                self.sess.run(self.train_op,feed_dict={self.images:image_batch, self.labels:label_batch})
                
                if step % 5 == 0:
                    loss = self.sess.run(self.loss,feed_dict={self.images:image_batch, self.labels:label_batch})
                    print("Epoch: "+str(epoch)+", Step: "+str(step)+", Loss: "+str(loss))
    


In [28]:
#train_sets = os.listdir("D:/TestData/train/")

In [29]:
#train_sets

['train_00.h5',
 'train_01.h5',
 'train_02.h5',
 'train_03.h5',
 'train_04.h5',
 'train_05.h5',
 'train_06.h5',
 'train_07.h5']

In [36]:
#print("D:/TestData/train/"+train_sets[0])
#with h5py.File("D:/TestData/train/"+train_sets[2], 'r') as file:
#    print(file.keys())

D:/TestData/train/train_00.h5
<KeysViewHDF5 ['images', 'labels']>
