In [1]:
# clone the ADIS repository
!git clone https://github.com/sathishkumar67/SSD_MobileNetV3_ADIS.git
# move the files to the current directory
!mv /kaggle/working/SSD_MobileNetV3_ADIS/* /kaggle/working/
# upgrade pip
!pip install --upgrade pip
# install the required packages
!pip install  -r requirements.txt --upgrade --upgrade-strategy eager

In [None]:
# necessary imports
import os
import optuna
import joblib
from tqdm import tqdm
import random
import numpy as np
from tqdm import tqdm
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, RandomSampler
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from ssdlite_mobnetv3_adis.utils import unzip_file, replace_activation_function
from ssdlite_mobnetv3_adis.dataset import collate_fn, SSDLITEOBJDET_DATASET, CachedSSDLITEOBJDET_DATASET
from ssdlite_mobnetv3_adis.model import SSDLITE_MOBILENET_V3_Large
from ssdlite_mobnetv3_adis.epu import EPU
from ssdlite_mobnetv3_adis.trainer import bohb_tunner, train


# set random seed for reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
# set constants
REPO_ID = "pt-sk/ADIS" 
DATASET_NAME = "balanced_dataset"
REPO_TYPE = "dataset"
FILENAME_IN_REPO = f"{DATASET_NAME}.zip"
LOCAL_DIR = os.getcwd()
DATASET_PATH = f"{LOCAL_DIR}/{FILENAME_IN_REPO}"
DATASET_FOLDER_PATH = f"{LOCAL_DIR}/{DATASET_NAME}"                       
CLASSES = ['Cat', 'Cattle', 'Chicken', 'Deer', 'Dog', 'Squirrel', 'Eagle', 'Goat', 'Rodents', 'Snake'] 
NUM_CLASSES = len(CLASSES)
NUM_CLASSES_WITH_BG = NUM_CLASSES + 1    # 1 for background class

# download the dataset and unzip it
hf_hub_download(repo_id=REPO_ID, filename=FILENAME_IN_REPO, repo_type=REPO_TYPE, local_dir=LOCAL_DIR)
unzip_file(DATASET_PATH, LOCAL_DIR)

In [None]:
# set pin memory device
PIN_MEMORY_DEVICE = "cuda:0"
NUM_CORES = os.cpu_count()
BATCH_SIZE = 64

# prepare the dataset
train_dataset = CachedSSDLITEOBJDET_DATASET(
    dataset_class=SSDLITEOBJDET_DATASET,
    root_dir=DATASET_FOLDER_PATH,
    split="train",
    num_classes=NUM_CLASSES_WITH_BG)

val_dataset = CachedSSDLITEOBJDET_DATASET(
    dataset_class=SSDLITEOBJDET_DATASET,
    root_dir=DATASET_FOLDER_PATH,
    split="val",
    num_classes=NUM_CLASSES_WITH_BG)

test_dataset = CachedSSDLITEOBJDET_DATASET(
    dataset_class=SSDLITEOBJDET_DATASET,
    root_dir=DATASET_FOLDER_PATH,
    split="test",
    num_classes=NUM_CLASSES_WITH_BG)


# samplers for reproducibility
train_sampler = RandomSampler(train_dataset, generator=torch.Generator().manual_seed(RANDOM_SEED))
val_sampler = RandomSampler(val_dataset, generator=torch.Generator().manual_seed(RANDOM_SEED))
test_sampler = RandomSampler(test_dataset, generator=torch.Generator().manual_seed(RANDOM_SEED))


# prepare the dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=train_sampler,
    num_workers=NUM_CORES,
    collate_fn=collate_fn,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2,
    pin_memory_device=PIN_MEMORY_DEVICE)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    sampler=val_sampler,
    num_workers=NUM_CORES,
    collate_fn=collate_fn,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2,
    pin_memory_device=PIN_MEMORY_DEVICE)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    sampler=test_sampler,
    num_workers=NUM_CORES,
    collate_fn=collate_fn,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2,
    pin_memory_device=PIN_MEMORY_DEVICE)

In [None]:
# constants
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
WARMUP_EPOCHS = 10
NUM_EPOCHS = 100
PATIENCE = 10
END_FACTOR = 1.0

# define the objective function
def objective(trial):
    # define callback to report intermidiate results
    def on_train_epoch_end(score, epoch):
        trial.report(score, step=epoch)  
        if trial.should_prune():
            raise optuna.TrialPruned()
        
    # suggest hyperparameters for the model
    INITIAL_LR = trial.suggest_float("INITIAL_LR", 1e-4, 1e-1, log=True)
    LR_FACTOR = trial.suggest_float("LR_FACTOR", 1e-4, 1e-1, log=True)
    START_FACTOR = trial.suggest_float("START_FACTOR", 1e-4, 1e-1, log=True)
    WEIGHT_DECAY = trial.suggest_float("WEIGHT_DECAY", 1e-4, 1e-1, log=True)
    MOMENTUM = trial.suggest_float("MOMENTUM", 0.7, 0.99)
    
    # create the model
    model = SSDLITE_MOBILENET_V3_Large(num_classes_with_bg=NUM_CLASSES_WITH_BG)
    # move the model to device
    model.to(DEVICE)
    
    # create the optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=INITIAL_LR,
        betas=(MOMENTUM, 0.999),
        weight_decay=WEIGHT_DECAY,
        eps=1e-8,
        fused=True
    )
    
    # tune the model
    best_val_loss = bohb_tunner(
        args={
            "device": DEVICE,
            "warmup_epochs": WARMUP_EPOCHS,
            "num_epochs": NUM_EPOCHS,
            "patience": PATIENCE,
            "initial_lr": INITIAL_LR,
            "lr_factor": LR_FACTOR,
            "start_factor": START_FACTOR,
            "end_factor": END_FACTOR
        },
        model=model,
        optimizer=optimizer,
        dataloaders={"train":train_loader, "val":val_loader},
        callback=on_train_epoch_end
    )
    # return the best validation loss
    return best_val_loss

In [None]:
# define the number of trials
NUM_TRIALS = 5

# load the study
study = optuna.create_study(direction='minimize', 
                            sampler=optuna.samplers.TPESampler(), 
                            pruner=optuna.pruners.HyperbandPruner(),
                            study_name="ssd_mobnetv3_adis_epu_bohbtune",
                            load_if_exists=True)

# Optimize with a callback to stop after NUM_TRIALS complete trials
study.optimize(objective, n_trials=NUM_TRIALS)

# save the study
joblib.dump(study, f"{LOCAL_DIR}/ssd_mobnetv3_adis_bohbtune_study1.pkl")