# Imports and Definitions

In [None]:
pip install -q wandb

In [None]:
import numpy as np
import json
import glob

import tensorflow as tf
ks = tf.keras

import wandb
from wandb.keras import WandbCallback as WandbCallback

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Reading iNaturalist Data

In [None]:
path_train = "/content/drive/MyDrive/Sem 8/DL/inaturalist_12K/train"
path_test  = "/content/drive/MyDrive/Sem 8/DL/inaturalist_12K/val"

batch_size = 32
img_size = (800, 800)
img_shape = img_size + (3,)

train_ds = ks.preprocessing.image_dataset_from_directory(path_train,
                                                         label_mode='categorical',
                                                         validation_split = 0.1,
                                                         subset="training",
                                                         seed=123,
                                                         image_size=img_size,
                                                         batch_size=batch_size)

val_ds = ks.preprocessing.image_dataset_from_directory(path_train,
                                                       label_mode='categorical',
                                                       validation_split = 0.1,
                                                       subset="validation",
                                                       seed=123,
                                                       image_size=img_size,
                                                       batch_size=batch_size)

test_ds = ks.preprocessing.image_dataset_from_directory(path_test,
                                                        label_mode='categorical',
                                                        image_size=img_size,
                                                        batch_size=batch_size)

class_names = train_ds.class_names
num_classes = len(class_names)

AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.prefetch(buffer_size = tf.data.AUTOTUNE)
val_ds   = val_ds.prefetch(buffer_size = tf.data.AUTOTUNE)
test_ds  = test_ds.prefetch(buffer_size = tf.data.AUTOTUNE)

# Functions

In [None]:
def finetune_model(hyp):
    
    # 1.
    model_to_use = hyp["model"]
    
    # 1.1. The function to get the model
    exec("ks_app_CNN_exec = ks.applications." + model_to_use[0])
    ks_app_CNN = locals()["ks_app_CNN_exec"]
    
    # 1.2. Supporting class for the model
    exec("ks_app_cnn_exec = ks.applications." + model_to_use[1])
    ks_app_cnn = locals()["ks_app_cnn_exec"]
    
    # 2.
    eta = hyp["eta"]
    epochs = hyp["epochs"]
    dropout = hyp["dropout"]
    num_unfrozen = hyp["num_unfrozen"]
    
    # 3.
    if hyp["include_top"]:
        # Creating the cnn
        cnn_model = ks_app_CNN(include_top=True,
                               weights='imagenet')
        img_width = cnn_model.layers[0].input_shape[0][1]
        
        # Defining data_augmentation function
        data_aug = ks.Sequential([
            ks.layers.experimental.preprocessing.RandomFlip('horizontal'),
            ks.layers.experimental.preprocessing.RandomRotation(0.2),
            ks.layers.experimental.preprocessing.Resizing(img_width, img_width, interpolation='bilinear')
        ])
    else:
        # Defining data_augmentation function
        data_aug = ks.Sequential([
            ks.layers.experimental.preprocessing.RandomFlip('horizontal'),
            ks.layers.experimental.preprocessing.RandomRotation(0.2)
        ])
        
        # Creating the cnn
        cnn_model = ks_app_CNN(input_shape=img_shape,
                               include_top=False,
                               weights='imagenet',
                               pooling='avg')
    
    # 4. Creating model with all layers in cnn frozen and fitting train data
    cnn_model.trainable = False

    # 4.1. Creating the model
    inputs = ks.Input(shape=img_shape)
    outputs = data_aug(inputs)
    outputs = ks_app_cnn.preprocess_input(outputs)
    outputs = cnn_model(outputs, training=False)
    outputs = tf.keras.layers.Dropout(dropout)(outputs)
    outputs = ks.layers.Dense(num_classes, activation='softmax')(outputs)

    model = ks.Model(inputs, outputs)
    model.compile(optimizer=ks.optimizers.Adam(lr=eta),
                  loss=ks.losses.categorical_crossentropy,
                  metrics=['accuracy'])
    
    # 4.2. Fitting
    if num_unfrozen == 0:
        model.fit(train_ds,
                  epochs = epochs,
                  validation_data = val_ds, 
                  callbacks = [WandbCallback()])
    else:
        model.fit(train_ds,
                  epochs=epochs//2,
                  validation_data=val_ds,
                  callbacks = [WandbCallback()])
        
        # 5. Unfreezing num_unfrozen layers in cnn and fitting
        cnn_model.trainable = True
        for layer in cnn_model.layers[:-num_unfrozen]:
            layer.trainable = False

        model.compile(optimizer=ks.optimizers.Adam(lr=eta/10),
                      loss=ks.losses.categorical_crossentropy,
                      metrics=['accuracy'])
        
        model.fit(train_ds,
                  epochs=epochs,
                  validation_data=val_ds,
                  initial_epoch=epochs//2,
                  callbacks = [WandbCallback()])
        


def runSweep():
    wandb.init()
    
    #set the hyperparameters
    hyp = {}
    hyp["eta"] = wandb.config.eta
    hyp["epochs"] = wandb.config.epochs
    hyp["dropout"] = wandb.config.dropout
    hyp["model"] = wandb.config.model
    hyp["include_top"] = wandb.config.include_top
    hyp["num_unfrozen"] = wandb.config.num_unfrozen
    
    # Finetuning run
    finetune_model(hyp)

# Sweep

In [None]:
sweepCfg = {
    "name":"Fine Tuning Pretrained Sweep 7", 
    "method": "grid", 
    "parameters":{
        "include_top":{
            "values":[0, 1]
        },
        "num_unfrozen":{
            "values":[0, 5]
        },
        "model":{
            "values":[('Xception', 'xception'),
                    ('InceptionResNetV2', 'inception_resnet_v2'),
                    ('ResNet50', 'resnet50'),
                    ('InceptionV3', 'inception_v3')
                    ]
        },
        "eta":{
            "values":[1e-3]
        },
        "epochs":{
            "values":[10]
        },
        "dropout":{
            "values":[0.2]
        }
    }
}

sweepId = wandb.sweep(sweepCfg)
wandb.agent('srijan_gupta/uncategorized/u5j52d5c', function = runSweep)