# Galaxy Calssification using  the EFIGI reference dataset 

## Author: Avi Vajpeyi, Dr. Rahul Remanan 
### This project was conceived as part of the 2018 Summer Internship, [@Moad Computer](https://www.moad.computer)

### [EFIGI data](https://www.astromatic.net/projects/efigi)


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rahulremanan/akash_ganga/blob/master/EFIGI_Galaxy_Classification.ipynb)


## Setting notebook behavior

In [0]:
download_raw_data = False #Set True to download the raw data from https://www.astromatic.net/download/efigi/
setup = True #Set this flag to True to install all the dependencies and load data from the cloud object storage
upload_data = False

## Copy preprocessed data to Drive


## Connect Google Drive to this session

This code adds ./drive/ (your google drive home folder) to the current session. You cannot cd into this but can access the files on google drive this way.

Code obtained from [Google Colab Free GPU Tutorial](https://medium.com/deep-learning-turkey/google-colab-free-gpu-tutorial-e113627b9f5d)

In [0]:
if setup:
  from google.colab import drive
  drive.mount('/content/gdrive')

### Copy and unzip sorted data from drive to local

Data accessible from ./data/train and ./data/validation

In [0]:
if not download_raw_data and setup:
  ! cp /content/gdrive/My\ Drive/sorted_data_EFIGI.zip ./
! mkdir ./drive
! mkdir ./drive/EFIGI_Galaxy_Classification/
! mkdir ./drive/EFIGI_Galaxy_Classification/output/
! mkdir ./drive/EFIGI_Galaxy_Classification/output/checkpoint
! cp /content/gdrive/My\ Drive/Transfer_learn_299_299_EFIGI.h5 ./drive/EFIGI_Galaxy_Classification/output//checkpoint/
! mv ./drive/EFIGI_Galaxy_Classification/output//checkpoint/Transfer_learn_299_299_EFIGI.h5 ./drive/EFIGI_Galaxy_Classification/output//checkpoint/Transfer_learn_299_299_.h5
if setup:
  ! unzip -q sorted_data_EFIGI.zip

## Define python function to run linux commands

In [0]:
import subprocess
 
def execute_in_shell(command=None,
                     verbose=False):
    """
    This is a function that executes shell scripts from within python.

    Example usage:
    execute_in_shell(command = ['ls ./some/folder/',
                                'ls ./some/folder/  -1 | wc -l'],
                     verbose = True )

    This command returns dictionary with elements: Output and Error.

    Output records the console output,
    Error records the console error messages.

    :param command: takes a list of shell commands
    :param verbose: takes a boolean value to set verbose level
    :return: Dictionary with two elements Output and Error
    """

    error = []
    output = []

    if isinstance(command, list):
        for i in range(len(command)):
            try:
                process = subprocess.Popen(command[i], shell=True,
                                           stdout=subprocess.PIPE)
                process.wait()
                out, err = process.communicate()
                error.append(err)
                output.append(out)
                if verbose:
                    print(
                        'Success running shell command: {}'.format(command[i]))
            except Exception as e:
                print('Failed running shell command: {}'.format(command[i]))
                if verbose:
                    print(type(e))
                    print(e.args)
                    print(e)

    else:
        print('The argument command takes a list input ...')
    return {'Output': output, 'Error': error}

## Download and sort data
Currently only using coloured processed images rather than the sperate images that the final image is composed of. 



### Download data
The EFIGI dataset info we are using is in 6 separate compressed archives (gzipped
tar format):
- efigi_tables-1.6.tgz: 6 ASCII tables, including morphological information
- efigi_png_gri-1.6.tgz: 4458 PNG images in the SDSS g,r and i bands
- efigi_ima_u-1.6.tgz: 4458 galaxy images in the SDSS u-band (FITS format)
- efigi_ima_g-1.6.tgz: 4458 galaxy images in the SDSS g-band (FITS format)
- efigi_ima_r-1.6.tgz: 4458 galaxy images in the SDSS r-band (FITS format)
- efigi_ima_i-1.6.tgz: 4458 galaxy images in the SDSS i-band (FITS format)
- efigi_ima_z-1.6.tgz: 4458 galaxy images in the SDSS z-band (FITS format)


In [0]:
if download_raw_data:
  ! mkdir ./data
  ! mkdir ./data/raw

 
  ! wget  -O ./data/raw/efigi-1.6.tgz "https://www.astromatic.net/download/efigi/efigi_tables-1.6.2.tgz"
  ! wget  -O ./data/raw/efigi_pics.tgz "https://www.astromatic.net/download/efigi/efigi_png_gri-1.6.tgz"
  ! wget  -O ./data/raw/efigi_u_pics.tgz "https://www.astromatic.net/download/efigi/efigi_ima_u-1.6.tgz"
  ! wget  -O ./data/raw/efigi_g_pics.tgz "https://www.astromatic.net/download/efigi/efigi_ima_g-1.6.tgz"
  ! wget  -O ./data/raw/efigi_r_pics.tgz "https://www.astromatic.net/download/efigi/efigi_ima_r-1.6.tgz"
  ! wget  -O ./data/raw/efigi_i_pics.tgz "https://www.astromatic.net/download/efigi/efigi_ima_i-1.6.tgz"
  ! wget  -O ./data/raw/efigi_z_pics.tgz "https://www.astromatic.net/download/efigi/efigi_ima_z-1.6.tgz"

### Zip and move raw files to the object drive
DO ONLY ONCE

In [0]:
if upload_data:
  ! zip -q -r raw_data_EFIGI.zip ./data/raw/
  ! cp raw_data_EFIGI.zip /content/gdrive//My\ Drive/

### Unpack data from tgz
Data stored in

* Tables:  ` ./data/raw/efigi-1.6/ `
* Colored Images:   ` ./data/raw/efigi-1.6/png ` 
* FITS: `/efigi-1.6/ima_g,  ima_i, ima_u, ima_z`


In [0]:
if download_raw_data:
  import glob
  tgz_files = glob.glob("./data/raw/*tgz")
  for tgz_file in tgz_files:
    command=["tar xzf "+tgz_file+" -C ./data/raw/", "rm "+tgz_file]
    execute_in_shell(command, verbose=True)

### Convert fits to png

In [0]:
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
from astropy.utils.data import get_pkg_data_filename
from astropy.io import fits
import numpy as np
import glob
import os
import cv2

def fits_to_png(fits_fn):
    # Generally the image information is located in the Primary HDU (ext 0)
    # read the image data from this first extension using the keyword argument
    data = fits.getdata(fits_fn, ext=0)

    sizes = np.shape(data)
    height = float(sizes[0])
    width = float(sizes[1])

    fig = plt.figure()
    fig.set_size_inches(width / height, 1, forward=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)

    ax.imshow(data, cmap="binary")

    # createing png filename from fits filename
    png_fn = fits_fn.split(".fits")[0] + ".png"


    plt.savefig(png_fn, dpi=height)
    plt.close()


def fits_folder_to_png(dir, verbose):

    fits_files = glob.glob( dir+"*.fits")
    num_files = len(fits_files)
    status_flag = num_files * 0.1

    for i in range(0, num_files):
        fits_to_png(fits_files[i])

        if verbose and i > status_flag:
            status_flag += num_files * 0.1
            p_done = (i * 100// num_files) 
            print(str(p_done)+"% processed")

def delete_fits_from_folder(dir):
    fits_files = glob.glob(dir+"*.fits")
    for f in fits_files:
      os.remove(f)


def make_movie_from_png(video_name, dir):

    images = glob.glob(dir + "*.png")
    frame = cv2.imread(images[0])
    height, width, layers = frame.shape

    video = cv2.VideoWriter(video_name, -1, 25, (width, height))

    for image in images:
        video.write(cv2.imread(image))

    cv2.destroyAllWindows()
    video.release()

Execute following to convert the images into pngs.
This can take a while...Can be timed out and if you are then run again, itll pick up where it left off.

In [0]:
import glob
import os
import gc

if setup:
  FITS_folders = ["ima_g", "ima_i","ima_u","ima_z","ima_r"]
  for fits_folder in FITS_folders:
    dir = "./data/raw/efigi-1.6/"+fits_folder+"/"
    print("Processing "+dir)
    fits_folder_to_png(dir, verbose=True)
    delete_fits_from_folder(dir)
    gc.collect()

### Make Galaxy Type Enum 

In [0]:
from enum import Enum, auto, unique
@unique
class T(Enum):
  ''' Enum to store the different types of galaxies
  '''
  ELLIPTICAL = auto()
  LENTICULAR = auto()
  SPIRAL = auto()
  IRREGULAR = auto()
  DWARF = auto()

  def __str__(self):
    '''To print the name of the galaxy type when enum printed
    '''
    return str(self.name)

def check_class(t_val):
  '''Takes the t_val attribute and returns the associated enum
  '''
  try:
    t_val = int(t_val)
  except ValueError:
    pass  # it was a string, not an int.
  if t_val < -3:
    return T.ELLIPTICAL
  elif  t_val < 0:
    return T.LENTICULAR
  elif t_val < 10:
    return T.SPIRAL
  elif t_val == 10:
    return T.IRREGULAR
  elif t_val == 11:
    return T.DWARF
  else:
    print ("ERROR")
    # raise exception
    return null

### Make Organisational Folders

In [0]:
import os
train_dir = './data/train/'
val_dir = './data/validation'
if setup and download_raw_data:
  galaxy_classes = [name for name, gal_type in T.__members__.items()]
  execute_in_shell(["mkdir {} {}".format(train_dir, val_dir)])
  for galaxy_class in galaxy_classes:
    commands =["mkdir {}{} {}{}"
               .format(train_dir, 
                       galaxy_class, 
                       val_dir, 
                       galaxy_class)]
    execute_in_shell(commands)

  print("Folders in {}: ".format(train_dir))
  print (os.listdir(train_dir))

### Move files from orignal folder to their classes folder

The table `data/raw/efigi-1.6/EFIGI_attributes.txt` has several attributes. We need the "PGC_name" and "T" (the file name and EFIGI morphological type). Based on this, we will move the file from `./raw` to `./train/{type}`

In [0]:
import shutil


def row_generator(filepath):
  ''' Grabs one row of the txt file if its not a comment
  '''
  with open(filepath) as fp:

      # Skip initial comments that starts with #
      while True:
          row = fp.readline()
          if not row.startswith('#'):
              break

      # Second while loop to process the rest of the file
      while row:
          yield (row)
          row = fp.readline()



def move_file_by_class(filename, type, image_foldername):
  current_dir = "data/raw/efigi-1.6/"+image_foldername+"/"+filename
  destination_dir = "data/train/"+ type.name + "/" + filename
  shutil.move(current_dir, destination_dir)


def move_files_according_to_txt(txt_filepath, img_folder, extension, verbose):
  print("Moving files from "+img_folder)

  count = 0
  for line in row_generator(txt_filepath):
    attributes = line.split()

    # create file name based on PGC_name
    if extension is None:
      image_file_name = attributes[0]+".png"
    else:
      image_file_name = attributes[0]+"_"+extension+".png"

    # get type according to dataset
    image_class = check_class(attributes[1])


    move_file_by_class(image_file_name, image_class, img_folder)

    count +=1
    if count % 100 == 0 and verbose:
      print ("Image Num"+str(count)+": " +image_file_name+ " is a " + image_class.name)
  print("Done moving from "+img_folder+ " to data/train/")



if setup and download_raw_data:
  image_folders = ["png","ima_g", "ima_i", "ima_u", "ima_z", "ima_r"]
  extensions =[None, "g","i","u","z","r"]

  for i in range(0,len(extensions)):
    move_files_according_to_txt(txt_filepath = "data/raw/efigi-1.6/EFIGI_attributes.txt",
                                img_folder = image_folders[i], 
                                extension = extensions[i], 
                                verbose = True)

### Shuffle some files from training folder to validation folder


In [0]:
import os, glob, random 
if setup and download_raw_data: 
  subfolders = [f.path for f in os.scandir(train_dir) if f.is_dir()] 

  # For each training folder 
  for train_class_dir in subfolders:

    # Get total number of files in folder
    images = glob.glob(train_class_dir+"/*.png")
    total_num = len(images)
    print (train_class_dir +" has " + str(total_num)+" images.")

    # Shuffle 20% files
    number_of_validation = int(0.2*float(total_num)) # 20% validation
    files_to_move = random.sample(images, number_of_validation)


    class_name = train_class_dir.split("/")[-1]

    # Move 20% to the validation folder of the same class
    for file_dir in files_to_move:
      destination_dir = file_dir.split("/train/")[0]+"/validation/"+file_dir.split("/train/")[-1]
      shutil.move(file_dir, destination_dir)

    num_images_remaining = len(glob.glob(train_class_dir+"/*.png"))
    print ("After transfer " + str(num_images_remaining)+" images will remain as training data.\n")

### Remove Raw Files

In [0]:
import os, shutil
if setup and download_raw_data:
  shutil.rmtree("./data/raw")

### Zip and move sorted data to drive

In [0]:
if upload_data:
  ! zip -q -r sorted_data_EFIGI.zip ./data
  ! cp -v sorted_data_EFIGI.zip /content/gdrive/My\ Drive/

## Train a deep convolutional neural network calssifier

- Galaxy classifier using Inception-ResNet version 2.


### Imports for ML that will be needed

In [0]:
import argparse
import os
import time
import sys
import glob
try:
    import h5py
except:
    print ('Package h5py needed for saving model weights ...')
    sys.exit(1)
import json
import matplotlib
import matplotlib.pyplot as plt
try:
    import tensorflow
    import keras
except:
    print ('This code uses tensorflow deep-learning framework and keras api ...')
    print ('Install tensorflow and keras to train the classifier ...')
    sys.exit(1)
    
import PIL # Python Imaging Library
from collections import defaultdict
from keras.applications.inception_v3 import InceptionV3,    \
                                            preprocess_input as preprocess_input_inceptionv3
from keras.applications.inception_resnet_v2 import InceptionResNetV2,    \
                                            preprocess_input as preprocess_input_inceptionv4
from keras.models import Model,                             \
                         model_from_json,                    \
                         load_model
from keras.layers import Dense,                             \
                         GlobalAveragePooling2D,            \
                         Dropout,                           \
                         BatchNormalization,\
                         Concatenate
from keras.layers.merge import concatenate
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD,                           \
                             RMSprop,                       \
                             Adagrad
from keras.callbacks import EarlyStopping,   \
                            ModelCheckpoint, \
                            ReduceLROnPlateau


### Fetch saved weights from Google drive storage object

In [0]:
import argparse
import os
import time
import sys
import glob

try:
    import h5py
except:
    print('Package h5py needed for saving model weights ...')
    sys.exit(1)
import json
import matplotlib
import matplotlib.pyplot as plt

try:
    import tensorflow
    import keras
except:
    print(
        'This code uses tensorflow deep-learning framework and keras api ...')
    print('Install tensorflow and keras to train the classifier ...')
    sys.exit(1)

import PIL  # Python Imaging Library
from collections import defaultdict
from keras.applications.inception_v3 import InceptionV3, \
    preprocess_input as preprocess_input_inceptionv3
from keras.applications.inception_resnet_v2 import InceptionResNetV2, \
    preprocess_input as preprocess_input_inceptionv4
from keras.models import Model, \
    model_from_json, \
    load_model
from keras.layers import Dense, \
    GlobalAveragePooling2D, \
    Dropout, \
    BatchNormalization,\
    merge
from keras.layers.merge import concatenate
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD, \
    RMSprop, \
    Adagrad
from keras.callbacks import EarlyStopping, \
    ModelCheckpoint, \
    ReduceLROnPlateau


def generate_timestamp():
    """
    Generates a timestring in the format year_month_day-hr_min_sec

    :return: a string that holds the timestring
    """
    timestring = time.strftime("%Y_%m_%d-%H_%M_%S")
    print("Time stamp generated: " + timestring)
    return timestring


def is_valid_file(parser, arg):
    """
    Checks if a file passed as an arg exists

    :param parser:  an ArgumentParser object that can process command line args
    :param arg: filename
    :return: the filename if the file exists, otherwise null
    """
    if not os.path.isfile(arg):
        parser.error("The file %s does not exist ..." % arg)
    else:
        return arg


def is_valid_dir(parser, arg):
    """
    Checks if a dir passed as an arg exists

    :param parser: an ArgumentParser object that can process command line args
    :param arg: directory path
    :return: the dir path if the dir exists
    """
    if not os.path.isdir(arg):
        parser.error("The folder %s does not exist ..." % arg)
    else:
        return arg


def string_to_bool(val):
    """
    Converts yes, y, 1, t, n, 0, f, False into their appropriate bools

    :param val: a string
    :return: the associated bool (True/False) otherwise an error
    """
    if val.lower() in ('yes', 'true', 't', 'y', '1', 'yeah'):
        return True
    elif val.lower() in ('no', 'false', 'f', 'n', '0', 'none'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected ...')


def get_nb_files(directory):
    """
    Gets the number of files in a directory

    :param directory: the dir that we are counting the number of files in
    :return: the count of the files in the dir (0 if the dir doesnt exist)
    """
    if not os.path.exists(directory):
        return 0
    cnt = 0
    for r, dirs, files in os.walk(directory):
        for dr in dirs:
            cnt += len(glob.glob(os.path.join(r, dr + "/*")))
    return cnt


def setup_to_transfer_learn(model, base_model, optimizer):
    """
    Helps join a previous model (the base model) to new model

    ???: where are base_model and model joined

    :param model: the new model to be trained
    :param base_model: the pre existing base model, already trained
    :param optimizer: the optimizer function to train the new model
    :return: the new model
    """
    for layer in base_model.layers:
        layer.trainable = False
    model.compile(optimizer=optimizer,
                  loss='categorical_crossentropy', metrics=['accuracy'])
    return model





## Dense attention layer for the top decision layer

The implementation of dense attention layer is a modified fork described [here](https://github.com/philipperemy/keras-attention-mechanism/blob/master/attention_dense.py).

In [0]:
def add_top_layer(base_model, nb_classes):
    """
    Add a fully connected convolutional neural network layer

    ???: confused about x1, x2, x12, x3

    :param base_model: the current model
    :param nb_classes: the number of classes that we need to predict for
    :return: the new model with additional CNN layer
    """

    # create dropout layer
    # (drops units from NN to prevent overfitting)
    try:
        dropout = float(args.dropout[0])
    except:
        dropout = DEFAULT_DROPOUT
        print('Invalid input for dropout ...')

    # choose activation function
    try:
        activation = str(args.activation[0]).lower()
        print('Building model using default activation function: ' + str(activation))
    except:
        activation = 'relu'
        print('Invalid input for activation function ...')
        print("Choice of activation functions: hard_sigmoid, elu, linear, relu," 
              "selu, sigmoid, softmax, softplus, sofsign, tanh ...")
        print('Building model using default activation function: relu')

    bm = base_model.output

    x = Dropout(dropout, name='dropout_1')(bm)
    x = GlobalAveragePooling2D()(x)
    #x = Dropout(dropout,name='dropout_2')(x)
    x = BatchNormalization()(x)
    
    xattn = Dense(FC_SIZE, name='dense_attention_1',activation='softmax')(x)
    
    x_concat = concatenate([x, xattn], name='concatenate_1')
    
    xout = Dense(FC_SIZE, name='fc_dense_1', activation=activation)(x_concat)
    xout = Dropout(dropout, name='fc_dropout_1')(xout)

    predictions = Dense(nb_classes,
                        activation='softmax',
                        name='prediction')(xout)  # New softmax layer

    model = Model(inputs=base_model.input, outputs=predictions)

    return model

In [0]:
def setup_to_finetune(model, optimizer, NB_FROZEN_LAYERS):
    """
    Freezes some of the bottom layers of the model to not be trained

    :param model: the current ML model
    :param optimizer: optimizer function being used for the training of the model
    :param NB_FROZEN_LAYERS: Freeze the bottom NB_LAYERS and retrain the remaining top layers
    :return: the updated new model with some of the bottom layers frozen
    """
    for layer in model.layers[:NB_FROZEN_LAYERS]:
        layer.trainable = False
    for layer in model.layers[NB_FROZEN_LAYERS:]:
        layer.trainable = True
    model.compile(optimizer=optimizer, loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model


def save_model(args, name, model):
    """
    Saves the
    1) model weights in name.model
    2) model config (the number of layers, activation funcs, etc) in name.json.
    Both are stored in a dir specified in args.output_dir[0].

    :param args: holds the saving dir in args.output_dir[0]
    :param name: name of the saved model weights and model config
    :param model: the model to be saved
    :return: Null
    """
    file_loc = args.output_dir[0]
    file_pointer_str = file_loc + "//trained_" + timestr
    file_pointer = os.path.join(file_pointer_str)
    model_save_str = file_pointer + "_weights" + str(name) + ".model"
    model.save_weights(os.path.join(model_save_str))

    model_json = model.to_json()  # Serialize model to JSON
    config_save_str = file_pointer + "_config" + str(name) + ".json"
    with open(os.path.join(config_save_str), "w") as json_file:
        json_file.write(model_json)
    print("Saved the trained model weights to: " +
          str(os.path.join(file_pointer + "_weights" + str(name) + ".model")))
    print("Saved the trained model configuration as a json file to: " +
          str(os.path.join(file_pointer + "_config" + str(name) + ".json")))


def generate_labels(args):
    """
    Generates labels from folder names in data/train/ and data/validation/
    IF the labels from train and validation folders match, labels are returned
    and a json with the lables is saved

    :param args: holds the saving dir in args.output_dir[0]
    :return: a sorted dict with the labels
    """
    file_loc = args.output_dir[0]
    file_pointer = os.path.join(file_loc + "//trained_labels")

    data_dir = args.train_dir[0]
    val_dir_ = args.val_dir[0]

    dt = defaultdict(list)
    dv = defaultdict(list)

    for root, subdirs, files in os.walk(data_dir):
        for filename in files:
            file_path = os.path.join(root, filename)
            assert file_path.startswith(data_dir)
            suffix = file_path[len(data_dir):]
            suffix = suffix.lstrip("/")
            label = suffix.split("/")[0]
            dt[label].append(file_path)

    for root, subdirs, files in os.walk(val_dir_):
        for filename in files:
            file_path = os.path.join(root, filename)
            assert file_path.startswith(val_dir_)
            suffix = file_path[len(val_dir_):]
            suffix = suffix.lstrip("/")
            label = suffix.split("/")[0]
            dv[label].append(file_path)

    labels = sorted(dt.keys())
    val_labels = sorted(dv.keys())

    if set(labels) == set(val_labels):
        print("Training labels: " + str(labels))
        print("Validation labels: " + str(val_labels))
        with open(os.path.join(file_pointer + ".json"), "w") as json_file:
            json.dump(labels, json_file)
    else:
        print("Training labels: " + str(labels))
        print("Validation labels: " + str(val_labels))
        print("Mismatched training and validation data labels ...")
        print(
            "Sub-folder names do not match between training and validation "
            "directories ...")
        sys.exit(1)

    return labels


def generate_plot(args, name, model_train):
    """
    Checks if plots were made and if so, displays plots of training

    :param args: holds the bool about if there were to be plots of
    training made, in args.plot[0]
    :param name: name of plots
    :param model_train: the model that was trained
    :return: Null
    """
    gen_plot = args.plot[0]
    if gen_plot == True:
        plot_training(args, name, model_train)
    else:
        print("No training summary plots generated ...")
        print("Set: --plot True for creating training summary plots")


def plot_training(args, name, history):
    """
    Plots the accuracy vs epoch and loss vs epoch and saves as png files

    ???: actual parameter is "model" but formal parameter is "history"...

    :param args: holds the outpt dir
    :param name: the names of the plots
    :param history: the datapoints of the training
    :return: null
    """
    output_loc = args.output_dir[0]

    output_file_acc = os.path.join(output_loc +
                                   "//training_plot_acc_" +
                                   timestr + str(name) + ".png")
    output_file_loss = os.path.join(output_loc +
                                    "//training_plot_loss_" +
                                    timestr + str(name) + ".png")
    fig_acc = plt.figure()
    plt.plot(history.history['acc'])
    plt.plot(history.history['val_acc'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    fig_acc.savefig(output_file_acc, dpi=fig_acc.dpi)
    print("Successfully created the training accuracy plot: "
          + str(output_file_acc))
    plt.close()

    fig_loss = plt.figure()
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    fig_loss.savefig(output_file_loss, dpi=fig_loss.dpi)
    print("Successfully created the loss function plot: "
          + str(output_file_loss))
    plt.close()


def train(args):
    """
    Helper function to train a model based on args

    :param args: some of the arguments and possible values are
            args.config_file = ['./model/model_efc.json']
            args.output_dir = ['./output/']
            args.train_dir = ['./data/train/']
            args.val_dir = ['./data/validation/']
            args.epoch = [10]
            args.batch = [4]
            args.train_model = [True]
            args.load_weights = [False]
            args.load_checkpoint = [False]
            args.fine_tune = [True]
            args.test_aug = [False]
            args.train_aug = [False]
            args.plot = [False]
            args.model_summary = [False]
            args.dropout = [0.6]
            args.learning_rate = [1e-8]
            args.decay = [0.0]
            args.optimizer_val = ['rms'] # 'rms', 'sgd', 'ada'
            args.frozen_layers = [150]
            args.base_model = ['inceptionv4']
            args.saved_chkpnt
    :return: Null
    """

    # Get output dir ##########################################################
    if not os.path.exists(args.output_dir[0]):
        os.makedirs(args.output_dir[0])

    # Get optimizer, learning rate, decay parameters ##########################
    optimizer_val = args.optimizer_val[0]
    lr = args.learning_rate[0]
    decay = args.decay[0]

    # Set optimizer based on user input #######################################
    if optimizer_val.lower() == 'sgd':
        optimizer = SGD(lr=lr, decay=decay, momentum=1, nesterov=True)
        print("Using SGD as the optimizer ...")
    elif optimizer_val.lower() == 'rms' or optimizer_val.lower() == 'rmsprop':
        optimizer = RMSprop(lr=lr, rho=0.9, epsilon=1e-08, decay=decay)
        print("Using RMSProp as the optimizer ...")
    elif optimizer_val.lower() == 'ada':
        optimizer = Adagrad(lr=lr, epsilon=1e-08, decay=decay)
        print("Using Adagrad as the optimizer ...")
    else:
        optimizer = DEFAULT_OPTIMIZER

    # Get number training samples and classes #################################
    nb_train_samples = get_nb_files(args.train_dir[0])
    nb_classes = len(glob.glob(args.train_dir[0] + "/*"))
    print("Total number of training samples = " + str(nb_train_samples))
    print("Number of training classes = " + str(nb_classes))

    # Get number validation samples and classes ###############################
    nb_val_samples = get_nb_files(args.val_dir[0])
    nb_val_classes = len(glob.glob(args.val_dir[0] + "/*"))
    print("Total number of validation samples = " + str(nb_val_samples))
    print("Number of validation classes = " + str(nb_val_classes))

    # START TRAINING if train labels == valid labels ##########################
    if nb_val_classes == nb_classes:
        print("Initiating training session ...")
    else:
        print("Mismatched number of training and validation data classes ...")
        print("Unequal number of sub-folders found between train and "
              "validation directories ...")
        print("Each sub-folder in train and validation directroies are "
              "treated as a separate class ...")
        print("Correct this mismatch and re-run ...")
        print("Now exiting ...")
        sys.exit(1)

    # Get num epochs, batch size, train aug ###################################
    nb_epoch = int(args.epoch[0])
    batch_size = int(args.batch[0])
    train_aug = args.train_aug[0]

    # Grab Base Model to train op top of  [TRANSF LEARNING] ###################
    if str((args.base_model[0]).lower()) == 'inceptionv4' or \
                    str((args.base_model[0]).lower()) == 'inception_v4' or \
                    str((args.base_model[0]).lower()) == 'inception_resnet':
        preprocess_input = preprocess_input_inceptionv4
    else:
        preprocess_input = preprocess_input_inceptionv3

    # ?????? ##################################################################
    if train_aug == True:
        train_datagen = ImageDataGenerator(
            preprocessing_function=preprocess_input,
            rotation_range=30,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True)
    else:
        train_datagen = ImageDataGenerator(
            preprocessing_function=preprocess_input)

    # ?????? ##################################################################
    test_aug = args.test_aug[0]
    if test_aug == True:
        test_datagen = ImageDataGenerator(
            preprocessing_function=preprocess_input,
            rotation_range=30,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True)
    else:
        test_datagen = ImageDataGenerator(
            preprocessing_function=preprocess_input)

    # Getting training data ###################################################
    print("Generating training data: ... ")
    train_generator = train_datagen.flow_from_directory(args.train_dir[0],
                                                        target_size=(
                                                        IM_WIDTH, IM_HEIGHT),
                                                        batch_size=batch_size,
                                                        class_mode='categorical')

    # Getting validation data #################################################
    print("Generating validation data: ... ")
    validation_generator = test_datagen.flow_from_directory(args.val_dir[0],
                                                            target_size=(
                                                            IM_WIDTH,
                                                            IM_HEIGHT),
                                                            batch_size=batch_size,
                                                            class_mode='categorical')

    # If base model an inception net ##########################################
    if str((args.base_model[0]).lower()) == 'inceptionv4' or\
        str((args.base_model[0]).lower()) == 'inception_v4' or\
        str((args.base_model[0]).lower()) == 'inception_resnet':
        base_model = InceptionResNetV2(weights='imagenet', \
                                       include_top=False)
        base_model_name = 'Inception version 4'
    else:
        # Model argument: include_top=False excludes the final FC layer
        base_model = InceptionV3(weights='imagenet',
                                 include_top=False)
        base_model_name = 'Inception version 3'
    print('Base model: ' + str(base_model_name))

    # Add a new layer to the base model #######################################
    model = add_top_layer(base_model, nb_classes)
    print("New top layer added to: " + str(base_model_name))

    # get classification labels, if to load checkpoints, previous weights,etc #
    labels = generate_labels(args)
    load_weights_ = args.load_weights[0]
    fine_tune_model = args.fine_tune[0]
    load_checkpoint = args.load_checkpoint[0]
    checkpointer_savepath = os.path.join(args.output_dir[0] +
                                         '/checkpoint/Transfer_learn_' +
                                         str(IM_WIDTH) + '_' +
                                         str(IM_HEIGHT) + '_' + '.h5')

    # Getting previous weights from checkpoint, else new model ################
    if load_weights_ == True and load_checkpoint == False:
        try:
            with open(args.config_file[0]) as json_file:
                model_json = json_file.read()
            model = model_from_json(model_json)
        except:
            model = model
        try:
            model.load_weights(args.weights_file[0])
            print("Loaded model weights from: " + str(args.weights_file[0]))
        except:
            print("Error loading model weights ...")
            print("Loaded default model weights ...")
    elif load_checkpoint == True:
        try:
            model = load_model(checkpointer_savepath)
            print(
                "Loaded model from checkpoint: " + str(checkpointer_savepath))
        except:
            if os.path.exists(args.saved_chkpnt[0]):
                model = load_model(args.saved_chkpnt[0])
                print('Loaded saved checkpoint file ...')
            else:
                print("Error loading model checkpoint ...")
                print("Loaded default model weights ...")
    else:
        model = model
        print("Tabula rasa ...")

    # Checking and freezing certain layers during training ####################
    try:
        NB_FROZEN_LAYERS = args.frozen_layers[0]
    except:
        NB_FROZEN_LAYERS = DEFAULT_NB_LAYERS_TO_FREEZE
    if fine_tune_model == True:
        print("Fine tuning Inception architecture ...")
        print("Frozen layers: " + str(NB_FROZEN_LAYERS))
        setup_to_finetune(model, optimizer, NB_FROZEN_LAYERS)
    else:
        print("Transfer learning using Inception architecture ...")
        setup_to_transfer_learn(model, base_model, optimizer)

    # START TRAINING ##########################################################
    print("Initializing training with  class labels: " + str(labels))

    # checking and printing current model summary prior to training
    model_summary_ = args.model_summary[0]
    if model_summary_ == True:
        print(model.summary())
    else:
        print(
            "Successfully loaded deep neural network classifier for training ")

    # getting checkpoint file prepared
    if not os.path.exists(os.path.join(args.output_dir[0] + '/checkpoint/')):
        os.makedirs(os.path.join(args.output_dir[0] + '/checkpoint/'))

    # setting up checkpoint, learning rate
    earlystopper = EarlyStopping(patience=6, verbose=1)
    checkpointer = ModelCheckpoint(checkpointer_savepath,
                                   verbose=1,
                                   save_best_only=True)
    learning_rate_reduction = ReduceLROnPlateau(monitor='val_acc',
                                                patience=2,
                                                mode='min',
                                                epsilon=1e-4,
                                                cooldown=1,
                                                verbose=1,
                                                factor=0.5,
                                                min_lr=lr * 1e-2)

    # training command
    model_train = model.fit_generator(train_generator,
                                      epochs=nb_epoch,
                                      steps_per_epoch=2000,
                                      validation_data=validation_generator,
                                      validation_steps=2000,
                                      class_weight='auto',
                                      callbacks=[earlystopper,
                                                 learning_rate_reduction,
                                                 checkpointer])

    # saving model and training plots
    if fine_tune_model == True:
        save_model(args, "_ft_", model)
        generate_plot(args, "_ft_", model_train)
    else:
        save_model(args, "_tl_", model)
        generate_plot(args, "_tl_", model_train)


###############################################################################
########################### DEFAULT PARAMETERS ################################

IM_WIDTH, IM_HEIGHT = 299, 299  # Fixed input image size for Inception
DEFAULT_EPOCHS = 100
DEFAULT_BATCHES = 20
FC_SIZE = 4096
DEFAULT_DROPOUT = 0.1
DEFAULT_NB_LAYERS_TO_FREEZE = 169

sgd = SGD(lr=1e-7, decay=0.5, momentum=1, nesterov=True)
rms = RMSprop(lr=1e-7, rho=0.9, epsilon=1e-08, decay=0.0)
ada = Adagrad(lr=1e-3, epsilon=1e-08, decay=0.0)

DEFAULT_OPTIMIZER = ada
timestr = generate_timestamp()
###############################################################################

In [0]:
import os
import types

model_dir = "./drive/EFIGI_Galaxy_Classification/model/"
output_dir = "./drive/EFIGI_Galaxy_Classification/output/"
checkpoint_dir = "./drive/EFIGI_Galaxy_Classification/output/checkpoint/"


args = types.SimpleNamespace()
args.config_file = [model_dir+'model_efc.json']
args.output_dir = [output_dir]
args.train_dir = ['./data/train/']
args.val_dir = ['./data/validation/']
args.epoch = [100]# at least 100 epochs
args.batch = [60]
args.train_model = [True]
args.load_weights = [False]
args.load_checkpoint = [True] # Set it to true for using a saved checkpoint
args.fine_tune = [True]
args.test_aug = [False]
args.train_aug = [False]
args.plot = [True]
args.model_summary = [False]
args.dropout = [0.95]
args.learning_rate = [1e-3]
args.decay = [0.0]
args.optimizer_val = ['adam'] # 'rms', 'sgd', 'ada'
args.frozen_layers = [140]
args.base_model = ['inceptionv4']
args.saved_chkpnt = [checkpoint_dir+'transfer_learn_299_299_.h5']
args.activation = ['elu']


def ensure_dir(file_path):
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
        
ensure_dir(model_dir)
ensure_dir(output_dir)
ensure_dir(checkpoint_dir)


In [0]:
train_model = args.train_model[0]
    
if train_model ==True:
  print ("Training sesssion initiated ...")
  train(args)
else:
  print ("Nothing to do here ...")
  print ("Try setting the --train_model flag to True ...")
  print ("For more help, run with -h flag ...")
  sys.exit(1)

In [0]:
! mv ./drive/EFIGI_Galaxy_Classification/output//checkpoint/Transfer_learn_299_299_.h5 /content/gdrive/My\ Drive/Transfer_learn_299_299_EFIGI.h5