# Nature Conservancy Fish Classification - End-to-End Model

### Imports & environment

In [None]:
import os
from PIL import Image

from keras.callbacks import ModelCheckpoint
from keras.layers import GlobalAveragePooling2D, Activation, Input, Flatten
from keras.optimizers import Adam
from keras.models import Sequential

from utils import * 
from models import Vgg16BN, Inception, Resnet50
from glob import iglob

ROOT_DIR = os.getcwd()
DATA_HOME_DIR = ROOT_DIR + '/data'
%matplotlib inline

### Config & Hyperparameters

In [None]:
# paths
data_path = DATA_HOME_DIR
split_train_path = data_path + '/nof_excl/train/'
valid_path = data_path + '/nof_excl/valid/'
test_path = DATA_HOME_DIR + '/test/'
saved_model_path = ROOT_DIR + '/models/end_to_end/'
fish_detector_path = ROOT_DIR + '/models/fish_detector_480x270/0.03-loss_2epoch_480x270_0.3-dropout_0.001-lr_vggbn.h5'
submission_path = ROOT_DIR + '/submissions/end_to_end/'

# data
batch_size = 8
im_size = (270, 480)  # ht, wt (only 299x299 for inception)
nb_split_train_samples = 2944
nb_valid_samples = 395
nb_test_samples = 1000
classes = ["ALB", "BET", "DOL", "LAG", "OTHER", "SHARK", "YFT"]  # excluding "NoF"
nb_classes = len(classes)

# model
nb_runs = 5
nb_epoch = 30
nb_aug = 5
dropout = 0.4
clip = 0.01
archs = ["vggbn"]

models = {
    "vggbn": Vgg16BN(size=im_size, n_classes=nb_classes, lr=0.001,
                           batch_size=batch_size, dropout=dropout),
    "inception": Inception(size=(299, 299), n_classes=nb_classes,
                           lr=0.001, batch_size=batch_size),
    "resnet": Resnet50(size=im_size, n_classes=nb_classes, lr=0.001,
                    batch_size=batch_size, dropout=dropout)
} 

### Build & Train Species Classifier

This classifier is looking at fish classes only (excludes "NoF" class). When we make predictions on the test set later, we'll first use our fish detection model to separate out "NoF", and then predict species from there. 

In [None]:
def train(parent_model, model_str):
    parent_model.build()    
    model_fn = saved_model_path + '{val_loss:.2f}-loss_{epoch}epoch_' + model_str
    ckpt = ModelCheckpoint(filepath=model_fn, monitor='val_loss',
                           save_best_only=True, save_weights_only=True)
    
    parent_model.fit_val(split_train_path, valid_path, nb_trn_samples=nb_split_train_samples, 
                         nb_val_samples=nb_valid_samples, nb_epoch=nb_epoch, callbacks=[ckpt], aug=nb_aug)

    model_path = max(iglob(saved_model_path + '*.h5'), key=os.path.getctime)
    return model_path

In [None]:
def train_all():    
    model_paths = {
        "vggbn": [],
        "inception": [],
        'resnet': [],
    }
    
    for run in range(nb_runs):
        print("Starting Training Run {0} of {1}...\n".format(run+1, nb_runs))
        aug_str = "aug" if nb_aug else "no-aug"
        
        for arch in archs:
            print("Training {} model...\n".format(arch))
            model = models[arch]
            model_str = "{0}x{1}_{2}_{3}lr_run{4}_{5}.h5".format(model.size[0], model.size[1], aug_str,
                                                                 model.lr, run, arch)
            model_path = train(model, model_str)
            model_paths[arch].append(model_path)
        
    print("Done.") 
    return model_paths
        
model_paths = train_all()

In [None]:
def generate_preds(model_paths):
    
    predictions_full = np.zeros((nb_test_samples, nb_classes+1))
    
    for run in range(nb_runs):
        print("\nStarting Prediction Run {0} of {1}...\n".format(run+1, nb_runs))
        predictions_aug = np.zeros((nb_test_samples, nb_classes+1))
        
        for aug in range(nb_aug):
            print("\n--Predicting on Augmentation {0} of {1}...\n".format(aug+1, nb_aug))
            predictions_mod = np.zeros((nb_test_samples, nb_classes+1))
            
            for arch in archs:
                print("----Predicting on {} model...".format(arch))
                
                parent = models[arch]
                model = parent.build()
                model.load_weights(model_paths[arch][run])
                
                fish_detector = Vgg16BN(size=im_size, n_classes=2, lr=0.001,
                                       batch_size=batch_size, dropout=dropout)
                fish_detector.build()
                fish_detector.model.load_weights(fish_detector_path)
                
                nofish_prob, _ = fish_detector.test(test_path, nb_test_samples, aug=nb_aug)
                nofish_prob = nofish_prob[:, 1]
                
                species_prob, filenames = parent.test(test_path, nb_test_samples, aug=nb_aug)
                
                pred = np.insert(species_prob, 4, nofish_prob, axis=1)
                predictions_mod += pred
            
            predictions_mod /= len(archs)
            predictions_aug += predictions_mod

        predictions_aug /= nb_aug
        predictions_full += predictions_aug
    
    predictions_full /= nb_runs
    return predictions_full, filenames                     
    
predictions, filenames = generate_preds(model_paths)

In [None]:
def weight_predictions(predictions):
    """Weights predictions based on probability image contains a fish as predicted by fish detector model"""
    no_fish = predictions[:, 4]
    fish = np.delete(predictions, 4, axis=1)

    weights = -1. * (no_fish - 1.)
    weights = weights.reshape(1000, 1)

    fish = weights * fish
    preds = np.insert(fish, 4, no_fish, axis=1)

    return preds

predictions = weight_predictions(predictions)

### Write Predictions to File

In [None]:
def write_submission(predictions, filenames):
    preds = np.clip(predictions, clip, 1-clip)
    sub_fn = submission_path + '{0}epoch_{1}aug_{2}clip_{3}runs'.format(nb_epoch, nb_aug, clip, nb_runs)
    
    for arch in archs:
        sub_fn += "_{}".format(arch)

    with open(sub_fn + '.csv', 'w') as f:
        print("Writing Predictions to CSV...")
        f.write('image,ALB,BET,DOL,LAG,NoF,OTHER,SHARK,YFT\n')
        for i, image_name in enumerate(filenames):
            pred = ['%.6f' % p for p in preds[i, :]]
            f.write('%s,%s\n' % (os.path.basename(image_name), ','.join(pred)))
        print("Done.")

write_submission(predictions, filenames)