# Use Convolutional Neural Network to Identify Spiral Arms
Transfer Learning: use the EfficientNet Model. 
--
Reference:
Kalvankar et al. 2020
https://ui.adsabs.harvard.edu/abs/2020arXiv200813611K

We first try freezing all EfficientNet layers.... 

### Importing the libraries

In [None]:
# !pip install efficientnet

In [1]:
import numpy as np
from IPython.display import clear_output
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

import tensorflow as tf
print(tf.__version__)
import keras
print(keras.__version__)

from keras.preprocessing.image import ImageDataGenerator

import efficientnet.keras as efn 
from efficientnet.keras import preprocess_input

2.4.0
2.4.3


## Data Preprocessing
### Preprocessing the Training set

In [2]:
# customised image sizes
szx = 200
szy = 200
szz = 3

train_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

training_set = train_datagen.flow_from_directory('dataset/training',
                                                 target_size=(szx, szy),
                                                 batch_size=32,
                                                 subset='training',
                                                 shuffle=False,
                                                 class_mode='categorical')
STEP_SIZE_TRAIN = training_set.n // training_set.batch_size

Found 31848 images belonging to 3 classes.


### Preprocessing the Validation set

In [3]:
valid_datagen = ImageDataGenerator(preprocessing_function=preprocess_input,
                                   validation_split=0.15)

valid_set = valid_datagen.flow_from_directory('dataset/training',
                                              target_size=(szx, szy),
                                              batch_size=32,
                                              subset='validation',
                                              shuffle=False,
                                              class_mode='categorical')
STEP_SIZE_VALID = valid_set.n // valid_set.batch_size

Found 4776 images belonging to 3 classes.


## Setting up the EfficientNet Model
Keep the trainable parameters minimal at the  moment. If we open more layers for training, we need
GPUs (e.g., SageMaker on AWS)

In [5]:
from keras.models import Model
base_model = efn.EfficientNetB4(weights='imagenet', include_top=False, 
                                input_shape=(szx, szy, szz), pooling='avg')
output = base_model.layers[-1].output
output = keras.layers.Flatten()(output)
model_enet = Model(base_model.input, output)       

model_enet.trainable = False
for layer in model_enet.layers:
    layer.trainable = False

import pandas as pd 
pd.set_option('max_colwidth', None)
layers = [(layer, layer.name, layer.trainable) for layer in model_enet.layers]
df_model_show = pd.DataFrame(layers, columns=['Layer Type', 'Layer Name', 'Trainable or not'])

In [6]:
# check that all layers are indeed frozen
df_model_show['Trainable or not'].describe()

count       469
unique        1
top       False
freq        469
Name: Trainable or not, dtype: object

In [7]:
display(df_model_show.tail(20))

Unnamed: 0,Layer Type,Layer Name,Trainable or not
449,<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7ff48b09d970>,block7b_expand_conv,False
450,<tensorflow.python.keras.layers.normalization_v2.BatchNormalization object at 0x7ff48a77f2b0>,block7b_expand_bn,False
451,<tensorflow.python.keras.layers.core.Activation object at 0x7ff48b09dd90>,block7b_expand_activation,False
452,<tensorflow.python.keras.layers.convolutional.DepthwiseConv2D object at 0x7ff48a2c1520>,block7b_dwconv,False
453,<tensorflow.python.keras.layers.normalization_v2.BatchNormalization object at 0x7ff48b09de50>,block7b_bn,False
454,<tensorflow.python.keras.layers.core.Activation object at 0x7ff48b0aa910>,block7b_activation,False
455,<tensorflow.python.keras.layers.pooling.GlobalAveragePooling2D object at 0x7ff48b0b5ac0>,block7b_se_squeeze,False
456,<tensorflow.python.keras.layers.core.Reshape object at 0x7ff482ec1850>,block7b_se_reshape,False
457,<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7ff48b0aa370>,block7b_se_reduce,False
458,<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7ff48b0a3b50>,block7b_se_expand,False


In [8]:
print(model_enet.output_shape)

(None, 1792)


In [9]:
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, InputLayer
from keras.models import Sequential
from keras import optimizers

model = Sequential()
model.add(model_enet)
model.add(Dense(512, activation='relu', input_dim=model_enet.output_shape[1]))
model.add(Dropout(0.3))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(3, activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.Adam(lr=1e-4),
              metrics=['accuracy'])

model.summary()


Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
model_1 (Functional)         (None, 1792)              17673816  
_________________________________________________________________
dense (Dense)                (None, 512)               918016    
_________________________________________________________________
dropout (Dropout)            (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               262656    
_________________________________________________________________
dropout_1 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 3)                 1539      
Total params: 18,856,027
Trainable params: 1,182,211
Non-trainable params: 17,673,816
____________________________________

## Training the Model

#### add visualisation to monitor the training and validation accuracy real-time:
* code block below adapted from https://github.com/kapil-varshney/utilities/blob/master/training_plot/training_plot_ex_with_cifar10.ipynb

In [10]:
class TrainingPlot(keras.callbacks.Callback):
    
    # This function is called when the training begins
    def on_train_begin(self, logs={}):
        # Initialize the lists for holding the logs, losses and accuracies
        self.losses = []
        self.acc = []
        self.val_losses = []
        self.val_acc = []
        self.logs = []
    
    # This function is called at the end of each epoch
    def on_epoch_end(self, epoch, logs={}):
        
        # Append the logs, losses and accuracies to the lists
        self.logs.append(logs)
        self.losses.append(logs.get('loss'))
        self.acc.append(logs.get('accuracy'))
        self.val_losses.append(logs.get('val_loss'))
        self.val_acc.append(logs.get('val_accuracy'))
        
        # Before plotting ensure at least 2 epochs have passed
        if len(self.losses) > 1:
            
            # Clear the previous plot
            clear_output(wait=True)
            N = np.arange(0, len(self.losses))
            
            # You can chose the style of your preference
            # print(plt.style.available) to see the available options
            plt.style.use("seaborn-talk")
            
            # Plot train loss, train acc, val loss and val acc against epochs passed
            plt.figure()
            plt.plot(N, self.losses, linestyle=':', label = "train_loss")
            plt.plot(N, self.acc, linestyle=':', label = "train_accuracy")
            plt.plot(N, self.val_losses, label = "val_loss")
            plt.plot(N, self.val_acc, label = "val_accuracy")
            plt.title("Training Loss and Accuracy [Epoch {}]".format(epoch))
            plt.xlabel("Epoch #")
            plt.ylabel("Loss/Accuracy")
            plt.legend()
            plt.show()
plot_losses = TrainingPlot()

In [11]:
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, CSVLogger

checkpointer = ModelCheckpoint(
    filepath='/Users/tiantianyuan/work/learn_py/self/astro/dataset/wts_enet4_model_freeze.h5', verbose=2, save_best_only=True)

early_stopping = EarlyStopping(
    monitor='val_loss', patience=10, verbose=1, mode='auto')

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                              patience=4)

csv_logger = CSVLogger('/Users/tiantianyuan/work/learn_py/self/astro/dataset/wts_enet4_model_freeze.csv')

In [None]:
import time
t1 = time.time()
results = model.fit(training_set,
                    steps_per_epoch=STEP_SIZE_TRAIN,
                    validation_data=valid_set,
                    validation_steps=STEP_SIZE_VALID,
                    epochs=10,
                    callbacks=[plot_losses, checkpointer, early_stopping, reduce_lr, csv_logger])
t2 = time.time()
print('Model running time is {:.2f}mins'.format((t2 - t1)/60))

### Save models

In [None]:
# Save the entire model as a SavedModel.
# !mkdir -p saved_model
model_enet.save('saved_model/enet4_model_freeze')

In [None]:
model_enet.save('saved_model/enet4_model_freeze.h5')

### Model evaluation using Confusion Matrix and F1 score

In [None]:
from keras.models import Model,load_model
model_check = load_model('saved_model/enet4_model_freeze.h5')

In [None]:
test_datagen = ImageDataGenerator(rescale=1./255)
test_set = test_datagen.flow_from_directory('dataset/test',
                                            target_size=(szx, szy),
                                            batch_size=32,
                                            shuffle=False,
                                            class_mode='categorical')

test_set.reset()

Y_pred = model_check.predict(
                            test_set,
                            steps=test_set.n / test_set.batch_size,
                            verbose=1)

y_pred = np.argmax(Y_pred, axis=1)

In [None]:
from sklearn.metrics import classification_report, confusion_matrix

cm = confusion_matrix(test_set.classes, y_pred)

print('The confusion matrix is \n{}\n'.format(cm))

f1 = classification_report(test_set.classes, y_pred, target_names=training_set.class_indices)
print('F1 score is {}\n'.format(f1))

### Comment