In [20]:
from explain import L2X, generate_model_preds
from keras.engine.topology import Layer 
from keras.datasets import mnist
import numpy as np
import keras

In [6]:
class Sample_Concrete(Layer):
    """
    Layer for sample Concrete / Gumbel-Softmax variables. 

    """
    def __init__(self, tau0, k, **kwargs): 
        self.tau0 = tau0
        self.k = k
        super(Sample_Concrete, self).__init__(**kwargs)

    def call(self, logits):   
        # logits: [batch_size, d, 1]
        logits_ = K.expand_dims(logits, -2) #transform to Batch x 1 x Dim

        d = int(logits_.get_shape()[2]) #d = 784 in this case
        unif_shape = [batch_size,self.k,d] #sizing the sampling.

        uniform = K.random_uniform_variable(shape=unif_shape,
            low = np.finfo(tf.float32.as_numpy_dtype).tiny,
            high = 1.0) #finfo is machine limit for floating precision - this is the draw from a Uniform for Gumbel sftmx
        gumbel = - K.log(-K.log(uniform)) #This is now a tf.tensor; tf.variables are converted to Tensors once used
        noisy_logits = (gumbel + logits_)/self.tau0 
        samples = K.softmax(noisy_logits) #In this context, logits are just 'raw activations'
        samples = K.max(samples, axis = 1) #reduces to max of the softmax (i.e. batch x 784)
        
        logits = tf.reshape(logits,[-1, d]) #Not sure necessary for our dimensions
        threshold = tf.expand_dims(tf.nn.top_k(logits, self.k, sorted = True)[0][:,-1], -1) #gives a batchx1 tensor.
                #this is taking the 10th highest logit value as a threshold for each instance
        discrete_logits = tf.cast(tf.greater_equal(logits,threshold),tf.float32) #Does what you think. Returns a Batch x d vec of zeros or ones 
        
        output = K.in_train_phase(samples, discrete_logits) #Returns samples if in training, discrete_logits otherwise.
        return output #tf.expand_dims(output,-1)

    def compute_output_shape(self, input_shape):
        return input_shape

In [38]:
def create_mnist_model(train = True):
    """
    Build simple MNIST model in Keras, and train it if train = True
    """

    model = Sequential()
    model.add(Dense(100, activation='relu', input_shape=(784,)))
    model.add(Dense(50, activation='relu'))
    model.add(Dense(25, activation='relu'))
    model.add(Dense(2, activation='softmax'))

    model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

    x_train, y_train, x_val, y_val = load_data()

    if train:
        filepath="models/original.hdf5"
        checkpoint = ModelCheckpoint(filepath, monitor='val_acc', 
            verbose=1, save_best_only=True, mode='max')
        callbacks_list = [checkpoint]
        model.fit(x_train, y_train, validation_data=(x_val, y_val),callbacks = callbacks_list, epochs=epochs, batch_size=batch_size)

    model.load_weights('./models/original.hdf5',by_name=True) #If train=False, we assume we have already trained an instance of the model

    return model

In [33]:
def load_data():
    """
    Load Data from keras mnist dataset, adjust to appropriate dimensions range etc.
    """
    (x_train, y_train), (x_val, y_val) = mnist.load_data()
    x_train = x_train.reshape(60000, 784)
    x_val = x_val.reshape(10000, 784)
    x_train = x_train.astype('float32')
    x_val = x_val.astype('float32')
    x_train /= 255
    x_val /= 255
    
    yy_train = np.zeros(y_train.shape)
    yy_val = np.zeros(y_val.shape)
    
    for idx, value in enumerate(y_train):
        if value > 4:
            yy_train[idx] = 1
            
    for idx, value in enumerate(y_val):
        if value > 4:
            yy_val[idx] = 1
    
    y_train = keras.utils.to_categorical(y_train, 10)
    y_val = keras.utils.to_categorical(y_val, 10)
    
    yy_train = keras.utils.to_categorical(yy_train, 2)
    yy_val = keras.utils.to_categorical(yy_val, 2)

    return x_train, y_train, x_val, y_val, yy_train, yy_val

In [34]:
x_train, y_train, x_val, y_val, yy_train, yy_val = load_data()