In [4]:
from keras.models import Model
from keras.layers import Input
from keras.layers import Activation
from keras.layers import Conv2D, Dense
from keras.layers import add, BatchNormalization, Flatten
from keras.losses import CategoricalCrossentropy,MeanSquaredError
import tensorflow as tf
from board_conversion import *
from keras.utils.generic_utils import get_custom_objects

def residual_module(layer_in, n_filters):
    merge_input = layer_in
    if layer_in.shape[-1] != n_filters:
        merge_input = Conv2D(n_filters, (1,1), padding='same', activation='relu')(layer_in)
    conv1 = Conv2D(n_filters, (3,3), padding='same', activation='relu')(layer_in)
    batch_norm = BatchNormalization()(conv1)
    layer_out = add([batch_norm, merge_input])
    layer_out = Activation('relu')(layer_out)
    return layer_out

visible = Input(shape=(17,8,8,12))
layer1 = residual_module(visible, 64)
layer2 = residual_module(layer1, 64)
flatten = Flatten()(layer2)

pre_v = Dense(256)(flatten)
p = Conv2D(73,(1,1),activation='softmax', name = 'p')(layer2)
v = Dense(1,activation = 'tanh', name = 'v')(pre_v)

model = Model(inputs=visible, outputs=[p,v])

model.compile(optimizer = 'adam', loss = {'p':CategoricalCrossentropy(),
                                                'v':MeanSquaredError()})
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(None, 17, 8, 8, 12 0                                            
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 17, 8, 8, 64) 6976        input_4[0][0]                    
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 17, 8, 8, 64) 256         conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 17, 8, 8, 64) 832         input_4[0][0]                    
____________________________________________________________________________________________