In [None]:
import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:

    try:
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
        tf.config.experimental.set_virtual_device_configuration(
            gpus[0],
            [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=7000)])
    except RuntimeError as e:

        print(e)

In [None]:
# -------------------------------- Xception ---------------------------------------------
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2, Xception
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Dense, Dropout, Flatten, BatchNormalization
from tensorflow.keras.optimizers import RMSprop, Adam, SGD
import matplotlib.pyplot as plt


train_dir = '/home/workspace/user-workspace/crack_data_full_8_2/train'
validation_dir = '/home/workspace/user-workspace/crack_data_full_8_2/validation'




train_datagen = ImageDataGenerator(rescale=1/255,
                                   rotation_range=10,
                                   width_shift_range=0.1,
                                   height_shift_range=0.1,
                                   shear_range=0.2,
                                   zoom_range=0.2,
                                   horizontal_flip=True,
                                   vertical_flip =True)


validation_datagen = ImageDataGenerator(rescale=1/255)


train_generator = train_datagen.flow_from_directory(
    train_dir,                   
    classes=['0', '1'],     
    target_size=(71,71),        
    batch_size=30,                
    class_mode='binary')         
                     
validation_generator = validation_datagen.flow_from_directory(
    validation_dir,
    classes=['0', '1'],
    target_size=(71,71),
    batch_size=32,
    class_mode='binary')


with tf.device('/device:GPU:0'):
    model_base = Xception(weights='imagenet',
                       include_top=False,
                       input_shape=(71,71,3))

    model_base.trainable = False  

    model = Sequential()

    model.add(model_base)

    model.add(Flatten(input_shape=(3*3*2048,)))
    model.add(Dense(256,
                    activation='relu'))
    model.add(BatchNormalization())    
    model.add(Dense(1,
                    activation='sigmoid'))

    model.summary()

    model.compile(optimizer=Adam(learning_rate=2e-4),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

    history = model.fit(train_generator,
                        steps_per_epoch=375,
                        epochs=20,
                        validation_data=validation_generator,
                        validation_steps=88,
                        verbose=2)

   


    fine_tune_at = 10

    model_base.trainable = True

    for layer in model_base.layers[-fine_tune_at:]:
        layer.trainable =  False

         
    model.compile(optimizer=Adam(learning_rate=2e-5),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

    
    history = model.fit(train_generator,
                        steps_per_epoch=375,
                        epochs=20,
                        validation_data=validation_generator,
                        validation_steps=88,
                        verbose=2)
    
train_acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
train_loss = history.history['loss']
val_loss = history.history['val_loss']

plt.plot(train_acc, 'bo', color='r', label='training accuracy')
plt.plot(val_acc, 'b', color='b', label='validation accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()

plt.plot(train_loss, 'bo', color='r', label='training loss')
plt.plot(val_loss, 'b', color='b', label='validation loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
model.save('/home/workspace/user-workspace/JSI/model_save/Xception_2.h5')