In [None]:
import numpy as np
from keras.applications.inception_v3 import InceptionV3
from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D, Dropout
from keras import backend as K
from keras.optimizers import SGD
from keras.callbacks import (
    ReduceLROnPlateau,
    CSVLogger,
    EarlyStopping,
    ModelCheckpoint)

In [None]:
# parameters:

SZ = 224

BZ = 500

LN = 249

LR = 0.001

DO = 0.5

RG = 50

CLASS = 5

# training and validation folder

echo_train = '/ysm-gpfs/home/hc487/project/data_RGB3/train'
echo_validation = '/ysm-gpfs/home/hc487/project/data_RGB3/validation'

# create the base pre-trained model
base_model = InceptionV3(weights= "imagenet", input_shape = (224,224,3), include_top=False)

x = base_model.output
x = Dropout(DO)(x)
x = GlobalAveragePooling2D()(x)
x = Dropout(DO)(x)
predictions = Dense(CLASS, activation='softmax')(x)

# this is the model we will train
model = Model(inputs=base_model.input, outputs=predictions)

model.summary()

for i, layer in enumerate(base_model.layers):
    print(i, layer.name)

In [None]:
# set model compile:
model.compile(optimizer=SGD(lr=LR, momentum=0.9), loss='categorical_crossentropy',metrics=['accuracy'])

# Checkpoint
checkpointer = ModelCheckpoint(
    filepath="model_checkpoint_{}_{}.h5".format("first", "title"),
    verbose=1,
    save_best_only=True)

# csvlogger
csv_logger = CSVLogger(
    'csv_logger_{}_{}.csv'.format("first", "title"))

# image data generator:
train_datagen = ImageDataGenerator(
    rotation_range=15., # rotation
    width_shift_range=0.1,
    height_shift_range=0.1,
    rescale=1./255,
    shear_range=0.,
    zoom_range=0.1,
    channel_shift_range=0.,
    fill_mode='nearest',
    #fill_mode = "constant",
    cval=0.,
    horizontal_flip=True,
    vertical_flip=False,
    preprocessing_function=None)

train_generator = train_datagen.flow_from_directory(
        echo_train,
        target_size=(SZ, SZ),
        batch_size = BZ,
        shuffle = True,
        classes = ["ASD","DCM","HP","MI","NORM"],
        class_mode='categorical',
        seed = 42)

test_datagen = ImageDataGenerator(
    rescale=1./255)



test_generator = test_datagen.flow_from_directory(
        echo_validation,
        target_size=(SZ, SZ),
        batch_size = BZ,
        shuffle = True,
        classes = ["ASD","DCM","HP","MI","NORM"],
        class_mode='categorical',
        seed = 42)

In [None]:
# train 50 epochs

my_class_weight = {0:1., 1:2., 2:1., 3:2., 4:1., 5:1.}

model.fit_generator(train_generator, 
                    steps_per_epoch = 100000. / BZ, 
                    epochs= 50, 
                    validation_data = test_generator, 
                    validation_steps = 10000. / BZ, 
                    class_weight = my_class_weight, 
                    callbacks=[csv_logger, checkpointer])