In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras import Sequential, layers
from tensorflow.keras.layers import Layer
from tensorflow.keras.callbacks import ModelCheckpoint

# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

### Helper Functions

In [None]:
# an implication of Pytorch CrossMapLRN2d with Keras
class LRN2D(Layer):
    """
    This code is adapted from pylearn2.
    License at: https://github.com/lisa-lab/pylearn2/blob/master/LICENSE.txt
    """

    def __init__(self, alpha=1e-4, k=2, beta=0.75, n=5):
        if n % 2 == 0:
            raise NotImplementedError("LRN2D only works with odd n. n provided: " + str(n))
        super(LRN2D, self).__init__()
        self.alpha = alpha
        self.k = k
        self.beta = beta
        self.n = n

    def get_output(self, train):
        X = self.get_input(train)
        b, ch, r, c = X.shape
        half_n = self.n // 2
        input_sqr = T.sqr(X)
        extra_channels = T.alloc(0., b, ch + 2*half_n, r, c)
        input_sqr = T.set_subtensor(extra_channels[:, half_n:half_n+ch, :, :], input_sqr)
        scale = self.k
        for i in range(self.n):
            scale += self.alpha * input_sqr[:, i:i+ch, :, :]
        scale = scale ** self.beta
        return X / scale

    def get_config(self):
        return {"name": self.__class__.__name__,
                "alpha": self.alpha,
                "k": self.k,
                "beta": self.beta,
                "n": self.n}

    

# another implication of Pytorch CrossMapLRN2d with Keras
class LocalResponseNormalization(Layer):
  
    def __init__(self, n=5, alpha=1e-4, beta=0.75, k=2, **kwargs):
        self.n = n
        self.alpha = alpha
        self.beta = beta
        self.k = k
        super(LocalResponseNormalization, self).__init__(**kwargs)

    def build(self, input_shape):
        self.shape = input_shape
        super(LocalResponseNormalization, self).build(input_shape)

    def call(self, x):
        _, r, c, f = self.shape 
        squared = K.square(x)
        pooled = K.pool2d(squared, (self.n, self.n), strides=(1,1), padding="same", pool_mode='avg')
        summed = K.sum(pooled, axis=3, keepdims=True)
        averaged = self.alpha * K.repeat_elements(summed, f, axis=3)
        denom = K.pow(self.k + averaged, self.beta)
        return x / denom 
    
    def compute_output_shape(self, input_shape):
        return input_shape

### Base Net

In [9]:
base = Sequential()

base.add(layers.Conv2D(filters=96,
                        kernel_size=11,
                        strides=4,
                        padding='valid',
                        activation='relu',
                        input_shape=(224,224,3)))

base.add(layers.MaxPool2D(pool_size=3, strides=2))

base.add(LRN2D(n=5, alpha=1e-4, beta=0.75, k=1.0))

base.add(layers.ZeroPadding2D(padding=2))

base.add(layers.Conv2D(filters=256,
                        kernel_size=5,
                        strides=1,
                        padding='valid',
                        #groups=2,
                        activation='relu'))

base.add(layers.MaxPool2D(pool_size=3, strides=2))

base.add(LRN2D(n=5, alpha=1e-4, beta=0.75, k=1.0))

base.add(layers.Conv2D(filters=384,
                        kernel_size=3,
                        strides=1,
                        padding='same',
                        activation='relu'))

base.add(layers.Conv2D(filters=64,
                        kernel_size=1,
                        strides=1,
                        padding='valid',
                        activation='relu'))


### Eye Net

In [None]:
eye = Sequential()

eye.add(base)


### Face Net

In [11]:
face = Sequential()

face.add(base)
face.add(layers.Dense(128, activation = 'relu'))
face.add(layers.Dense(64, activation = 'relu'))


### Grid Net

In [None]:
grid = Sequential()

grid.add(layers.Flatten())
grid.add(layers.Dense(256, activation = 'relu'))
grid.add(layers.Dense(128, activation = 'relu'))


### Final Net

In [None]:
model = Sequential()


