In [None]:
from tensorflow import keras
from keras.layers.core import Dense, Dropout, Activation, Flatten
from tensorflow.keras.models import Model
from keras.applications.vgg16 import VGG16
from tensorflow.keras import optimizers
from keras.callbacks import EarlyStopping, ModelCheckpoint, Callback
from keras.preprocessing.image import ImageDataGenerator

data_dir = data_dir = '/Volumes/Samsung_T5/siim-covid19-detection'
train_dir = data_dir + '/' + 'train/study_class'
test_dir = data_dir + '/' + 'test/original'
n_folds = 5
epochs = 20
img_row, img_col = 1024, 1024
batch_size = 64

def get_model():
    base_model = VGG16(include_top=False, weights=None, input_shape=(img_row, img_col, 1))
    
    #最上位全結合層を定義
    out = Flatten()(base_model.output)
    out = Dense(2048, activation='relu')(out)
    out = Dropout(0.5)(out)
    out = Dense(2048, activation='relu')(out)
    out = Dropout(0.5)(out)
    output = Dense(4, activation='softmax')(out)
    model = Model(inputs=base_model.input, outputs=output)
    
    sgd = optimizers.SGD(learning_rate=1e-4, decay=1e-6, momentum=0.9, nesterov=True)
    model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])
    return model

datagen = ImageDataGenerator(
        rescale = 1. / 255,
        rotation_range=3,
        width_shift_range=0.05,
        height_shift_range=0.05,
        zoom_range=0.05,
        horizontal_flip=False,
        validation_split=0.2
)

train_gen = datagen.flow_from_directory(
      train_dir,
      target_size=(img_row, img_col),
      batch_size=batch_size,
      class_mode='categorical',
      color_mode='grayscale',
      subset='training'
)

valid_gen = datagen.flow_from_directory(
      train_dir,
      target_size=(img_row, img_col),
      batch_size=batch_size,
      class_mode='categorical',
      color_mode='grayscale',
      subset='validation'
)

weight_path = '{}/weight_path'.format(data_dir)
cp_callback = ModelCheckpoint(
    filepath =weight_path,
    save_weight_only=True,
    save_best_only=True,
    monitor= 'val_loss',
    verbose=1
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    min_delta=0.0,
    patience=3
)

model = get_model()

model.fit(
    train_gen,
    #steps_per_epoch=5070/batch_size,  # 5070 is train data length
    validation_data = valid_gen,
    #validation_steps=1265/batch_size, # 1265 is validation data length
    epochs = epochs,
    batch_size = batch_size,
    callbacks = [cp_callback, early_stopping],
    shuffle = True,
    verbose = 1
)


Found 5070 images belonging to 4 classes.
Found 1265 images belonging to 4 classes.
Epoch 1/20
