### Architecture:
* Something like Google Le Net. We use Inception V3.
* SGD with momentum, lr = 0.01 and 5000 iterations.
* L2 norm on higher level feature maps. We take this as last 4.

### Weird Stuff:
* Feed in 4 examples at once to compute the loss
* Oversample the rare classes
* Based on the above two points, we should iterate through sub classes, and then randomly sample the other 3 points
* Other points: Sample randomly a different point from within class, sample random intra class and take random sample from this, sample random from inter class.
* What to set m1, m2 to?

### Augmentation:
1. intensity variation between −0.1 to 0.1
2. rotation with −90° to 90°
3. flip with level and vertical direction
4. translation with ±20 pixels

In [1]:
import os
import glob
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
from matplotlib.pyplot import imshow
from PIL import Image

import sys
sys.path.insert(0, '/Users/rb/Google_Drive/Waterloo/projects/breakHis/src')
from models import*
% matplotlib inline

Using TensorFlow backend.


In [12]:
# make the batch generator:
# this must return 4 images, from the correct classes in the correct order:
# [current element, element in same sub class, element in same class, element in different class]
# this should also sample even numbers from each class (same number of subclasses for malignant and benign)

def data_gen_CSD(file_loc, batch_size, magnification=0, im_size=None, 
    square_rot_p=.3, translate=0, flips=False, rotate=False):
    # square_rot_p is the prob of using a 90x rotation, otherwise sample from 360. Possibly not useful
    # translate is maximum number of pixels to translate by. Make it close the doctor's variance in annotation
    label_list = ['B_A', 'B_F', 'B_PT', 'B_TA', 'M_DC', 'M_LC', 'M_MC', 'M_PC']
    square_rot_p = float(square_rot_p)
    translate = int(translate)
    num_samples = int(batch_size/4)
    
    all_files=glob.glob(os.path.join(file_loc, '*'))
    if int(magnification)!=0:
        all_files = [loc for loc in all_files if loc.rsplit('/', 1)[1].rsplit('-', 1)[0].rsplit('-', 1)[1] == str(magnification)]
    num_batches = int(np.floor(len(all_files)/batch_size))

    while 1:
        random.shuffle(all_files) # randomize after every epoch
        # now get the right files for the batch:
        for index, label in enumerate(label_list):
            x=[]
            y=[]
            batch_files = []
            # get imgs from the current class
            img_locs = [loc for loc in all_files if loc.rsplit('/', 1)[1].split('_', 1)[1].split('-', 1)[0] == str(label)]
            batch_files.append(random.choice(img_locs))

            # get imgs from the same sub class, but not duplicates:
            pi_pos = [x for x in img_locs if x not in batch_files]
            batch_files.append(random.choice(pi_pos))

            # get imgs from the same intra class, but not duplicates:
            pi_neg = [loc for loc in all_files if loc.rsplit('/', 1)[1].split('_', 1)[1].split('-', 1)[0][0] == str(label)[0]]
            pi_neg = [x for x in pi_neg if x not in batch_files]
            batch_files.append(random.choice(pi_neg))

            # get imgs from the same intra class, but not duplicates:
            n_ = [loc for loc in all_files if loc.rsplit('/', 1)[1].split('_', 1)[1].split('-', 1)[0][0] != str(label)[0]]
            batch_files.append(random.choice(n_))
            
            # now get the images and augment:
            for image_loc in batch_files:
                image = Image.open(image_loc)

                # APPLY AUGMENTATION:
                # flips
                if flips:
                    flip_vert = random.randint(0, 1)
                    flip_hor = random.randint(0, 1)
                    if flip_vert:
                        image = image.transpose(Image.FLIP_TOP_BOTTOM)
                    if flip_hor:
                        image = image.transpose(Image.FLIP_LEFT_RIGHT)

                # rotation
                if rotate:
                    square_rot =  bool((np.random.uniform(0, 1, 1)<square_rot_p))
                    if square_rot:  # maybe this is dumb, but it cant hurt
                        angle = random.randint(0, 4)
                        if(angle ==0):
                            image = image.transpose(Image.ROTATE_90)
                        elif(angle ==1):
                            image = image.transpose(Image.ROTATE_180)
                        elif(angle ==2):
                            image = image.transpose(Image.ROTATE_270)
                    else:
                        angle = np.random.uniform(0, 360,1)
                        image=image.rotate(angle)

                if(im_size != 0):
                    image_shape = (im_size, im_size)
                    image = image.resize(image_shape)

                # translate
                ts_sz_row = randint(-1*translate, translate)
                ts_sz_col = randint(-1*translate, translate)

                image = image.transform(image.size, Image.AFFINE, (1, 0, ts_sz_row, 0, 1, ts_sz_col))

                image = np.reshape(np.array(image.getdata()), (im_size, im_size, 3))
                image = image/255.0 
                
                label = image_loc.rsplit('/', 1)[1].split('_', 1)[1].split('-', 1)[0]
                y_temp = label_list.index(label)
                y_temp = np.eye(8)[y_temp]
                
                x.append(image)
                y.append(y_temp)
            
            x=np.array(x)
            y=np.array(y)

            yield (x, y)

In [13]:
import os
import sys
import glob
import random
from random import randint
import numpy as np 
from PIL import Image
from PIL import ImageFilter


import keras
from keras.models import Sequential, Model, load_model
from keras.layers import Dropout, Flatten, Reshape, Input
from keras.layers.core import Activation, Dense, Lambda
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.pooling import GlobalAveragePooling2D
from keras.layers.normalization import BatchNormalization
from keras import backend as K
from keras.applications.vgg16 import VGG16
from keras.applications.inception_v3 import InceptionV3

def conv_bn_dp(x, filters, dropout):
    a = Conv2D(16, (3, 3), padding='same', kernel_initializer='he_normal')(x)
    b = BatchNormalization()(a)
    c = Activation('relu')(b)
    d = Dropout(dropout)(c)
    return d

def conv_6L_CSD(im_size, learning_rate = .01, dropout = .1):
    def CSD_loss(y_true, y_pred):
        # always assume there will be batch size of 4, with the elements in the following order:
        # [current element, element in same sub class, element in same class, element in different class]
        m1 = .05 # no idea what this should be set to
        m2 = .1

        # Normal loss function for CNN:
        loss_classification = K.categorical_crossentropy(y_true, y_pred)

        # Loss function based on feature distances:
        # after the flatten dims is [batches, features]
        d_x_pos = K.sqrt(K.sum(K.square(dense2[0, :] - dense2[1, :])))

        d_x_neg = K.sqrt(K.sum(K.square(dense2[1, :] - dense2[2, :])))
        d_x_n = K.sqrt(K.sum(K.square(dense2[2, :] - dense2[3, :])))

        # loss_feat_dist = (.5*K.mean(K.max(0, d_x_pos - d_x_neg + m1 - m2)) + .5*K.mean(K.max(0, d_x_neg - d_x_n + m2)))
        # Can't use the above because it ids for the case with multiple samples of each class
        loss_feat_dist = .5*K.maximum(0.0, d_x_pos - d_x_neg + m1 - m2) + .5*K.maximum(0.0, d_x_neg - d_x_n + m2)    
        loss = 0.5*loss_classification + 0.5*loss_feat_dist
        return loss
    
    input_shape = (im_size, im_size, 3)
    img_input = Input(shape=input_shape)

    L1 = conv_bn_dp(img_input, filters=16, dropout=dropout)

    L2 = conv_bn_dp(L1, filters=16, dropout=dropout)
    L2_pool = MaxPooling2D(pool_size=(2, 2))(L2)

    L3 = conv_bn_dp(L2, filters=32, dropout=dropout)
    
    L4 = conv_bn_dp(L3, filters=32, dropout=dropout)
    L4_pool = MaxPooling2D(pool_size=(2, 2))(L4)

    L5 = conv_bn_dp(L4_pool, filters=64, dropout=dropout)
    
    L6 = conv_bn_dp(L5, filters=64, dropout=dropout)
    L5_pool = MaxPooling2D(pool_size=(2, 2))(L5)
    
    flatten = Flatten(input_shape=(None, None, 64) )(L5_pool)
    
    dense1 = Dense(512, kernel_initializer='he_normal')(flatten)
    dense1 = BatchNormalization()(dense1)
    dense1 = Activation('relu')(dense1)
    dense1 = Dropout(dropout*4)(dense1)

    dense2 = Dense(256, kernel_initializer='he_normal')(dense1)
    dense2 = BatchNormalization()(dense2)
    dense2 = Activation('relu')(dense2)
    dense2 = Dropout(dropout*4)(dense2)
    
    predictions = Dense(8, activation='softmax')(dense2)
    model = Model(outputs=predictions, inputs=img_input)

    SGD = keras.optimizers.SGD(lr=learning_rate, momentum=0.3, decay=0.0, nesterov=False)
    model.compile(optimizer=SGD, loss=CSD_loss, metrics=['accuracy'])
    return model


def conv_6L_CSD2(im_size, learning_rate = .01, dropout = .1):
    def CSD_loss(y_true, y_pred):
        # always assume there will be batch size of 4, with the elements in the following order:
        # [current element, element in same sub class, element in same class, element in different class]
        m1 = .05 # no idea what this should be set to
        m2 = .1

        # Normal loss function for CNN:
        loss_classification = K.categorical_crossentropy(y_true, y_pred)

        # Loss function based on feature distances:
        # after the flatten dims is [batches, features]
        d_x_pos = K.sqrt(K.sum(K.square(dense2[0, :] - dense2[1, :])))

        d_x_neg = K.sqrt(K.sum(K.square(dense2[1, :] - dense2[2, :])))
        d_x_n = K.sqrt(K.sum(K.square(dense2[2, :] - dense2[3, :])))

        # loss_feat_dist = (.5*K.mean(K.max(0, d_x_pos - d_x_neg + m1 - m2)) + .5*K.mean(K.max(0, d_x_neg - d_x_n + m2)))
        # Can't use the above because it ids for the case with multiple samples of each class
        loss_feat_dist = .5*K.maximum(0.0, d_x_pos - d_x_neg + m1 - m2) + .5*K.maximum(0.0, d_x_neg - d_x_n + m2)    
        loss = 0.5*loss_classification + 0.5*loss_feat_dist
        return loss
    
    input_shape = (im_size, im_size, 3)
    img_input = Input(shape=input_shape)

    L1 = conv_bn_dp(img_input, filters=16, dropout=dropout)

    L2 = conv_bn_dp(L1, filters=16, dropout=dropout)
    L2_pool = MaxPooling2D(pool_size=(2, 2))(L2)

    L3 = conv_bn_dp(L2, filters=32, dropout=dropout)
    
    L4 = conv_bn_dp(L3, filters=32, dropout=dropout)
    L4_pool = MaxPooling2D(pool_size=(2, 2))(L4)

    L5 = conv_bn_dp(L4_pool, filters=64, dropout=dropout)
    
    L6 = conv_bn_dp(L5, filters=64, dropout=dropout)
    L5_pool = MaxPooling2D(pool_size=(2, 2))(L5)
    
    flatten = Flatten(input_shape=(None, None, 64) )(L5_pool)
    
    dense1 = Dense(512, kernel_initializer='he_normal')(flatten)
    dense1 = BatchNormalization()(dense1)
    dense1 = Activation('relu')(dense1)
    dense1 = Dropout(dropout*4)(dense1)

    dense2 = Dense(256, kernel_initializer='he_normal')(dense1)
    dense2 = BatchNormalization()(dense2)
    dense2 = Activation('relu')(dense2)
    dense2 = Dropout(dropout*4)(dense2)
    
    predictions = Dense(8, activation='softmax')(dense2)
    model = Model(outputs=predictions, inputs=img_input)

    Adam = keras.optimizers.Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
    model.compile(loss="categorical_crossentropy", optimizer=Adam, metrics=['accuracy'])
    return model


def InceptionV3_CSD(im_size, learning_rate=.001, dropout =.1):

    def CSD_loss(y_true, y_pred):
        # always assume there will be batch size of 4, with the elements in the following order:
        # [current element, element in same sub class, element in same class, element in different class]
        m1 = .05 # no idea what this should be set to
        m2 = .01

        # Normal loss function for CNN:
        loss_classification = K.categorical_crossentropy(y_true, y_pred)

        # Loss function based on feature distances:
        # after the flatten dims is [batches, features]
        d_x_pos = K.sqrt(K.sum(K.square(dense2_d[0, :] - dense2_d[1, :])))

        d_x_neg = K.sqrt(K.sum(K.square(dense2_d[1, :] - dense2_d[2, :])))
        d_x_n = K.sqrt(K.sum(K.square(dense2_d[2, :] - dense2_d[3, :])))

        # loss_feat_dist = (.5*K.mean(K.max(0, d_x_pos - d_x_neg + m1 - m2)) + .5*K.mean(K.max(0, d_x_neg - d_x_n + m2)))
        # Can't use the above because it ids for the case with multiple samples of each class
        loss_feat_dist = .5*K.maximum(0.0, d_x_pos - d_x_neg + m1 - m2) + .5*K.maximum(0.0, d_x_neg - d_x_n + m2)    
        loss = 0.5*loss_classification + 0.5*loss_feat_dist
        return loss

    inp = Input(shape=(int(im_size), int(im_size), 3))

    try:
        base_model = InceptionV3(weights='imagenet', include_top=False)
    except:
        InceptionV3  = load_model('project/rbbidart/models/InceptionV3')
        base_model = InceptionV3(inp)
        
    dense1 = Flatten(input_shape= (None, None, 2048) )(base_model)
    dense1 = Dense(1024, kernel_initializer='he_normal')(dense1)
    dense1 = keras.layers.normalization.BatchNormalization()(dense1)
    dense1 = Activation('relu')(dense1)
    dense1 = Dropout(dropout*4)(dense1)

    dense2 = Dense(256, kernel_initializer='he_normal')(dense2)
    dense2 = keras.layers.normalization.BatchNormalization()(dense2)
    dense2 = Activation('relu')(dense2)
    dense2_d = Dropout(dropout*4)(dense2)

    # Output layer
    predictions = Dense(8, activation='softmax')(x)

    # add everything together to get a model.
    # input not inputs because using Keras 1
    model = Model(outputs=predictions, inputs=inp)

    # This freezes the convolutional layers, so only the added FC layers will be trained
    # Can still adjust the base model because it hasn't been compiled yet

    Adam = keras.optimizers.Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
    model.compile(loss="categorical_crossentropy", optimizer=Adam, metrics=['accuracy'])
    return model
                           

In [14]:
# params


out_loc = '/Users/rb/Google_Drive/Waterloo/projects/breakHis/output/csd_test'

data_loc = '/Users/rb/Documents/waterloo/projects/breakHis/patient_det'
out_loc = '/Users/rb/Google_Drive/Waterloo/projects/breakHis/output/csd_test'
epochs = 5
batch_size = 4
im_size = 256
model_str = 'conv_6L_CSD'
magnification = 100

In [15]:
import sys

import os
import glob
import random
import numpy as np 
import pandas as pd
import keras
import pickle
from keras import backend as K
from keras.engine.topology import Layer
from keras.layers import Dropout, Flatten, Reshape, Input
from keras.layers.core import Activation, Dense, Lambda
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.preprocessing.image import ImageDataGenerator


# Get the function:
functionList = {
'conv_6L_CSD': conv_6L_CSD
# 'InceptionV3_ft' : InceptionV3_ft
}

parameters = {
'learning_rate': .001,
'dropout': .1,
'im_size' : int(im_size)
}

# Locations
train_loc = os.path.join(str(data_loc),'train')
valid_loc = os.path.join(str(data_loc),'valid')

if(int(magnification)!=0):
    num_train = len([loc for loc in glob.glob(train_loc + '/**/*.png', recursive=True) if loc.rsplit('/', 1)[1].rsplit('-', 1)[0].rsplit('-', 1)[1] == str(magnification)])
    num_valid = len([loc for loc in glob.glob(valid_loc + '/**/*.png', recursive=True) if loc.rsplit('/', 1)[1].rsplit('-', 1)[0].rsplit('-', 1)[1] == str(magnification)])
else:
    num_train = len(glob.glob(train_loc + '/**/*.png', recursive=True))
    num_valid = len(glob.glob(valid_loc + '/**/*.png', recursive=True))

print('train_loc', train_loc)
print('valid_loc', valid_loc)

print('num_train', num_train)
print('num_valid', num_valid)

# Params for all models
epochs=int(epochs)
batch_size=int(batch_size)   # make this divisible by len(x_data)
im_size = int(im_size)

steps_per_epoch = np.floor(num_train/batch_size) # num of batches from generator at each epoch. (make it full train set)
validation_steps = np.floor(num_valid/batch_size)# size of validation dataset divided by batch size
print('steps_per_epoch', steps_per_epoch)
print('validation_steps', validation_steps)
print('batch_size', batch_size)

model = functionList[model_str](**parameters)
print(model_str)
name = model_str+'noaug'
out_file=os.path.join(str(out_loc), name)
callbacks = [EarlyStopping(monitor='val_loss', patience=15, verbose=0),
        ModelCheckpoint(filepath=os.path.join(out_loc, name+'_'+str(magnification)+'_.{epoch:02d}-{val_acc:.2f}.hdf5'), 
                        verbose=1, monitor='val_loss', save_best_only=True)]

hist = model.fit_generator(data_gen_CSD(train_loc, batch_size, magnification=magnification, im_size=im_size, 
                                  square_rot_p=.5, translate=30, flips=True, rotate=True),
                                  validation_data=data_gen_CSD(train_loc, batch_size, magnification=magnification, im_size=im_size),
                                  steps_per_epoch=steps_per_epoch,
                                  epochs=epochs,
                                  validation_steps=validation_steps,
                                  callbacks=callbacks)

pickle.dump(hist.history, open(out_file, 'wb'))

train_loc /Users/rb/Documents/waterloo/projects/breakHis/patient_det/train
valid_loc /Users/rb/Documents/waterloo/projects/breakHis/patient_det/valid
num_train 1123
num_valid 497
steps_per_epoch 280.0
validation_steps 124.0
batch_size 4
conv_6L_CSD
Epoch 1/5
x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.  

x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   

x:  [[[[ 0.68627451  0.38431373  0.75294118]
   [ 0.72156863  0.43529412  0.78823529]
   [ 0.71764706  0.44705882  0.76078431]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.75686275  0.42745098  0.81176471]
   [ 0.71764706  0.41960784  0.77647059]
   [ 0.7254902   0.43529412  0.76862745]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.81568627  0.4745098   0.85490196]
   [ 0.79607843  0.46666667  0.84313725]
   [ 0.76078431  0.43921569  0.83137255]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   

x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.89411765  0.89019608  0.88235294]
   [ 0.87058824  0.87058824  0.87843137]
   [ 0.85490196  0.85882353  0.87843137]
   ..., 
   [ 0.92941176  0.96078431  0.90196078]
   [ 0.95294118  0.98431373  0.94117647]
   

  1/280 [..............................] - ETA: 1936s - loss: 7.8980 - acc: 0.0000e+00x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.96862745  0.87843137  0.90196078]
   [ 0.95294118  0.90980392  0.93333333]
   [ 1.          1.          0.98431373]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.79607843  0.70980392  0.79215686]
   [ 0.77254902  0.72941176  0.72156863]
   [ 0.96862745  1.          0.95686275]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.55294118  0.49411765  0.58431373]
   [ 0.84313725  0.80784314  0.78823529]
   [ 0.95686275  0.98823529  0.9372549 ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 

x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.85882353  0.87058824  0.89019608]
   [ 0.87058824  0.88627451  0.89019608]
   [ 0.82352941  0.87843137  0.88235294]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.76078431  0.77647059  0.81176471]
   [ 0.7254902   0.74509804  0.76862745]
   [ 0.76078431  0.81568627  0.81568627]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.91764706  0.94509804  0.96862745]
   [ 0.87058824  0.90196078  0.91372549]
   [ 0.79215686  0.84705882  0.84705882]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   

  2/280 [..............................] - ETA: 1354s - loss: 7.9801 - acc: 0.0000e+00x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.90588235  0.85882353  0.8745098 ]
   [ 0.89019608  0.84313725  0.85882353]
   [ 0.91764706  0.85882353  0.88627451]
   ..., 

  4/280 [..............................] - ETA: 1051s - loss: 7.5108 - acc: 0.1875x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   

  6/280 [..............................] - ETA: 966s - loss: 7.5165 - acc: 0.1667 x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   

  8/280 [..............................] - ETA: 918s - loss: 7.5353 - acc: 0.1562x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.83137255  0.87058824  0.83921569]
   [ 0.82745098  0.86666667  0.83529412]
   [ 0.81568627  0.83137255  0.82745098]
   ..., 
   [

 10/280 [>.............................] - ETA: 881s - loss: 7.6026 - acc: 0.1500x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [

 12/280 [>.............................] - ETA: 847s - loss: 7.5151 - acc: 0.1667x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 1.          0.90588235  1.        ]
   [ 1.          0.8627451   0.98039216]
   [ 0.96078431  0.8         0.93333333]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.99607843  0.92941176  1.        ]
   [ 1.          0.91372549  1.        ]
   [ 1.          0.8745098   0.98823529]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.91372549  0.70196078  0.90588235]
   [ 0.95294118  0.78823529  0.94509804]
   [ 1.          0.82352941  1.        ]
   ..., 
   [

 14/280 [>.............................] - ETA: 822s - loss: 7.5070 - acc: 0.1429x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.89803922  0.83137255  0.80392157]
   [ 0.90588235  0.80784314  0.79215686]
   [ 0.90588235  0.77254902  0.77647059]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.88627451  0.81568627  0.80784314]
   [ 0.8627451   0.75294118  0.8       ]
   [ 0.85882353  0.7254902   0.79215686]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.85882353  0.76470588  0.76470588]
   [ 0.7254902   0.64313725  0.71764706]
   [ 0.52941176  0.45882353  0.52156863]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [

 16/280 [>.............................] - ETA: 796s - loss: 7.5575 - acc: 0.1250x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.90588235  0.86666667  0.8627451 ]
   [ 0.91372549  0.90980392  0.89019608]
   [ 0.99215686  0.99215686  0.98431373]
   ..., 
   [

 18/280 [>.............................] - ETA: 774s - loss: 7.6112 - acc: 0.1111x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [

 20/280 [=>............................] - ETA: 758s - loss: 7.6096 - acc: 0.1125x:  [[[[ 0.74509804  0.62352941  0.68235294]
   [ 0.76862745  0.64705882  0.70588235]
   [ 0.77254902  0.64705882  0.72941176]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.90588235  0.78431373  0.84313725]
   [ 0.83921569  0.71372549  0.76470588]
   [ 0.79215686  0.6627451   0.72941176]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.78431373  0.65882353  0.70980392]
   [ 0.81176471  0.68627451  0.72941176]
   [ 0.83921569  0.71372549  0.75686275]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [

 22/280 [=>............................] - ETA: 743s - loss: 7.5533 - acc: 0.1136x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [

 24/280 [=>............................] - ETA: 729s - loss: 7.5114 - acc: 0.1146x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [

 26/280 [=>............................] - ETA: 716s - loss: 7.4154 - acc: 0.1346x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.86666667  0.85882353  0.8627451 ]
   [ 0.87843137  0.87058824  0.8745098 ]
   [ 0.89019608  0.88235294  0.88627451]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.88627451  0.87058824  0.85882353]
   [ 0.90588235  0.88235294  0.89019608]
   [ 0.91764706  0.89411765  0.90196078]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.90588235  0.88627451  0.8745098 ]
   [ 0.90196078  0.88627451  0.8745098 ]
   [ 0.90980392  0.89411765  0.88235294]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [

 28/280 [==>...........................] - ETA: 699s - loss: 7.3902 - acc: 0.1339x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [

 30/280 [==>...........................] - ETA: 686s - loss: 7.3437 - acc: 0.1333x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.61960784  0.46666667  0.63921569]
   [ 0.70196078  0.54901961  0.72156863]
   [ 0.69411765  0.54901961  0.71764706]
   ..., 
   [

 32/280 [==>...........................] - ETA: 678s - loss: 7.3884 - acc: 0.1250x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.84313725  0.81568627  0.78431373]
   [ 0.82352941  0.79607843  0.77254902]
   [ 0.84313725  0.82745098  0.79215686]
   ..., 
   [

 34/280 [==>...........................] - ETA: 667s - loss: 7.4606 - acc: 0.1176x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [

 36/280 [==>...........................] - ETA: 660s - loss: 7.4790 - acc: 0.1111x:  [[[[ 1.          1.          1.        ]
   [ 1.          1.          1.        ]
   [ 1.          1.          1.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 1.          1.          1.        ]
   [ 1.          1.          1.        ]
   [ 1.          1.          1.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 1.          1.          1.        ]
   [ 1.          1.          1.        ]
   [ 1.          1.          1.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [

 38/280 [===>..........................] - ETA: 650s - loss: 7.4859 - acc: 0.1053x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.82352941  0.50196078  0.68235294]
   [ 0.85098039  0.52156863  0.71372549]
   [ 0.85882353  0.55294118  0.74117647]
   ..., 
   [

 40/280 [===>..........................] - ETA: 640s - loss: 7.4957 - acc: 0.1125x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.62745098  0.66666667  0.70196078]
   [ 0.67058824  0.75294118  0.73333333]
   [ 0.79215686  0.83137255  0.78823529]
   ..., 
   [

 42/280 [===>..........................] - ETA: 640s - loss: 7.4545 - acc: 0.1190x:  [[[[ 0.93333333  0.51372549  0.96078431]
   [ 0.89019608  0.44705882  0.90980392]
   [ 0.99215686  0.49803922  0.88627451]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.92156863  0.5254902   0.94117647]
   [ 0.92156863  0.49803922  0.96470588]
   [ 0.91372549  0.45882353  0.85490196]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.87058824  0.48235294  0.89803922]
   [ 0.91372549  0.49803922  0.93333333]
   [ 1.          0.66666667  1.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [

 44/280 [===>..........................] - ETA: 636s - loss: 7.4589 - acc: 0.1193x:  [[[[ 0.79215686  0.71764706  0.82745098]
   [ 0.69019608  0.60392157  0.75686275]
   [ 0.57647059  0.51372549  0.65882353]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.63921569  0.58039216  0.70196078]
   [ 0.61176471  0.52941176  0.68235294]
   [ 0.6         0.53333333  0.69803922]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.64313725  0.57647059  0.70980392]
   [ 0.67058824  0.60784314  0.75294118]
   [ 0.56862745  0.50980392  0.67058824]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [

 46/280 [===>..........................] - ETA: 645s - loss: 7.4410 - acc: 0.1196x:  [[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  ..., 
  [[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   ..., 
   [

KeyboardInterrupt: 

In [10]:
def customized_loss(args):
    #A is from the training data
    #S is the internal state
    A, A', S, S' = args 
    #customize your own loss components
    loss1 = K.mean(K.square(A - A'), axis=-1)
    loss2 = K.mean(K.square(S - S'), axis=-1)
    #adjust the weight between loss components
    return 0.5 * loss1 + 0.5 * loss2

 def model():
     #define other inputs
     A = Input(...) # define input A
     #construct your model 
     cnn_model = Sequential()
     ...
     # get true internal state
     S = cnn_model(prev_layer_output0)
     # get predicted internal state output
     S' = Dense(...)(prev_layer_output1)
     # get predicted A output
     A' = Dense(...)(prev_layer_output2)
     # customized loss function
     loss_out = Lambda(customized_loss, output_shape(1,), name='joint_loss')([A, A', S, S'])
     model = Model(input=[...], output=[loss_out])
     return model

  def train():
      m = model()
      opt = 'adam'
      model.compile(loss={'joint_loss': lambda y_true, y_pred:y_pred}, optimizer = opt)
      # train the model 
                            

IndentationError: unindent does not match any outer indentation level (<tokenize>, line 11)