In [1]:
from keras.layers import Input, ZeroPadding3D, Conv3D, BatchNormalization, Activation, SpatialDropout3D, MaxPooling3D, \
    TimeDistributed, Flatten, Bidirectional, GRU, Dense, AveragePooling3D
from keras import Model
import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
print(gpus)
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)
        
import keras
import numpy as np
from data_gen_stanford import DataGenerator

Using TensorFlow backend.


[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [3]:
def get_Lipnet(n_classes=51, summary=False):
    input_layer = Input(name='the_input', shape=(75, 50, 100, 3), dtype='float32')
    x = Conv3D(32, (3, 5, 5), strides=(1, 2, 2), padding="same", kernel_initializer='he_normal', name='conv1')(input_layer)
    x = BatchNormalization(name='batc1')(x)
    x = Activation('relu', name='actv1')(x)
    x = SpatialDropout3D(0.5)(x)
    x = MaxPooling3D(pool_size=(1, 2, 2), strides=(1, 2, 2), name='max1')(x)

    x = Conv3D(64, (3, 5, 5), strides=(1, 1, 1), padding="same", kernel_initializer='he_normal', name='conv2')(x)
    x = BatchNormalization(name='batc2')(x)
    x = Activation('relu', name='actv2')(x)
    x = SpatialDropout3D(0.5)(x)
    x = MaxPooling3D(pool_size=(1, 2, 2), strides=(1, 2, 2), name='max2')(x)

    x = Conv3D(96, (3, 3, 3), strides=(1, 1, 1), padding="same", kernel_initializer='he_normal', name='conv3')(x)
    x = BatchNormalization(name='batc3')(x)
    x = Activation('relu', name='actv3')(x)
    x = SpatialDropout3D(0.5)(x)
    x = MaxPooling3D(pool_size=(1, 2, 2), strides=(1, 2, 2), name='max3')(x)
    
    x = TimeDistributed(Flatten())(x)

    x = Bidirectional(GRU(128, return_sequences=True, kernel_initializer='Orthogonal', name='gru1'),
                            merge_mode='concat')(x)
    x = Bidirectional(GRU(128, return_sequences=True, kernel_initializer='Orthogonal', name='gru2'),
                            merge_mode='concat')(x)
    x = Flatten()(x)
    outputs = Dense(n_classes, kernel_initializer='he_normal', name='dense1', activation="sigmoid")(x)

    model = Model(inputs=input_layer, outputs=outputs)
    if summary:
        keras.utils.plot_model(model, 'network.png', show_shapes=True)
        print(model.summary())

    model.compile(optimizer=keras.optimizers.Adam(beta_1=0.9, beta_2=0.999, lr=1e-4),
                  loss='binary_crossentropy',
                  metrics=['accuracy', 'mse', tf.keras.metrics.AUC()])
    return model

In [3]:
if __name__ == "__main__":
    model = get_Lipnet(n_classes=51, summary=False)
    datagen = DataGenerator(batch_size=10, val_split=0.99)
    model.fit_generator(generator=datagen, epochs=1, shuffle=True, validation_data=datagen.get_valid_data())
    model.save('my_model.h5')

Epoch 1/1
[28516, 9013, 15724, 11100, 13956, 571, 26280, 21982, 20736, 26575]
max240.0
min0.0
[24665, 15378, 11097, 23836, 5893, 10035, 1633, 26485, 21176, 30988]
max247.0
min0.0
[14034, 27503, 21321, 25613, 332, 11153, 7727, 25895, 28201, 20629]
max237.0
min0.0
[24409, 9082, 10239, 5780, 737, 9301, 3031, 14959, 28623, 7571]
max246.0
min0.0
[31909, 110, 19784, 18601, 23149, 23746, 22005, 12631, 16621, 31459]
max240.0
min0.0
[9027, 31944, 29787, 17045, 20202, 11670, 26255, 13381, 876, 23936]
max226.0
min0.0
[13902, 19450, 17396, 4981, 25875, 10383, 15858, 12394, 21967, 21827]
max237.0
min0.0
[6043, 613, 31691, 3665, 22972, 23943, 10350, 20406, 32410, 20607]
max252.0
min0.0
[27549, 16821, 11346, 18444, 24070, 12396, 2678, 21313, 28205, 9279]
max255.0
min0.0
[27080, 24587, 21892, 667, 5962, 4785, 6534, 7044, 19093, 9524]
max249.0
min0.0
[11734, 11326, 10201, 6878, 1936, 30149, 3067, 5168, 9767, 19620]
max255.0
min0.0
[2715, 6056, 32299, 4874, 22271, 31626, 30088, 8716, 22237, 25042]
max25

  37/3281 [..............................] - ETA: 25:57 - loss: 0.3756 - accuracy: 0.8600 - mse: 0.1114 - auc: 0.6238[19205, 8820, 26620, 15447, 26218, 10488, 16322, 21621, 7587, 3996]
max240.0
min0.0
  38/3281 [..............................] - ETA: 25:45 - loss: 0.3747 - accuracy: 0.8605 - mse: 0.1111 - auc: 0.6250[20765, 31788, 1894, 9093, 18388, 10271, 7509, 27412, 18023, 20161]
max239.0
min0.0
  39/3281 [..............................] - ETA: 25:31 - loss: 0.3738 - accuracy: 0.8610 - mse: 0.1108 - auc: 0.6262[22453, 26031, 23197, 30643, 14263, 23397, 9084, 7590, 32175, 9183]
max251.0
min0.0
  40/3281 [..............................] - ETA: 25:18 - loss: 0.3729 - accuracy: 0.8615 - mse: 0.1104 - auc: 0.6273[2788, 14551, 32390, 26304, 31608, 16538, 19689, 24077, 414, 12778]
max240.0
min0.0
  41/3281 [..............................] - ETA: 25:06 - loss: 0.3727 - accuracy: 0.8618 - mse: 0.1102 - auc: 0.6284[11858, 19841, 14949, 17870, 29255, 30077, 5439, 97, 17539, 1296]
max244.0
min0

  76/3281 [..............................] - ETA: 21:29 - loss: 9236718861.8334 - accuracy: 0.8642 - mse: 4054796396895041795129344.0000 - auc: 0.6532[32132, 27540, 9159, 29896, 11518, 18898, 6661, 19315, 28807, 16140]
max236.0
min0.0
  77/3281 [..............................] - ETA: 21:26 - loss: 9116761474.0227 - accuracy: 0.8641 - mse: 4002136707172124059500544.0000 - auc: 0.6537[19643, 28707, 24738, 10411, 19230, 4906, 15227, 2962, 28426, 31727]
max249.0
min0.0
  78/3281 [..............................] - ETA: 21:23 - loss: 8999879916.6695 - accuracy: 0.8638 - mse: 3950827088531100941680640.0000 - auc: 0.6542[27744, 32292, 19683, 19021, 32543, 19859, 18872, 25151, 27169, 4442]
max231.0
min0.0
  79/3281 [..............................] - ETA: 21:22 - loss: 8885957386.0845 - accuracy: 0.8635 - mse: 3900816812425769740402688.0000 - auc: 0.6546[8257, 6062, 28506, 7067, 26680, 31011, 17345, 5001, 19704, 1622]
max248.0
min0.0
  80/3281 [..............................] - ETA: 21:19 - loss

 146/3281 [>.............................] - ETA: 19:38 - loss: 18252278293.4128 - accuracy: 0.8598 - mse: 3166073772597881444761600.0000 - auc: 0.6734[31687, 13204, 29664, 25495, 22722, 18425, 31114, 3392, 22041, 20342]
max248.0
min0.0
 147/3281 [>.............................] - ETA: 19:37 - loss: 18128113134.9571 - accuracy: 0.8598 - mse: 3144535757739944784691200.0000 - auc: 0.6736[7118, 26606, 23627, 23932, 6110, 25342, 12740, 22056, 14323, 1208]
max249.0
min0.0
 148/3281 [>.............................] - ETA: 19:35 - loss: 18005625884.0479 - accuracy: 0.8596 - mse: 3123289143792297505193984.0000 - auc: 0.6738[4461, 3481, 30133, 6758, 21681, 19945, 32130, 24669, 27184, 32518]
max249.0
min0.0
 149/3281 [>.............................] - ETA: 19:34 - loss: 17884782757.3120 - accuracy: 0.8596 - mse: 3102327301456288116899840.0000 - auc: 0.6739[72, 11572, 21536, 12762, 10410, 22899, 31775, 1335, 9848, 19713]
max255.0
min0.0
 150/3281 [>.............................] - ETA: 19:33 - lo

 216/3281 [>.............................] - ETA: 18:29 - loss: -4276925325.8071 - accuracy: 0.8563 - mse: 3566718895365155348021248.0000 - auc: 0.6824[15921, 2868, 29105, 4797, 6201, 11342, 7746, 11376, 12075, 17315]
max240.0
min0.0
 217/3281 [>.............................] - ETA: 18:29 - loss: -4257215992.5050 - accuracy: 0.8562 - mse: 3550282558165103985819648.0000 - auc: 0.6825[27761, 11935, 1317, 16612, 23318, 9682, 17095, 15446, 8122, 3441]
max241.0
min0.0
 218/3281 [>.............................] - ETA: 18:28 - loss: -4237687478.7747 - accuracy: 0.8562 - mse: 3533996677221403817148416.0000 - auc: 0.6826[27007, 17933, 12949, 11084, 5630, 18157, 18720, 17334, 17798, 18214]
max244.0
min0.0
 219/3281 [=>............................] - ETA: 18:28 - loss: -4218337307.6356 - accuracy: 0.8561 - mse: 3517859811382174083448832.0000 - auc: 0.6827[10725, 8596, 8269, 10721, 9256, 8163, 5796, 9484, 2253, 31151]
max252.0
min0.0
 220/3281 [=>............................] - ETA: 18:27 - loss: 

KeyboardInterrupt: 

In [None]:
model = load_model('my_model.h5')