In [1]:
# clone the ADIS repository
!git clone https://github.com/sathishkumar67/SSD_MobileNetV3_ADIS.git
# move the files to the current directory
!mv /kaggle/working/SSD_MobileNetV3_ADIS/* /kaggle/working/
# upgrade pip
!pip install --upgrade pip
# install the required packages
!pip install  -r requirements.txt --upgrade --upgrade-strategy eager
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

Cloning into 'SSD_MobileNetV3_ADIS'...
remote: Enumerating objects: 176, done.[K
remote: Counting objects: 100% (176/176), done.[K
remote: Compressing objects: 100% (130/130), done.[K
remote: Total 176 (delta 106), reused 108 (delta 46), pack-reused 0 (from 0)[K
Receiving objects: 100% (176/176), 77.63 KiB | 3.38 MiB/s, done.
Resolving deltas: 100% (106/106), done.
Collecting pip
  Downloading pip-25.0.1-py3-none-any.whl.metadata (3.7 kB)
Downloading pip-25.0.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m45.8 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.0.1
Collecting ultralytics (from -r requirements.txt (line 1))
  Downloading ultralytics-8.3.110-py3-none-any.whl.metadata (37 kB)
Collecting albumentations==2.0.5 

In [2]:
# necessary imports
import os
import optuna
import joblib
from typing import Tuple
from tqdm import tqdm
from ssd_mobnetv3_adis import unzip_file
from huggingface_hub import hf_hub_download
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, RandomSampler
from torchmetrics.detection import MeanAveragePrecision
from ssd_mobnetv3_adis import collate_fn, SSDLITEOBJDET_DATASET, CachedSSDLITEOBJDET_DATASET, SSD_MOBILENET_V3_Large

In [3]:
# set constants
REPO_ID = "pt-sk/ADIS" 
DATASET_NAME = "balanced_dataset"
REPO_TYPE = "dataset"
FILENAME_IN_REPO = f"{DATASET_NAME}.zip"
LOCAL_DIR = os.getcwd()
DATASET_PATH = f"{LOCAL_DIR}/{FILENAME_IN_REPO}"
DATASET_FOLDER_PATH = f"{LOCAL_DIR}/{DATASET_NAME}"                       
CLASSES = ['Cat', 'Cattle', 'Chicken', 'Deer', 'Dog', 'Squirrel', 'Eagle', 'Goat', 'Rodents', 'Snake'] 
NUM_CLASSES = len(CLASSES)
MODEL_NUM_CLASSES = NUM_CLASSES + 1    # 1 for background class

# download the dataset and unzip it
hf_hub_download(repo_id=REPO_ID, filename=FILENAME_IN_REPO, repo_type=REPO_TYPE, local_dir=LOCAL_DIR)
unzip_file(DATASET_PATH, LOCAL_DIR)

# number of cores
num_cores = os.cpu_count()
print(f"Number of CPU cores: {num_cores}")

balanced_dataset.zip:   0%|          | 0.00/7.04G [00:00<?, ?B/s]

Unzipping: 100%|██████████| 7.07G/7.07G [00:42<00:00, 166MB/s]


Unzipped /kaggle/working/balanced_dataset.zip to /kaggle/working
Removed zip file: /kaggle/working/balanced_dataset.zip
Number of CPU cores: 4


In [4]:
# set pin memory device
PIN_MEMORY_DEVICE = "cuda:0"

# prepare the dataset
train_dataset = CachedSSDLITEOBJDET_DATASET(
    dataset_class=SSDLITEOBJDET_DATASET,
    root_dir=DATASET_FOLDER_PATH,
    split="train",
    num_classes=MODEL_NUM_CLASSES)

val_dataset = CachedSSDLITEOBJDET_DATASET(
    dataset_class=SSDLITEOBJDET_DATASET,
    root_dir=DATASET_FOLDER_PATH,
    split="val",
    num_classes=MODEL_NUM_CLASSES)

# test_dataset = CachedSSDLITEOBJDET_DATASET(
#     dataset_class=SSDLITEOBJDET_DATASET,
#     root_dir=DATASET_FOLDER_PATH,
#     split="test",
#     num_classes=MODEL_NUM_CLASSES)


# samplers for reproducibility
train_sampler = RandomSampler(train_dataset, generator=torch.Generator().manual_seed(42))
val_sampler = RandomSampler(val_dataset, generator=torch.Generator().manual_seed(42))
# test_sampler = RandomSampler(test_dataset, generator=torch.Generator().manual_seed(42))


# prepare the dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    sampler=train_sampler,
    num_workers=num_cores,
    collate_fn=collate_fn,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2,
    pin_memory_device=PIN_MEMORY_DEVICE)

val_loader = DataLoader(
    val_dataset,
    batch_size=128,
    sampler=val_sampler,
    num_workers=num_cores,
    collate_fn=collate_fn,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2,
    pin_memory_device=PIN_MEMORY_DEVICE)

# test_loader = DataLoader(
#     test_dataset,
#     batch_size=128,
#     sampler=test_sampler,
#     num_workers=num_cores,
#     collate_fn=collate_fn,
#     pin_memory=True,
#     persistent_workers=True,
#     prefetch_factor=2,
#     pin_memory_device=PIN_MEMORY_DEVICE)

Preprocessing dataset and caching to /kaggle/working/balanced_dataset/train_cache...


100%|██████████| 18139/18139 [03:35<00:00, 84.09it/s] 


Preprocessing dataset and caching to /kaggle/working/balanced_dataset/val_cache...


100%|██████████| 2390/2390 [00:26<00:00, 91.29it/s] 


In [12]:
def train(warmup_epochs: int, num_epochs: int, patience: int, initial_lr: float, betas: Tuple[float, float], weight_decay: float, dataloaders: dict[str, torch.utils.data.DataLoader]) -> None:
    # early stopping parameters
    best_map = float('-inf')
    patience_counter = 0
    
    # get the dataloaders
    train_loader, val_loader = dataloaders['train'], dataloaders['val']
    
    # Set device
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load the model
    model = SSD_MOBILENET_V3_Large(num_classes_with_bg=MODEL_NUM_CLASSES)
    model.to(device)
    
    # Optimizer
    optimizer = model.configure_optimizers(lr=initial_lr, betas=betas, weight_decay=weight_decay, eps=1e-08, fused=True)
    
    for epoch in range(num_epochs):
        # Warmup phase: linearly increase learning rate for the first 4 epochs
        if epoch < warmup_epochs:
            lr = initial_lr * (epoch + 1) / warmup_epochs
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        # Training phase
        model.train()
        total_loss = 0.0
        num_batches = len(train_loader)
        
        # Progress bar
        train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        
        for _, (images, targets) in enumerate(train_bar):
            # Move data to device
            images = images.to(device)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            # Forward pass
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            # Backward pass and optimization
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            
            batch_loss = losses.detach().item()
            total_loss += batch_loss
            
            # Update progress bar
            train_bar.set_postfix(loss=batch_loss)
        
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1}/{num_epochs} | Learning Rate: {lr:.6f} | Avg Train Loss: {avg_loss:.4f}")
        
        # Validation phase
        model.eval()
        metric = MeanAveragePrecision()
        eval_bar = tqdm(val_loader, desc=f"Validating...", unit="batch")
        with torch.no_grad():
            for images, targets in eval_bar:
                # Move data to device
                images = images.to(device)
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                # Forward pass    
                predictions = model(images)
                metric.update(predictions, targets)
        
        map_result = metric.compute()
        print(f"Epoch {epoch+1} | Val mAP: {map_result['map']:.4f}")
        
        # Early stopping logic
        if map_result['map'] > best_map:
            best_map = map_result['map']
            best_model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()} 
            best_optimizer_state_dict = optimizer.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered at epoch", epoch + 1)
                # save the best model
                torch.save({"model_state_dict" : best_model_state_dict,
                    "optimizer_state_dict" : best_optimizer_state_dict,
                }, f"{LOCAL_DIR}/best_model.pth")
                print(f"Best model saved with mAP: {best_map:.4f}")
                break

# train the model with the suggested hyperparameters
train(warmup_epochs=4, num_epochs=50, patience=5, initial_lr=0.0001, betas=(0.9, 0.999), weight_decay=0.001, dataloaders={'train': train_loader, 'val': val_loader})

Using device: cuda:0


Epoch 1/50: 100%|██████████| 142/142 [02:05<00:00,  1.13batch/s, loss=9.95]


Epoch 1/50 | Learning Rate: 0.000025 | Avg Train Loss: 12.0772


Validating...: 100%|██████████| 19/19 [00:17<00:00,  1.06batch/s]


Epoch 1 | Val mAP: 0.0015


Epoch 2/50: 100%|██████████| 142/142 [02:01<00:00,  1.17batch/s, loss=5.92]


Epoch 2/50 | Learning Rate: 0.000050 | Avg Train Loss: 7.4961


Validating...: 100%|██████████| 19/19 [00:16<00:00,  1.12batch/s]


Epoch 2 | Val mAP: 0.0567


Epoch 3/50: 100%|██████████| 142/142 [02:01<00:00,  1.17batch/s, loss=4.62]


Epoch 3/50 | Learning Rate: 0.000075 | Avg Train Loss: 4.9224


Validating...: 100%|██████████| 19/19 [00:17<00:00,  1.11batch/s]


Epoch 3 | Val mAP: 0.2651


Epoch 4/50: 100%|██████████| 142/142 [02:02<00:00,  1.16batch/s, loss=3.67]


Epoch 4/50 | Learning Rate: 0.000100 | Avg Train Loss: 3.9472


Validating...: 100%|██████████| 19/19 [00:16<00:00,  1.14batch/s]


Epoch 4 | Val mAP: 0.3453


Epoch 5/50: 100%|██████████| 142/142 [02:02<00:00,  1.16batch/s, loss=3.21]


Epoch 5/50 | Learning Rate: 0.000100 | Avg Train Loss: 3.3725


Validating...: 100%|██████████| 19/19 [00:16<00:00,  1.14batch/s]


Epoch 5 | Val mAP: 0.3909


Epoch 6/50: 100%|██████████| 142/142 [01:59<00:00,  1.19batch/s, loss=2.76]


Epoch 6/50 | Learning Rate: 0.000100 | Avg Train Loss: 2.9845


Validating...: 100%|██████████| 19/19 [00:17<00:00,  1.11batch/s]


Epoch 6 | Val mAP: 0.4180


Epoch 7/50: 100%|██████████| 142/142 [02:01<00:00,  1.17batch/s, loss=2.87]


Epoch 7/50 | Learning Rate: 0.000100 | Avg Train Loss: 2.6901


Validating...: 100%|██████████| 19/19 [00:17<00:00,  1.11batch/s]


Epoch 7 | Val mAP: 0.4350


Epoch 8/50: 100%|██████████| 142/142 [02:02<00:00,  1.16batch/s, loss=2.4] 


Epoch 8/50 | Learning Rate: 0.000100 | Avg Train Loss: 2.4486


Validating...: 100%|██████████| 19/19 [00:16<00:00,  1.16batch/s]


Epoch 8 | Val mAP: 0.4504


Epoch 9/50: 100%|██████████| 142/142 [02:01<00:00,  1.17batch/s, loss=2.16]


Epoch 9/50 | Learning Rate: 0.000100 | Avg Train Loss: 2.2475


Validating...: 100%|██████████| 19/19 [00:16<00:00,  1.14batch/s]


Epoch 9 | Val mAP: 0.4593


Epoch 10/50: 100%|██████████| 142/142 [02:03<00:00,  1.15batch/s, loss=2.03]


Epoch 10/50 | Learning Rate: 0.000100 | Avg Train Loss: 2.0836


Validating...: 100%|██████████| 19/19 [00:16<00:00,  1.13batch/s]


Epoch 10 | Val mAP: 0.4666


Epoch 11/50: 100%|██████████| 142/142 [02:02<00:00,  1.16batch/s, loss=1.76]


Epoch 11/50 | Learning Rate: 0.000100 | Avg Train Loss: 1.9324


Validating...: 100%|██████████| 19/19 [00:16<00:00,  1.16batch/s]


Epoch 11 | Val mAP: 0.4714


Epoch 12/50: 100%|██████████| 142/142 [02:01<00:00,  1.16batch/s, loss=1.67]


Epoch 12/50 | Learning Rate: 0.000100 | Avg Train Loss: 1.8107


Validating...: 100%|██████████| 19/19 [00:16<00:00,  1.14batch/s]


Epoch 12 | Val mAP: 0.4719


Epoch 13/50: 100%|██████████| 142/142 [02:03<00:00,  1.15batch/s, loss=1.82]


Epoch 13/50 | Learning Rate: 0.000100 | Avg Train Loss: 1.6941


Validating...: 100%|██████████| 19/19 [00:16<00:00,  1.16batch/s]


Epoch 13 | Val mAP: 0.4730


Epoch 14/50: 100%|██████████| 142/142 [02:03<00:00,  1.15batch/s, loss=1.61]


Epoch 14/50 | Learning Rate: 0.000100 | Avg Train Loss: 1.5961


Validating...: 100%|██████████| 19/19 [00:16<00:00,  1.15batch/s]


Epoch 14 | Val mAP: 0.4704


Epoch 15/50: 100%|██████████| 142/142 [02:02<00:00,  1.16batch/s, loss=1.22]


Epoch 15/50 | Learning Rate: 0.000100 | Avg Train Loss: 1.5048


Validating...: 100%|██████████| 19/19 [00:17<00:00,  1.11batch/s]


Epoch 15 | Val mAP: 0.4754


Epoch 16/50: 100%|██████████| 142/142 [02:01<00:00,  1.16batch/s, loss=1.41]


Epoch 16/50 | Learning Rate: 0.000100 | Avg Train Loss: 1.4111


Validating...: 100%|██████████| 19/19 [00:16<00:00,  1.14batch/s]


Epoch 16 | Val mAP: 0.4716


Epoch 17/50: 100%|██████████| 142/142 [02:02<00:00,  1.16batch/s, loss=1.17]


Epoch 17/50 | Learning Rate: 0.000100 | Avg Train Loss: 1.3351


Validating...: 100%|██████████| 19/19 [00:16<00:00,  1.15batch/s]


Epoch 17 | Val mAP: 0.4738


Epoch 18/50: 100%|██████████| 142/142 [02:02<00:00,  1.16batch/s, loss=1.37]


Epoch 18/50 | Learning Rate: 0.000100 | Avg Train Loss: 1.2689


Validating...: 100%|██████████| 19/19 [00:16<00:00,  1.12batch/s]


Epoch 18 | Val mAP: 0.4762


Epoch 19/50: 100%|██████████| 142/142 [02:03<00:00,  1.15batch/s, loss=1.1] 


Epoch 19/50 | Learning Rate: 0.000100 | Avg Train Loss: 1.2093


Validating...: 100%|██████████| 19/19 [00:15<00:00,  1.25batch/s]


Epoch 19 | Val mAP: 0.4706


Epoch 20/50: 100%|██████████| 142/142 [02:03<00:00,  1.15batch/s, loss=1.11] 


Epoch 20/50 | Learning Rate: 0.000100 | Avg Train Loss: 1.1479


Validating...: 100%|██████████| 19/19 [00:15<00:00,  1.24batch/s]


Epoch 20 | Val mAP: 0.4707


Epoch 21/50: 100%|██████████| 142/142 [02:02<00:00,  1.16batch/s, loss=1.22] 


Epoch 21/50 | Learning Rate: 0.000100 | Avg Train Loss: 1.0934


Validating...: 100%|██████████| 19/19 [00:15<00:00,  1.19batch/s]


Epoch 21 | Val mAP: 0.4703


Epoch 22/50: 100%|██████████| 142/142 [02:03<00:00,  1.15batch/s, loss=1.18] 


Epoch 22/50 | Learning Rate: 0.000100 | Avg Train Loss: 1.0462


Validating...: 100%|██████████| 19/19 [00:15<00:00,  1.23batch/s]


Epoch 22 | Val mAP: 0.4696


Epoch 23/50: 100%|██████████| 142/142 [02:02<00:00,  1.16batch/s, loss=0.912]


Epoch 23/50 | Learning Rate: 0.000100 | Avg Train Loss: 0.9947


Validating...: 100%|██████████| 19/19 [00:15<00:00,  1.19batch/s]


Epoch 23 | Val mAP: 0.4696
Early stopping triggered at epoch 23
Best model saved with mAP: 0.4762


In [28]:
model = SSD_MOBILENET_V3_Large(num_classes_with_bg=MODEL_NUM_CLASSES)
model.load_state_dict(torch.load(f"{LOCAL_DIR}/best_model.pth")["model_state_dict"])
model.to("cuda:0")

SSD_MOBILENET_V3_Large(
  (model): SSD(
    (backbone): SSDLiteFeatureExtractorMobileNet(
      (features): Sequential(
        (0): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (2): Hardswish()
          )
          (1): InvertedResidual(
            (block): Sequential(
              (0): Conv2dNormActivation(
                (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
                (1): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
                (2): ReLU(inplace=True)
              )
              (1): Conv2dNormActivation(
                (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_runn

In [29]:
model.evaluate({"val_loader":val_loader}, device="cuda:0")

Starting evaluation...
Evaluating val_loader set


Evaluating val_loader set: 100%|██████████| 19/19 [00:16<00:00,  1.17batch/s]


Evaluation complete.


{'val_loader': {'map': tensor(1.7881e-05),
  'map_50': tensor(0.0001),
  'map_75': tensor(1.8398e-07),
  'map_small': tensor(0.),
  'map_medium': tensor(9.7113e-06),
  'map_large': tensor(1.9839e-05),
  'mar_1': tensor(0.),
  'mar_10': tensor(0.0005),
  'mar_100': tensor(0.0170),
  'mar_small': tensor(0.),
  'mar_medium': tensor(0.0051),
  'mar_large': tensor(0.0180),
  'ious': {(0, 1): [],
   (0, 2): [],
   (0, 3): [],
   (0, 4): [],
   (0, 5): [],
   (0, 6): [],
   (0, 7): [],
   (0, 8): [],
   (0, 9): [],
   (0, 10): [],
   (1, 1): [],
   (1, 2): [],
   (1, 3): [],
   (1, 4): [],
   (1, 5): [],
   (1, 6): [],
   (1, 7): [],
   (1, 8): [],
   (1, 9): [],
   (1,
    10): tensor([[0.2351],
           [0.1444],
           [0.1014],
           [0.0875],
           [0.1562],
           [0.0941],
           [0.0446],
           [0.0905],
           [0.0000],
           [0.0461],
           [0.2976],
           [0.0475],
           [0.2691],
           [0.0519],
           [0.1110],
       

In [13]:
!ls

balanced_dataset  LICENSE    requirements.txt	   ssd_mobnetv3_adis
best_model.pth	  README.md  SSD_MobileNetV3_ADIS  testing.ipynb


In [None]:
# # bohb tuning parameters
# def train(warmup_epochs: int, num_epochs: int, patience: int, initial_lr: float, betas: Tuple[float, float], weight_decay: float, dataloaders: dict[str, torch.utils.data.DataLoader], callback):
#     # early stopping parameters
#     best_map = float('-inf')
#     patience_counter = 0
    
#     # get the dataloaders
#     train_loader, val_loader = dataloaders['train'], dataloaders['val']
    
#     # Set device
#     device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#     print(f"Using device: {device}")
    
#     # Load the model
#     model = SSD_MOBILENET_V3_Large(num_classes_with_bg=MODEL_NUM_CLASSES)
#     model.to(device)
    
#     # Optimizer
#     optimizer = model.configure_optimizers(lr=initial_lr, betas=betas, weight_decay=weight_decay, eps=1e-08, fused=True)
    
#     for epoch in range(num_epochs):
#         # Warmup phase: linearly increase learning rate for the first 4 epochs
#         if epoch < warmup_epochs:
#             lr = initial_lr * (epoch + 1) / warmup_epochs
#             for param_group in optimizer.param_groups:
#                 param_group['lr'] = lr
#         # Training phase
#         model.train()
#         total_loss = 0.0
#         num_batches = len(train_loader)
        
#         # Progress bar
#         train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        
#         for _, (images, targets) in enumerate(train_bar):
#             # Move data to device
#             images = images.to(device)
#             targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
#             # Forward pass
#             loss_dict = model(images, targets)
#             losses = sum(loss for loss in loss_dict.values())
            
#             # Backward pass and optimization
#             optimizer.zero_grad()
#             losses.backward()
#             optimizer.step()
            
#             batch_loss = losses.detach().item()
#             total_loss += batch_loss
            
#             # Update progress bar
#             train_bar.set_postfix(loss=batch_loss)
        
#         avg_loss = total_loss / num_batches
#         print(f"Epoch {epoch+1}/{num_epochs} | Learning Rate: {lr:.6f} | Avg Train Loss: {avg_loss:.4f}")
        
#         # Validation phase
#         model.eval()
#         metric = MeanAveragePrecision()
#         eval_bar = tqdm(val_loader, desc=f"Validating...", unit="batch")
#         with torch.no_grad():
#             for images, targets in eval_bar:
#                 # Move data to device
#                 images = images.to(device)
#                 targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
#                 # Forward pass    
#                 predictions = model(images)
#                 metric.update(predictions, targets)
        
#         map_result = metric.compute()
#         print(f"Epoch {epoch+1} | Val mAP: {map_result['map']:.4f}")
        
#         # Report the validation mAP
#         callback(map_result['map'], epoch+1)
        
#         # Early stopping logic
#         if map_result['map'] > best_map:
#             best_map = map_result['map']
#             patience_counter = 0
#         else:
#             patience_counter += 1
#             if patience_counter >= patience:
#                 print("Early stopping triggered at epoch", epoch + 1)
#                 break
            
#     return best_map

In [None]:
# # constants
# WARMUP_EPOCHS = 3
# NUM_EPOCHS = 15
# PATIENCE = 3

# # define the dataloaders
# dataloaders = {"train":train_loader, "val":val_loader}

# # define the objective function
# def objective(trial):
#     # define callback to report intermidiate results
#     def on_train_epoch_end(score, epoch):
#         trial.report(score, step=epoch)  
#         if trial.should_prune():
#             raise optuna.TrialPruned()
        
#     # suggest hyperparameters for the model
#     lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
#     weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True)
#     momentum = trial.suggest_float("momentum", 0.7, 0.99)
    
#     # train the model
#     best_map = train(warmup_epochs=WARMUP_EPOCHS, num_epochs=NUM_EPOCHS, patience=PATIENCE, initial_lr=lr, betas=(momentum, 0.999), weight_decay=weight_decay,
#         dataloaders=dataloaders, callback=on_train_epoch_end)
    
#     # return the best mAP
#     return best_map

In [None]:
# # define the number of trials
# NUM_TRIALS = 5

# # load the study
# study = optuna.create_study(direction='maximize', 
#                             sampler=optuna.samplers.TPESampler(), 
#                             pruner=optuna.pruners.HyperbandPruner(),
#                             study_name="ssd_mobnetv3_adis_tuning",
#                             load_if_exists=True)

# # Optimize with a callback to stop after NUM_TRIALS complete trials
# study.optimize(objective, n_trials=NUM_TRIALS)

In [None]:
# joblib.dump(study, f"{LOCAL_DIR}/optuna_study.pkl")