# Gravity Spy

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
from sklearn.metrics import accuracy_score, confusion_matrix

import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import models, layers, optimizers
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

In [None]:
train_dir = '../input/gravity-spy-gravitational-waves/train/train/'
test_dir = '../input//gravity-spy-gravitational-waves/test/test/'
validation_dir = '../input//gravity-spy-gravitational-waves/validation/validation/'

In [None]:
class_names = os.listdir(train_dir)

# Preprocessing

In [None]:
plt.figure(figsize=(20,14))
for fold, i in zip(class_names,range(0,len(class_names))):

    plt.subplot(4,6, i+1)
    img_read = plt.imread(train_dir+fold+'/'+os.listdir(train_dir+fold)[0])
    plt.imshow(img_read)
    plt.title(class_names[i])
    
plt.show()

In [None]:
Batch_size = 128

tr_gen = ImageDataGenerator(rescale = 1./255.,
                           samplewise_center=True,
                           samplewise_std_normalization=True)

val_gen =  ImageDataGenerator(rescale = 1./255.,
                           samplewise_center=True,
                           samplewise_std_normalization=True)

ts_gen =  ImageDataGenerator(rescale = 1./255.,
                           samplewise_center=True,
                           samplewise_std_normalization=True)

In [None]:
tr_gen = tr_gen.flow_from_directory(train_dir,
                                  batch_size=Batch_size,
                                  classes=class_names,
                                  class_mode='categorical',
                                  shuffle=True,
                                  target_size=(300,300),
                                  seed= 42)

In [None]:
val_gen = val_gen.flow_from_directory(validation_dir,
                                      batch_size=Batch_size,
                                      classes=class_names,
                                      class_mode='categorical',
                                      shuffle=True,
                                      target_size=(300,300),
                                      seed= 42)

In [None]:
ts_gen = ts_gen.flow_from_directory(test_dir,
                                  batch_size=Batch_size,
                                  classes=class_names,
                                  class_mode='categorical',
                                  shuffle=False,
                                  target_size=(300,300),
                                  seed= 42)

# CNN Model

## Architecture

In [None]:
model = models.Sequential()

# Input and Conv1
model.add(layers.Conv2D(filters=4, kernel_size=(3,3), strides=(2,2),
                        activation = 'relu', padding='valid',
                        input_shape = (300,300,3)))
model.add(layers.Dropout(0.2))
model.add(layers.MaxPool2D(strides=2))
model.add(layers.BatchNormalization())

#Conv2

model.add(layers.Conv2D(filters=16, kernel_size=(3,3), strides=(2,2),
                       activation='relu', padding='valid'))
model.add(layers.Dropout(0.2))
model.add(layers.MaxPool2D(strides=2))
model.add(layers.BatchNormalization())

#Conv3

model.add(layers.Conv2D(filters=32, kernel_size=(3,3),dilation_rate=(2,2),
                       activation='relu', padding='same'))
model.add(layers.Conv2D(filters=32, kernel_size=(3,3),
                       activation='relu'))
model.add(layers.Dropout(0.2))
model.add(layers.MaxPool2D(strides=3))
model.add(layers.BatchNormalization())

#Flattening and Fully Connected

model.add(layers.Flatten())

model.add(layers.Dense(200, activation= 'relu'))
model.add(layers.Dense(100, activation= 'relu'))
model.add(layers.Dropout(0.3))
model.add(layers.BatchNormalization())

# Output
model.add(layers.Dense(22, activation='softmax'))

model.summary()

## Fitting

In [None]:
model.compile(optimizer=optimizers.Adam(lr=0.01),
             loss='categorical_crossentropy',
             metrics=['accuracy'])

es = EarlyStopping(monitor='val_loss',mode='min',patience=3,verbose=1)
RLr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience = 2, verbose = 1, min_delta=1e-3,min_lr=1e-6)

history= model.fit(tr_gen,
                 epochs=5,
                 steps_per_epoch=tr_gen.n/Batch_size,
                 validation_data=val_gen,
                 validation_steps=val_gen.n/Batch_size,
                 callbacks=[es,RLr])

### Evaluation

In [None]:
fig, ax=plt.subplots(2,1,figsize=(12,10))
fig.suptitle('Train evaluation')

with plt.style.context('fivethirtyeight'):

    sns.lineplot(ax= ax[0],x=np.arange(0,len(history.history['accuracy'])),y=history.history['accuracy'])
    sns.lineplot(ax= ax[0],x=np.arange(0,len(history.history['accuracy'])),y=history.history['val_accuracy'])

    ax[0].legend(['Train','Validation'])
    ax[0].set_title('Accuracy')
    ax[0].grid()

    sns.lineplot(ax= ax[1],x=np.arange(0,len(history.history['loss'])),y=history.history['loss'])
    sns.lineplot(ax= ax[1],x=np.arange(0,len(history.history['loss'])),y=history.history['val_loss'])

    ax[1].legend(['Train','Validation'])
    ax[1].set_title('Loss')
    ax[1].grid()
    

plt.show()

# Predictions 

In [None]:
Y_pred = model.predict(ts_gen,steps=np.ceil(ts_gen.n/Batch_size))
Y_pred = np.argmax(Y_pred, axis=1)
Y_ts = ts_gen.classes

In [None]:
conf_mat = confusion_matrix(Y_ts,Y_pred)
sns.set_style(style='dark')
plt.figure(figsize=(16,10))
heatmap = sns.heatmap(conf_mat,vmin=np.min(conf_mat.all()), vmax=np.max(conf_mat), annot=True,fmt='d', annot_kws={"fontsize":12},cmap='Spectral')
heatmap.set_title('Confusion Matrix Heatmap\nGravitational Wave Type', fontdict={'fontsize':15}, pad=12)
heatmap.set_xlabel('Predicted',fontdict={'fontsize':14})
heatmap.set_ylabel('Actual',fontdict={'fontsize':14})
heatmap.set_xticklabels(class_names, fontdict={'fontsize':12,'rotation': 75})
heatmap.set_yticklabels(class_names, fontdict={'fontsize':12,'rotation': 15})
plt.show()

print('-Accuracy achieved: {:.2f}%\n-Accuracy by model was: {:.2f}%\n-Accuracy by validation was: {:.2f}%'.
      format(accuracy_score(Y_ts,Y_pred)*100,(history.history['accuracy'][-1])*100,(history.history['val_accuracy'][-1])*100))

# Errors Sample

In [None]:
index=0
index_errors= []

for label, predict in zip(Y_ts,Y_pred):
    if label != predict:
        index_errors.append(index)
    index +=1

In [None]:
plt.figure(figsize=(20,20))

for i,img_index in zip(range(1,21),random.sample(index_errors,k=20)):
    plt.subplot(4,5,i)
    img_read = plt.imread(ts_gen.filepaths[img_index])
    plt.imshow(img_read)
    plt.title('Actual: '+str(class_names[Y_ts[img_index]])+'\nPredict: '+str(class_names[Y_pred[img_index]])+'\nImage_index:'+str(img_index))
plt.show()