In [None]:
# NN options
n_ensemble = 10	# no. NNs in ensemble
reg = 'anc'		# type of regularisation to use - anc (anchoring) reg (regularised) free (unconstrained)
n_hidden = 512 	# no. hidden units in NN
#activation_in = 'relu' # tanh relu sigmoid

# optimisation options
epochs = 5 		# run reg for 15+ epochs seems to mess them up
l_rate = 0.005 		# learning rate

# data options
n_data = len(fileList_train) 	# no. training data points
n_classes = 1 	# no. classes predicting
n_Xdim = 512 	# no. features for X
seed_in = 0 # random seed used to produce data blobs - try changing to see how results look w different data

# variance of priors
W1_var = 15/n_Xdim		# 1st layer weights and biases
W_mid_var = 1/n_hidden	# 2nd layer weights and biases
W_last_var = 5/n_hidden	# 3rd layer weights

In [None]:
from keras.backend import sigmoid
def swish(x, beta = 1):
    return (x * sigmoid(beta * x))

from keras.utils.generic_utils import get_custom_objects
from keras.layers import Activation
get_custom_objects().update({'swish': Activation(swish)})

In [None]:
# NN loss
#def fn_my_loss(y_true,y_pred):
#	return K.binary_crossentropy(y_true, y_pred, from_logits=True)

# NN object
def fn_make_NN(reg='anc'):
	# get initialisations, and regularisation values
  W_c1_lambda = 1/(2*W1_var)
  W_c1_anc = np.random.normal(loc=0,scale=np.sqrt(W1_var),size=[3,3,64,128])
  W_c1_init = np.random.normal(loc=0,scale=np.sqrt(W1_var),size=[3,3,64,128])

  b1_var = W1_var
  b_c1_lambda = 1/(2*b1_var)
  b_c1_anc = np.random.normal(loc=0,scale=np.sqrt(b1_var),size=[128])
  b_c1_init = np.random.normal(loc=0,scale=np.sqrt(b1_var),size=[128])

  W1_lambda = 1/(2*W1_var)
  W1_anc = np.random.normal(loc=0,scale=np.sqrt(W1_var),size=[n_Xdim,n_hidden])
  W1_init = np.random.normal(loc=0,scale=np.sqrt(W1_var),size=[n_Xdim,n_hidden])

  b1_var = W1_var
  b1_lambda =  1/(2*b1_var)
  b1_anc = np.random.normal(loc=0,scale=np.sqrt(b1_var),size=[n_hidden])
  b1_init = np.random.normal(loc=0,scale=np.sqrt(b1_var),size=[n_hidden])

  W_mid_lambda = 1/(2*W_mid_var)
  W_mid_anc = np.random.normal(loc=0,scale=np.sqrt(W_mid_var),size=[n_hidden,n_hidden])
  W_mid_init = np.random.normal(loc=0,scale=np.sqrt(W_mid_var),size=[n_hidden,n_hidden])

  b_mid_var = W_mid_var
  b_mid_lambda =  1/(2*b_mid_var)
  b_mid_anc = np.random.normal(loc=0,scale=np.sqrt(b_mid_var),size=[n_hidden])
  b_mid_init = np.random.normal(loc=0,scale=np.sqrt(b_mid_var),size=[n_hidden])
    
  W_last_lambda = 1/(2*W_last_var)
  W_last_anc = np.random.normal(loc=0,scale=np.sqrt(W_last_var),size=[n_hidden, n_classes])
  W_last_init = np.random.normal(loc=0,scale=np.sqrt(W_last_var),size=[n_hidden, n_classes])

  # create custom regularised
  def custom_reg_W_c1(weight_matrix):
    if reg == 'reg':
      return K.sum(K.square(weight_matrix)) * W_c1_lambda/n_data
    elif reg == 'free':
      return 0.
    elif reg == 'anc':
      return K.sum(K.square(weight_matrix - W_c1_anc)) * W_c1_lambda/n_data

  def custom_reg_b_c1(weight_matrix):
    if reg == 'reg':
      return K.sum(K.square(weight_matrix)) * b_c1_lambda/n_data
    elif reg == 'free':
      return 0.
    elif reg == 'anc':
      return K.sum(K.square(weight_matrix - b_c1_anc)) * b_c1_lambda/n_data


  def custom_reg_W1(weight_matrix):
    if reg == 'reg':
      return K.sum(K.square(weight_matrix)) * W1_lambda/n_data
    elif reg == 'free':
      return 0.
    elif reg == 'anc':
      return K.sum(K.square(weight_matrix - W1_anc)) * W1_lambda/n_data

  def custom_reg_b1(weight_matrix):
    if reg == 'reg':
      return K.sum(K.square(weight_matrix)) * b1_lambda/n_data
    elif reg == 'free':
      return 0.
    elif reg == 'anc':
      return K.sum(K.square(weight_matrix - b1_anc)) * b1_lambda/n_data

  def custom_reg_W_mid(weight_matrix):
    if reg == 'reg':
      return K.sum(K.square(weight_matrix)) * W_mid_lambda/n_data
    elif reg == 'free':
      return 0.
    elif reg == 'anc':
      return K.sum(K.square(weight_matrix - W_mid_anc)) * W_mid_lambda/n_data

  def custom_reg_b_mid(weight_matrix):
    if reg == 'reg':
      return K.sum(K.square(weight_matrix)) * b_mid_lambda/n_data
    elif reg == 'free':
      return 0.
    elif reg == 'anc':
      return K.sum(K.square(weight_matrix - b_mid_anc)) * b_mid_lambda/n_data

  def custom_reg_W_last(weight_matrix):
    if reg == 'reg':
      return K.sum(K.square(weight_matrix)) * W_last_lambda/n_data
    elif reg == 'free':
      return 0.
    elif reg == 'anc':
      return K.sum(K.square(weight_matrix - W_last_anc)) * W_last_lambda/n_data

  model = Sequential()
  model.add(Conv2D(32, (3, 3), input_shape = (90,90,225),padding='same')) 
  model.add(Activation('swish'))
  model.add(BatchNormalization()) 
  model.add(Conv2D(32, (3, 3),padding='same')) 
  model.add(Activation('swish')) 
  model.add(BatchNormalization())
  model.add(MaxPooling2D(pool_size =(2, 2), strides=(2,2))) 

  #(45,45,32)
  model.add(Conv2D(64, (3, 3))) 
  model.add(Activation('swish'))
  model.add(BatchNormalization()) 
  model.add(Conv2D(64, (3, 3))) 
  model.add(Activation('swish')) 
  model.add(BatchNormalization())
  model.add(MaxPooling2D(pool_size =(2, 2), strides=(2,2))) 

  #(20,20,64)
  model.add(Conv2D(128, (3, 3),kernel_initializer=keras.initializers.Constant(value=W_c1_init),bias_initializer=keras.initializers.Constant(value=b_c1_init),kernel_regularizer=custom_reg_W_c1,bias_regularizer=custom_reg_b_c1)) 
  model.add(Activation('swish'))
  model.add(BatchNormalization()) 
  model.add(Conv2D(128, (3, 3))) 
  model.add(Activation('swish')) 
  model.add(BatchNormalization())
  model.add(MaxPooling2D(pool_size =(2, 2), strides=(2,2))) 



  #(8,8,128)
  model.add(Conv2D(256, (3, 3))) 
  model.add(Activation('swish'))
  model.add(BatchNormalization()) 
  model.add(Conv2D(256, (3, 3))) 
  model.add(Activation('swish')) 
  model.add(BatchNormalization())
  model.add(MaxPooling2D(pool_size =(2, 2), strides=(2,2))) 


  #(2,2,256)
  model.add(Conv2D(512, (2, 2))) 
  model.add(Activation('swish'))


  #(1,1,512)  
  model.add(Flatten()) 
  model.add(Dense(512,input_shape=(n_Xdim,),kernel_initializer=keras.initializers.Constant(value=W1_init),bias_initializer=keras.initializers.Constant(value=b1_init),kernel_regularizer=custom_reg_W1,bias_regularizer=custom_reg_b1)) 
  model.add(Activation('swish')) 
  #model.add(Dropout(0.5))#,seed=seed_value)) 
  model.add(Dense(512,kernel_initializer=keras.initializers.Constant(value=W_mid_init),bias_initializer=keras.initializers.Constant(value=b_mid_init),kernel_regularizer=custom_reg_W_mid,bias_regularizer=custom_reg_b_mid)) 
  model.add(Activation('swish')) 
  #model.add(Dropout(0.5))#,seed=seed_value))
  model.add(Dense(1,use_bias=False,kernel_initializer=keras.initializers.Constant(value=W_last_init),kernel_regularizer=custom_reg_W_last)) 
  model.add(Activation('sigmoid'))
  sgd = SGD(lr=0.0001, momentum=0.9, nesterov=True)
  model.compile(loss='binary_crossentropy', optimizer=sgd,metrics=['accuracy'])

  return model

def fn_predict_ensemble(NNs, x_test):
	''' fn to predict given a list of NNs (an ensemble)''' 
	y_prob_preds = []
	for m in range(len(NNs)):
		y_prob_preds.append(NNs[m].predict(x_test, verbose=0))
	y_prob_preds = np.array(y_prob_preds)

	y_prob_final = np.mean(y_prob_preds,axis=0)

	return y_prob_preds, y_prob_final

In [None]:
# create some data
#x_train, y_train, x_test, y_test = fn_make_data(seed_in=seed_in)

# create the NNs
NNs=[]
for m in range(n_ensemble):
  NNs.append(fn_make_NN(reg=reg))
print(NNs[-1].summary())
batch_size=10
# do the actual training
NNs_hist_train=[]; NNs_hist_val=[]
for m in range(n_ensemble):
  print('-- training: ' + str(m+1) + ' of ' + str(n_ensemble) + ' NNs --')
  hist=NNs[m].fit(imageLoaderNew(fileList_train, batch_size), steps_per_epoch = 4*len(fileList_train)/batch_size, epochs = 5, verbose=0,validation_data=imageLoaderNew(fileList_valid, batch_size), validation_steps = 4*len(fileList_valid)/batch_size)    
  NNs[m].save_weights('/content/drive/My Drive/Colab Notebooks/swri_research/code/flux_emerge_ensemble_model_'+str(m+1)+'.h5')
  print('-- NN: ' + str(m+1) + ' weights saved --')
  NNs_hist_train.append(hist.history['loss'])
  NNs_hist_val.append(hist.history['val_loss'])