In [None]:
# pip install optuna

In [None]:
# pip install datasets

In [None]:
import optuna
import torch
from torch import nn
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoImageProcessor, AutoModelForImageClassification
from sklearn.metrics import average_precision_score, precision_recall_fscore_support
from tqdm import tqdm
from torchvision import transforms
import numpy as np
from PIL import Image

In [None]:
# Custom Model with Dropout
class CustomDINOv2WithDropout(nn.Module):
    def __init__(self, base_model, num_labels, dropout_rate=0.3):
        super(CustomDINOv2WithDropout, self).__init__()
        self.base_model = base_model
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(self.base_model.config.hidden_size, num_labels)

    def forward(self, pixel_values):
        outputs = self.base_model(pixel_values=pixel_values, output_hidden_states=False)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits


# Multi-label Metrics
def multi_label_metrics(logits, y_true, labels, threshold=0.5):
    probs = torch.sigmoid(logits).cpu().numpy()
    y_pred = probs > threshold

    # Compute overall metrics
    mean_prec, mean_rec, mean_f1, _ = precision_recall_fscore_support(
        y_true=y_true, y_pred=y_pred, average="weighted", zero_division=np.nan
    )
    mean_ap = average_precision_score(y_true, probs, average="weighted")

    # Compute label-wise metrics
    precs, recs, f1s, _ = precision_recall_fscore_support(
        y_true=y_true, y_pred=y_pred, average=None, zero_division=np.nan
    )
    aps = average_precision_score(y_true, probs, average=None)

    # Combine metrics into a dictionary
    metrics = {
        "mean_ap": mean_ap,
        "mean_precision": mean_prec,
        "mean_recall": mean_rec,
        "mean_f1": mean_f1,
        "label_aps": {labels[i]: aps[i] for i in range(len(labels))},
        "label_f1s": {labels[i]: f1s[i] for i in range(len(labels))}
    }
    return metrics


# Dataset class
class StreamDataset(IterableDataset):
    def __init__(self, dataset, split_name, label_keys, image_transforms):
        self.dataset = dataset
        self.split_name = split_name
        self.label_keys = label_keys
        self.image_transforms = image_transforms

    def process_item(self, item):
        image = item["image"]
        labels = [int(item[key]) for key in self.label_keys]
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        processed_image = self.image_transforms(image)
        processed_labels = torch.tensor(labels, dtype=torch.float32)
        return processed_image, processed_labels

    def __iter__(self):
        for item in self.dataset[self.split_name]:
            yield self.process_item(item)


# Dataset processing
def process_dataset(
    model, dataset, split_name, label_keys, image_transforms, optimizer=None, train=False, batch_size=8
):
    model.train() if train else model.eval()
    running_loss = 0.0
    all_logits, all_labels = [], []

    processed_dataset = StreamDataset(dataset, split_name, label_keys, image_transforms)
    loader = DataLoader(processed_dataset, batch_size=batch_size, collate_fn=lambda x: tuple(zip(*x)))

    batch_count = 0
    for batch in tqdm(loader, desc="Training" if train else "Validation"):
        batch_count += 1
        batch_images, batch_labels = map(torch.stack, batch)
        batch_images, batch_labels = batch_images.to(device), batch_labels.to(device)

        if train:
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                logits = model(batch_images)
                loss = torch.nn.BCEWithLogitsLoss()(logits, batch_labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item()
        else:
            with torch.no_grad():
                logits = model(batch_images)
                loss = torch.nn.BCEWithLogitsLoss()(logits, batch_labels)
                running_loss += loss.item()

        all_logits.append(logits.detach().cpu())
        all_labels.append(batch_labels.detach().cpu())

    all_logits = torch.cat(all_logits, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    return running_loss / batch_count, all_logits, all_labels


# Objective Function for Optuna
def objective(trial):
    global model, processor, device

    # Hyperparameters to tune
    lr = trial.suggest_float("lr", 1e-6, 1e-3, log=True)
    dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.5)
    batch_size = trial.suggest_int("batch_size", 16, 64, step=16)

    # Redefine the model with tuned dropout rate
    model = CustomDINOv2WithDropout(base_model.base_model, num_labels=len(label_keys), dropout_rate=dropout_rate).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scaler = torch.cuda.amp.GradScaler()

    # Path to save the best model for this trial
    trial_model_path = f"best_model_trial_{trial.number}.pth"

    # Training Loop
    num_epochs =5  # Shorter for tuning
    patience, no_improvement = 3, 0
    best_val_map = 0.0
    accumulation_steps = max(1, 64 // batch_size)  # Simulate effective batch size of 64

    for epoch in range(num_epochs):
        print(f"Trial {trial.number}, Epoch {epoch + 1}/{num_epochs}")

        # Training
        model.train()
        train_dataset = StreamDataset(ds, "train", label_keys, image_transforms)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=lambda x: tuple(zip(*x)))

        train_loss = 0.0
        optimizer.zero_grad()
        for i, (batch_images, batch_labels) in enumerate(tqdm(train_loader, desc="Training")):
            batch_images, batch_labels = torch.stack(batch_images).to(device), torch.stack(batch_labels).to(device)

            with torch.cuda.amp.autocast():
                logits = model(batch_images)
                loss = nn.BCEWithLogitsLoss()(logits, batch_labels)
                loss = loss / accumulation_steps  # Normalize for gradient accumulation

            scaler.scale(loss).backward()

            if (i + 1) % accumulation_steps == 0 or (i + 1) == batch_size * accumulation_steps:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            train_loss += loss.item() * accumulation_steps  # De-normalize for logging

        # Validation
        val_loss, val_logits, val_labels = process_dataset(
            model, ds, "validation", label_keys, image_transforms, batch_size=batch_size
        )
        val_metrics = multi_label_metrics(val_logits, val_labels.numpy(), label_keys)
        print(f"Validation Metrics (Trial {trial.number}, Epoch {epoch + 1}): {val_metrics}")

        # Log mean_ap for the trial
        trial.report(val_metrics["mean_ap"], epoch)
        print(f"Trial {trial.number}, Epoch {epoch + 1}: mean_ap = {val_metrics['mean_ap']:.4f}")

        # Save the best model of the trial
        if val_metrics["mean_ap"] > best_val_map:
            best_val_map = val_metrics["mean_ap"]
            no_improvement = 0
            torch.save(model.state_dict(), trial_model_path)
            print(f"Best model saved for trial {trial.number} with mAP: {best_val_map:.4f}")
        else:
            no_improvement += 1

        # Prune trial if mean_ap is not improving
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

        # Early stopping
        if no_improvement >= patience:
            print(f"Trial {trial.number}: Early stopping triggered.")
            break

        torch.cuda.empty_cache()  # Clear memory to prevent fragmentation

    return best_val_map

# Load dataset
from datasets import load_dataset
ds = load_dataset("MITLL/LADI-v2-dataset", streaming=True)

# Define label keys
label_keys = [
    'bridges_any', 'buildings_any', 'buildings_affected_or_greater',
    'buildings_minor_or_greater', 'debris_any', 'flooding_any',
    'flooding_structures', 'roads_any', 'roads_damage', 'trees_any',
    'trees_damage', 'water_any'
]

# Load DINOv2 model
model_name = "facebook/dinov2-base"
processor = AutoImageProcessor.from_pretrained(model_name)
base_model = AutoModelForImageClassification.from_pretrained(
    model_name, ignore_mismatched_sizes=True
)

# Define transformations
image_transforms = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Device and scaler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler = torch.cuda.amp.GradScaler()

# Run Optuna Study
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=2)

# Print best hyperparameters
print("Best hyperparameters:", study.best_params)


Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Some weights of Dinov2ForImageClassification were not initialized from the model checkpoint at facebook/dinov2-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  scaler = torch.cuda.amp.GradScaler()
[I 2024-12-05 16:41:12,475] A new study created in memory with name: no-name-12853c83-60cb-4795-98bb-f0774facc515
  scaler = torch.cuda.amp.GradScaler()


Trial 0, Epoch 1/5


  with torch.cuda.amp.autocast():
Training: 126it [25:05, 11.95s/it]
Validation: 14it [03:04, 13.20s/it]


Validation Metrics (Trial 0, Epoch 1): {'mean_ap': 0.9272823797298744, 'mean_precision': 0.9190118049561841, 'mean_recall': 0.820244849208719, 'mean_f1': 0.8500201271211056, 'label_aps': {'bridges_any': 0.6939329526936089, 'buildings_any': 0.9905536989689687, 'buildings_affected_or_greater': 0.8083271453088966, 'buildings_minor_or_greater': 0.6975624959455269, 'debris_any': 0.7065526540863172, 'flooding_any': 0.8460277478991011, 'flooding_structures': 0.8088699287882628, 'roads_any': 0.9852268111565917, 'roads_damage': 0.35830053829217046, 'trees_any': 0.9948153397282482, 'trees_damage': 0.8041611675642122, 'water_any': 0.9653076216387735}, 'label_f1s': {'bridges_any': 0.5074626865671642, 'buildings_any': 0.9590865842055185, 'buildings_affected_or_greater': 0.6793650793650794, 'buildings_minor_or_greater': 0.5111111111111111, 'debris_any': 0.6258503401360545, 'flooding_any': 0.6536964980544747, 'flooding_structures': 0.5542168674698795, 'roads_any': 0.9333333333333333, 'roads_damage': 

  with torch.cuda.amp.autocast():
Training: 126it [24:56, 11.88s/it]
Validation: 14it [02:54, 12.47s/it]


Validation Metrics (Trial 0, Epoch 2): {'mean_ap': 0.9420277979530379, 'mean_precision': 0.9294120505796343, 'mean_recall': 0.8474171394446103, 'mean_f1': 0.8736462171954757, 'label_aps': {'bridges_any': 0.8292579475154118, 'buildings_any': 0.9938641967645012, 'buildings_affected_or_greater': 0.84736902922533, 'buildings_minor_or_greater': 0.7429515446135834, 'debris_any': 0.7721596568894278, 'flooding_any': 0.8712689620028654, 'flooding_structures': 0.8364074109819027, 'roads_any': 0.9879511059052098, 'roads_damage': 0.47897109161271706, 'trees_any': 0.9957507698600025, 'trees_damage': 0.846608950896181, 'water_any': 0.968516558903249}, 'label_f1s': {'bridges_any': 0.7619047619047619, 'buildings_any': 0.9576669802445907, 'buildings_affected_or_greater': 0.6986301369863014, 'buildings_minor_or_greater': 0.5979381443298969, 'debris_any': 0.6131386861313869, 'flooding_any': 0.7854785478547854, 'flooding_structures': 0.7083333333333334, 'roads_any': 0.9419680403700589, 'roads_damage': 0.0

  with torch.cuda.amp.autocast():
Training: 126it [25:03, 11.93s/it]
Validation: 14it [02:58, 12.74s/it]


Validation Metrics (Trial 0, Epoch 3): {'mean_ap': 0.9382066249160568, 'mean_precision': 0.9373147799976579, 'mean_recall': 0.8136757240967453, 'mean_f1': 0.8466338036285604, 'label_aps': {'bridges_any': 0.8469698802919127, 'buildings_any': 0.9905049200073197, 'buildings_affected_or_greater': 0.8394498268895761, 'buildings_minor_or_greater': 0.7301693504755356, 'debris_any': 0.7563395670620069, 'flooding_any': 0.8621265059654168, 'flooding_structures': 0.854004859196522, 'roads_any': 0.9885057249132737, 'roads_damage': 0.4782821355920136, 'trees_any': 0.9962558095903778, 'trees_damage': 0.8165845739474568, 'water_any': 0.9657363895806585}, 'label_f1s': {'bridges_any': 0.7674418604651163, 'buildings_any': 0.9577464788732394, 'buildings_affected_or_greater': 0.5791505791505791, 'buildings_minor_or_greater': 0.5, 'debris_any': 0.5511811023622047, 'flooding_any': 0.7272727272727273, 'flooding_structures': 0.627906976744186, 'roads_any': 0.9433333333333334, 'roads_damage': 0.117647058823529

  with torch.cuda.amp.autocast():
Training: 126it [25:03, 11.93s/it]
Validation: 14it [03:00, 12.86s/it]


Validation Metrics (Trial 0, Epoch 4): {'mean_ap': 0.9382827315648609, 'mean_precision': 0.9278012392137862, 'mean_recall': 0.8253209913406987, 'mean_f1': 0.8498447450887752, 'label_aps': {'bridges_any': 0.8538227352307669, 'buildings_any': 0.9918613104374011, 'buildings_affected_or_greater': 0.8499211324624866, 'buildings_minor_or_greater': 0.7314006625857704, 'debris_any': 0.7679901295990688, 'flooding_any': 0.8523672893553609, 'flooding_structures': 0.8520456768957153, 'roads_any': 0.9885286167622683, 'roads_damage': 0.462445848926618, 'trees_any': 0.9962250447260019, 'trees_damage': 0.8087551161797981, 'water_any': 0.966567672003711}, 'label_f1s': {'bridges_any': 0.7741935483870968, 'buildings_any': 0.9612109744560076, 'buildings_affected_or_greater': 0.5962264150943396, 'buildings_minor_or_greater': 0.4523809523809524, 'debris_any': 0.6, 'flooding_any': 0.7285714285714285, 'flooding_structures': 0.6966292134831461, 'roads_any': 0.945332211942809, 'roads_damage': 0.1153846153846153

  with torch.cuda.amp.autocast():
Training: 126it [25:02, 11.93s/it]
Validation: 14it [02:59, 12.81s/it]


Validation Metrics (Trial 0, Epoch 5): {'mean_ap': 0.9461098084486654, 'mean_precision': 0.912728269460341, 'mean_recall': 0.8781725888324873, 'mean_f1': 0.8881511030554424, 'label_aps': {'bridges_any': 0.8507296408440272, 'buildings_any': 0.9926761494725256, 'buildings_affected_or_greater': 0.8784421407848622, 'buildings_minor_or_greater': 0.7625016818403271, 'debris_any': 0.7801431032188103, 'flooding_any': 0.8704667012530217, 'flooding_structures': 0.8871840849137663, 'roads_any': 0.9889979605088767, 'roads_damage': 0.5023095129434999, 'trees_any': 0.9963490335952285, 'trees_damage': 0.8527042143561635, 'water_any': 0.9674185697263653}, 'label_f1s': {'bridges_any': 0.7155963302752294, 'buildings_any': 0.9591642924976258, 'buildings_affected_or_greater': 0.7643312101910829, 'buildings_minor_or_greater': 0.5684210526315789, 'debris_any': 0.7361963190184049, 'flooding_any': 0.7722772277227723, 'flooding_structures': 0.723404255319149, 'roads_any': 0.948073701842546, 'roads_damage': 0.2

[I 2024-12-05 19:01:24,725] Trial 0 finished with value: 0.9461098084486654 and parameters: {'lr': 7.243499683207418e-06, 'dropout_rate': 0.1535643673944425, 'batch_size': 64}. Best is trial 0 with value: 0.9461098084486654.
  scaler = torch.cuda.amp.GradScaler()


Best model saved for trial 0 with mAP: 0.9461
Trial 1, Epoch 1/5


  with torch.cuda.amp.autocast():
Training: 502it [25:09,  3.01s/it]
Validation: 56it [02:59,  3.21s/it]


Validation Metrics (Trial 1, Epoch 1): {'mean_ap': 0.9420091799358838, 'mean_precision': 0.9094239552886978, 'mean_recall': 0.8769782024484921, 'mean_f1': 0.8864572937900483, 'label_aps': {'bridges_any': 0.8172732293891378, 'buildings_any': 0.9888126154460668, 'buildings_affected_or_greater': 0.867889537101222, 'buildings_minor_or_greater': 0.7432429356415677, 'debris_any': 0.7744057242317043, 'flooding_any': 0.8719815111327316, 'flooding_structures': 0.8500748112337257, 'roads_any': 0.9885828930757194, 'roads_damage': 0.4877165993091563, 'trees_any': 0.9966626232283229, 'trees_damage': 0.8263539225218912, 'water_any': 0.9711113053704146}, 'label_f1s': {'bridges_any': 0.6842105263157895, 'buildings_any': 0.957345971563981, 'buildings_affected_or_greater': 0.7459807073954984, 'buildings_minor_or_greater': 0.6542056074766355, 'debris_any': 0.7239263803680982, 'flooding_any': 0.782051282051282, 'flooding_structures': 0.7111111111111111, 'roads_any': 0.9422750424448217, 'roads_damage': 0.2

  with torch.cuda.amp.autocast():
Training: 502it [25:13,  3.01s/it]
Validation: 56it [03:03,  3.29s/it]


Validation Metrics (Trial 1, Epoch 2): {'mean_ap': 0.9447671599425089, 'mean_precision': 0.899382710790785, 'mean_recall': 0.8907136458644371, 'mean_f1': 0.8906948105314657, 'label_aps': {'bridges_any': 0.8409913981809246, 'buildings_any': 0.9914172370016581, 'buildings_affected_or_greater': 0.8711467989185279, 'buildings_minor_or_greater': 0.7561484400340922, 'debris_any': 0.7913495178021519, 'flooding_any': 0.8787862482387786, 'flooding_structures': 0.8785589261895516, 'roads_any': 0.9884411719543293, 'roads_damage': 0.4497623022593617, 'trees_any': 0.9971561745023876, 'trees_damage': 0.8380520112537067, 'water_any': 0.9708829998049604}, 'label_f1s': {'bridges_any': 0.7358490566037735, 'buildings_any': 0.9612109744560076, 'buildings_affected_or_greater': 0.7492447129909365, 'buildings_minor_or_greater': 0.5894736842105263, 'debris_any': 0.6951219512195121, 'flooding_any': 0.7724550898203593, 'flooding_structures': 0.803921568627451, 'roads_any': 0.9444444444444444, 'roads_damage': 0.

  with torch.cuda.amp.autocast():
Training: 502it [25:20,  3.03s/it]
Validation: 56it [02:59,  3.20s/it]


Validation Metrics (Trial 1, Epoch 3): {'mean_ap': 0.9410091707370805, 'mean_precision': 0.8937241511097336, 'mean_recall': 0.8966855777844133, 'mean_f1': 0.8911220877952073, 'label_aps': {'bridges_any': 0.7831663969092, 'buildings_any': 0.9916689983151098, 'buildings_affected_or_greater': 0.8687148975245242, 'buildings_minor_or_greater': 0.7527988899094158, 'debris_any': 0.7752574350884568, 'flooding_any': 0.8710479675052797, 'flooding_structures': 0.8309885766868328, 'roads_any': 0.9891705843020282, 'roads_damage': 0.4592882295490415, 'trees_any': 0.9959456596407743, 'trees_damage': 0.8368729658378681, 'water_any': 0.9647280853567688}, 'label_f1s': {'bridges_any': 0.7083333333333334, 'buildings_any': 0.9577735124760077, 'buildings_affected_or_greater': 0.7522388059701492, 'buildings_minor_or_greater': 0.6105263157894737, 'debris_any': 0.6918238993710691, 'flooding_any': 0.774869109947644, 'flooding_structures': 0.8305084745762712, 'roads_any': 0.9464882943143813, 'roads_damage': 0.4,

  with torch.cuda.amp.autocast():
Training: 502it [25:16,  3.02s/it]
Validation: 56it [03:03,  3.27s/it]


Validation Metrics (Trial 1, Epoch 4): {'mean_ap': 0.9439633020934596, 'mean_precision': 0.8824050422791313, 'mean_recall': 0.9098238280083607, 'mean_f1': 0.8920493432167201, 'label_aps': {'bridges_any': 0.8342723712466737, 'buildings_any': 0.9926156637739627, 'buildings_affected_or_greater': 0.8780435926382222, 'buildings_minor_or_greater': 0.7527124715066782, 'debris_any': 0.7804759117358534, 'flooding_any': 0.8666195262157244, 'flooding_structures': 0.8462178705570731, 'roads_any': 0.9874138948489448, 'roads_damage': 0.477330926231008, 'trees_any': 0.9966841771317309, 'trees_damage': 0.8467721648420558, 'water_any': 0.9681891005736539}, 'label_f1s': {'bridges_any': 0.7912087912087912, 'buildings_any': 0.9642184557438794, 'buildings_affected_or_greater': 0.7849462365591398, 'buildings_minor_or_greater': 0.6542056074766355, 'debris_any': 0.73224043715847, 'flooding_any': 0.7861271676300579, 'flooding_structures': 0.8037383177570093, 'roads_any': 0.9385382059800664, 'roads_damage': 0.2

  with torch.cuda.amp.autocast():
Training: 502it [25:11,  3.01s/it]
Validation: 56it [03:00,  3.22s/it]
[I 2024-12-05 21:22:44,949] Trial 1 finished with value: 0.9447671599425089 and parameters: {'lr': 5.895255252186952e-06, 'dropout_rate': 0.23042276808728862, 'batch_size': 16}. Best is trial 0 with value: 0.9461098084486654.


Validation Metrics (Trial 1, Epoch 5): {'mean_ap': 0.9428816703327637, 'mean_precision': 0.9045614029679381, 'mean_recall': 0.8883248730964467, 'mean_f1': 0.8889461676947279, 'label_aps': {'bridges_any': 0.819940124125014, 'buildings_any': 0.989742920070334, 'buildings_affected_or_greater': 0.86634560141319, 'buildings_minor_or_greater': 0.7694528424883759, 'debris_any': 0.7742338845621968, 'flooding_any': 0.8698350543044159, 'flooding_structures': 0.8572596373529813, 'roads_any': 0.9902560797852431, 'roads_damage': 0.5173924466689943, 'trees_any': 0.9956682593716479, 'trees_damage': 0.8355532871327611, 'water_any': 0.9661332693464975}, 'label_f1s': {'bridges_any': 0.7294117647058823, 'buildings_any': 0.9611374407582939, 'buildings_affected_or_greater': 0.7513812154696132, 'buildings_minor_or_greater': 0.6851851851851852, 'debris_any': 0.75, 'flooding_any': 0.7707006369426752, 'flooding_structures': 0.74, 'roads_any': 0.9431034482758621, 'roads_damage': 0.12, 'trees_any': 0.97564022485

In [None]:
# Final Training with Best Hyperparameters
# best_params = study.best_params
# Manually set best parameters
best_params = {
    'lr': 7.243499683207418e-06,
    'dropout_rate': 0.1535643673944425,
    'batch_size': 64
}
final_model = CustomDINOv2WithDropout(base_model.base_model, num_labels=len(label_keys), dropout_rate=best_params["dropout_rate"]).to(device)
optimizer = torch.optim.AdamW(final_model.parameters(), lr=best_params["lr"], weight_decay=1e-4)

# Path to save the best model for final training
final_model_path = "/content/final_best_model.pth"

best_val_map = 0.0
patience, no_improvement = 3, 0
accumulation_steps = max(1, 64 // best_params["batch_size"])  # Simulate effective batch size of 64

for epoch in range(15):
    print(f"Final Training, Epoch {epoch + 1}/15")

    # Training
    final_model.train()
    train_loss = 0.0
    train_dataset = StreamDataset(ds, "train", label_keys, image_transforms)
    train_loader = DataLoader(train_dataset, batch_size=best_params["batch_size"], collate_fn=lambda x: tuple(zip(*x)))

    optimizer.zero_grad()
    for i, (batch_images, batch_labels) in enumerate(tqdm(train_loader, desc="Training")):
        batch_images, batch_labels = torch.stack(batch_images).to(device), torch.stack(batch_labels).to(device)

        with torch.cuda.amp.autocast():
            logits = final_model(batch_images)
            loss = nn.BCEWithLogitsLoss()(logits, batch_labels)
            loss = loss / accumulation_steps  # Normalize for gradient accumulation

        scaler.scale(loss).backward()

        if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        train_loss += loss.item() * accumulation_steps  # De-normalize for logging

    # Validation
    val_loss, val_logits, val_labels = process_dataset(
        final_model, ds, "validation", label_keys, image_transforms, batch_size=best_params["batch_size"]
    )
    val_metrics = multi_label_metrics(val_logits, val_labels.numpy(), label_keys)
    print(f"Validation Metrics: {val_metrics}")

    # Save the best model during final training
    if val_metrics["mean_ap"] > best_val_map:
        best_val_map = val_metrics["mean_ap"]
        no_improvement = 0
        torch.save(final_model.state_dict(), final_model_path)
        print(f"New best model saved with mAP: {best_val_map:.4f}")
    else:
        no_improvement += 1

    # Early stopping
    if no_improvement >= patience:
        print("Early stopping triggered during final training.")
        break

    torch.cuda.empty_cache()  # Clear memory to prevent fragmentation





Final Training, Epoch 1/15


  with torch.cuda.amp.autocast():
Training: 126it [25:16, 12.03s/it]
Validation: 8it [01:59, 14.82s/it]'(ReadTimeoutError("HTTPSConnectionPool(host='cdn-lfs-us-1.hf.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 72dfea20-35ee-490f-8351-58b3aa6eaafe)')' thrown while requesting GET https://huggingface.co/datasets/MITLL/LADI-v2-dataset/resolve/5f2dbfe8c466d32edafd1bab847ec5252309acdb/data/validation/data-00025-of-00040.arrow
Retrying in 1s [Retry 1/5].
Validation: 14it [03:37, 15.53s/it]


Validation Metrics: {'mean_ap': 0.9310317999815452, 'mean_precision': 0.9222227458364122, 'mean_recall': 0.8181546730367274, 'mean_f1': 0.8444613237279097, 'label_aps': {'bridges_any': 0.772843638506266, 'buildings_any': 0.9919956946001534, 'buildings_affected_or_greater': 0.8112341181192707, 'buildings_minor_or_greater': 0.7151842344671382, 'debris_any': 0.7637105409883608, 'flooding_any': 0.8472072540069087, 'flooding_structures': 0.8347202445581658, 'roads_any': 0.9867970574677236, 'roads_damage': 0.3318138258402927, 'trees_any': 0.9954129750060572, 'trees_damage': 0.8058696436384956, 'water_any': 0.9635515373780983}, 'label_f1s': {'bridges_any': 0.5373134328358209, 'buildings_any': 0.9585687382297552, 'buildings_affected_or_greater': 0.6219081272084805, 'buildings_minor_or_greater': 0.5625, 'debris_any': 0.6621621621621622, 'flooding_any': 0.6533864541832669, 'flooding_structures': 0.35294117647058826, 'roads_any': 0.9343308395677473, 'roads_damage': 0.0, 'trees_any': 0.97313797313

  with torch.cuda.amp.autocast():
Training: 126it [24:59, 11.90s/it]
Validation: 14it [03:07, 13.39s/it]


Validation Metrics: {'mean_ap': 0.9423650867933403, 'mean_precision': 0.9380473292425361, 'mean_recall': 0.8175574798447298, 'mean_f1': 0.8548390817039973, 'label_aps': {'bridges_any': 0.8079305298252851, 'buildings_any': 0.9936194504266922, 'buildings_affected_or_greater': 0.8532150141654973, 'buildings_minor_or_greater': 0.7670894160950684, 'debris_any': 0.7851277872386315, 'flooding_any': 0.88437043779595, 'flooding_structures': 0.8647716017376099, 'roads_any': 0.9902096033857772, 'roads_damage': 0.3991635752930742, 'trees_any': 0.9970866040555326, 'trees_damage': 0.8328968172402949, 'water_any': 0.9672304967527258}, 'label_f1s': {'bridges_any': 0.6923076923076923, 'buildings_any': 0.9633113828786454, 'buildings_affected_or_greater': 0.5533596837944664, 'buildings_minor_or_greater': 0.5333333333333333, 'debris_any': 0.6713286713286714, 'flooding_any': 0.7293233082706767, 'flooding_structures': 0.55, 'roads_any': 0.9471919530595139, 'roads_damage': 0.04081632653061224, 'trees_any': 0

  with torch.cuda.amp.autocast():
Training: 126it [25:15, 12.03s/it]
Validation: 14it [03:28, 14.87s/it]


Validation Metrics: {'mean_ap': 0.9435797858835431, 'mean_precision': 0.9431479546280177, 'mean_recall': 0.8196476560167214, 'mean_f1': 0.8572191172589526, 'label_aps': {'bridges_any': 0.8461262629939503, 'buildings_any': 0.9931104128362316, 'buildings_affected_or_greater': 0.8530847330076461, 'buildings_minor_or_greater': 0.7691537352321396, 'debris_any': 0.8023044149586069, 'flooding_any': 0.877142215689506, 'flooding_structures': 0.8406164045389043, 'roads_any': 0.9895520588706571, 'roads_damage': 0.41391934212826664, 'trees_any': 0.9974059953458457, 'trees_damage': 0.8434748349133119, 'water_any': 0.9679172893548745}, 'label_f1s': {'bridges_any': 0.8, 'buildings_any': 0.9686013320647003, 'buildings_affected_or_greater': 0.5758754863813229, 'buildings_minor_or_greater': 0.5057471264367817, 'debris_any': 0.6330935251798561, 'flooding_any': 0.7153846153846154, 'flooding_structures': 0.55, 'roads_any': 0.9451114922813036, 'roads_damage': 0.04081632653061224, 'trees_any': 0.974874371859

  with torch.cuda.amp.autocast():
Training: 126it [25:02, 11.92s/it]
Validation: 14it [03:06, 13.31s/it]


Validation Metrics: {'mean_ap': 0.9466104554910971, 'mean_precision': 0.9360499527419766, 'mean_recall': 0.8441325768886234, 'mean_f1': 0.8753384018280576, 'label_aps': {'bridges_any': 0.8249449273156533, 'buildings_any': 0.993647658336823, 'buildings_affected_or_greater': 0.8883045401801188, 'buildings_minor_or_greater': 0.7758717156969628, 'debris_any': 0.7964488335429845, 'flooding_any': 0.8842321340878166, 'flooding_structures': 0.874072866706727, 'roads_any': 0.9905582807150661, 'roads_damage': 0.43192364688428786, 'trees_any': 0.9978514866465127, 'trees_damage': 0.8362588754450035, 'water_any': 0.9705032322949123}, 'label_f1s': {'bridges_any': 0.7555555555555555, 'buildings_any': 0.9640831758034026, 'buildings_affected_or_greater': 0.7315436241610739, 'buildings_minor_or_greater': 0.6326530612244898, 'debris_any': 0.6883116883116883, 'flooding_any': 0.7808219178082192, 'flooding_structures': 0.7422680412371134, 'roads_any': 0.946817785527463, 'roads_damage': 0.18518518518518517, 

  with torch.cuda.amp.autocast():
Training: 126it [25:15, 12.03s/it]
Validation: 14it [03:30, 15.01s/it]


Validation Metrics: {'mean_ap': 0.9456178063223948, 'mean_precision': 0.9263088223694105, 'mean_recall': 0.8707076739325171, 'mean_f1': 0.8920907385849205, 'label_aps': {'bridges_any': 0.822794489210805, 'buildings_any': 0.9940706152364824, 'buildings_affected_or_greater': 0.8832875590496276, 'buildings_minor_or_greater': 0.7969479740665214, 'debris_any': 0.7898724526183072, 'flooding_any': 0.8764825490013433, 'flooding_structures': 0.8696776285871005, 'roads_any': 0.9907041570471742, 'roads_damage': 0.43396419874816844, 'trees_any': 0.9969954166166157, 'trees_damage': 0.8353653933834709, 'water_any': 0.9685836651224594}, 'label_f1s': {'bridges_any': 0.7789473684210526, 'buildings_any': 0.9707822808671065, 'buildings_affected_or_greater': 0.78419452887538, 'buildings_minor_or_greater': 0.64, 'debris_any': 0.725, 'flooding_any': 0.7766323024054983, 'flooding_structures': 0.7755102040816326, 'roads_any': 0.9494097807757167, 'roads_damage': 0.27586206896551724, 'trees_any': 0.983302411873

  with torch.cuda.amp.autocast():
Training: 126it [24:59, 11.90s/it]
Validation: 14it [03:07, 13.42s/it]


Validation Metrics: {'mean_ap': 0.9461413234605608, 'mean_precision': 0.9018397260969916, 'mean_recall': 0.9089280382203643, 'mean_f1': 0.9031433382627045, 'label_aps': {'bridges_any': 0.8443010655826265, 'buildings_any': 0.9945456035595165, 'buildings_affected_or_greater': 0.8881345401836201, 'buildings_minor_or_greater': 0.7926062824302396, 'debris_any': 0.8066491194164893, 'flooding_any': 0.8824800432541298, 'flooding_structures': 0.8887046227361773, 'roads_any': 0.9902739553838897, 'roads_damage': 0.48911176779634113, 'trees_any': 0.9967400470387731, 'trees_damage': 0.822589560368953, 'water_any': 0.9634498498362033}, 'label_f1s': {'bridges_any': 0.8043478260869565, 'buildings_any': 0.9639468690702088, 'buildings_affected_or_greater': 0.815028901734104, 'buildings_minor_or_greater': 0.7256637168141593, 'debris_any': 0.7692307692307693, 'flooding_any': 0.7831325301204819, 'flooding_structures': 0.8431372549019608, 'roads_any': 0.9411764705882353, 'roads_damage': 0.3939393939393939, 

  with torch.cuda.amp.autocast():
Training: 126it [25:19, 12.06s/it]
Validation: 14it [03:28, 14.86s/it]


Validation Metrics: {'mean_ap': 0.9471715927896598, 'mean_precision': 0.8832848579130733, 'mean_recall': 0.924753657808301, 'mean_f1': 0.9012046229173155, 'label_aps': {'bridges_any': 0.8339063003634797, 'buildings_any': 0.9947738129019914, 'buildings_affected_or_greater': 0.8908268131204438, 'buildings_minor_or_greater': 0.7778876402977605, 'debris_any': 0.805658212598173, 'flooding_any': 0.8918129218042105, 'flooding_structures': 0.8831731416309045, 'roads_any': 0.9894211225205913, 'roads_damage': 0.4997551066171238, 'trees_any': 0.996089383682472, 'trees_damage': 0.8249659837376476, 'water_any': 0.9689979409618606}, 'label_f1s': {'bridges_any': 0.7474747474747475, 'buildings_any': 0.969639468690702, 'buildings_affected_or_greater': 0.8068181818181818, 'buildings_minor_or_greater': 0.6846846846846847, 'debris_any': 0.7391304347826086, 'flooding_any': 0.7857142857142857, 'flooding_structures': 0.8288288288288288, 'roads_any': 0.947107438016529, 'roads_damage': 0.5185185185185185, 'tre

  with torch.cuda.amp.autocast():
Training: 126it [25:09, 11.98s/it]
Validation: 14it [03:07, 13.39s/it]


Validation Metrics: {'mean_ap': 0.9448811915028206, 'mean_precision': 0.8878501825721049, 'mean_recall': 0.923260674828307, 'mean_f1': 0.9018363808829445, 'label_aps': {'bridges_any': 0.8548660733329855, 'buildings_any': 0.9909169160915078, 'buildings_affected_or_greater': 0.8992919000300477, 'buildings_minor_or_greater': 0.7641878387074408, 'debris_any': 0.7652019081740548, 'flooding_any': 0.8931566858349234, 'flooding_structures': 0.9044120387022913, 'roads_any': 0.9890762377893575, 'roads_damage': 0.46053419265697526, 'trees_any': 0.9950188790248704, 'trees_damage': 0.8202010437647077, 'water_any': 0.9669786457503465}, 'label_f1s': {'bridges_any': 0.7142857142857143, 'buildings_any': 0.9690140845070423, 'buildings_affected_or_greater': 0.8184438040345822, 'buildings_minor_or_greater': 0.62, 'debris_any': 0.75, 'flooding_any': 0.8103448275862069, 'flooding_structures': 0.8703703703703703, 'roads_any': 0.9405052974735126, 'roads_damage': 0.4166666666666667, 'trees_any': 0.982153846153

  with torch.cuda.amp.autocast():
Training: 126it [25:31, 12.15s/it]
Validation: 14it [03:29, 14.95s/it]


Validation Metrics: {'mean_ap': 0.9438831386005319, 'mean_precision': 0.890344535848109, 'mean_recall': 0.9089280382203643, 'mean_f1': 0.8958749526903993, 'label_aps': {'bridges_any': 0.8174687482861946, 'buildings_any': 0.9842003007841924, 'buildings_affected_or_greater': 0.8847486651034834, 'buildings_minor_or_greater': 0.7723879949961351, 'debris_any': 0.7686787067493558, 'flooding_any': 0.8884252650953833, 'flooding_structures': 0.8979922751434377, 'roads_any': 0.9900542991958773, 'roads_damage': 0.5009532675382895, 'trees_any': 0.9948776817504971, 'trees_damage': 0.8258053568345683, 'water_any': 0.969220481288845}, 'label_f1s': {'bridges_any': 0.7130434782608696, 'buildings_any': 0.967741935483871, 'buildings_affected_or_greater': 0.7654320987654321, 'buildings_minor_or_greater': 0.66, 'debris_any': 0.7272727272727273, 'flooding_any': 0.7800586510263929, 'flooding_structures': 0.8380952380952381, 'roads_any': 0.9432098765432099, 'roads_damage': 0.4383561643835616, 'trees_any': 0.9

  with torch.cuda.amp.autocast():
Training: 126it [25:22, 12.08s/it]
Validation: 14it [03:08, 13.48s/it]

Validation Metrics: {'mean_ap': 0.9450591180584627, 'mean_precision': 0.9020170771533885, 'mean_recall': 0.9020603165123917, 'mean_f1': 0.8998931978925744, 'label_aps': {'bridges_any': 0.8707753685493205, 'buildings_any': 0.983355809386273, 'buildings_affected_or_greater': 0.8884763562213003, 'buildings_minor_or_greater': 0.7695566513929748, 'debris_any': 0.7642905467741997, 'flooding_any': 0.8903583035532204, 'flooding_structures': 0.9145644863564291, 'roads_any': 0.9871233698219407, 'roads_damage': 0.5209478527875925, 'trees_any': 0.9962796627208002, 'trees_damage': 0.8269515480186858, 'water_any': 0.9688784245869783}, 'label_f1s': {'bridges_any': 0.7884615384615384, 'buildings_any': 0.9631728045325779, 'buildings_affected_or_greater': 0.774390243902439, 'buildings_minor_or_greater': 0.6355140186915887, 'debris_any': 0.7262569832402235, 'flooding_any': 0.7988980716253443, 'flooding_structures': 0.8256880733944955, 'roads_any': 0.9491525423728814, 'roads_damage': 0.4444444444444444, '




In [None]:
# Evaluate on test set
test_loss, test_logits, test_labels = process_dataset(
    final_model, ds, "test", label_keys, image_transforms, batch_size=best_params["batch_size"]
)


Validation: 17it [03:59, 14.08s/it]


In [None]:
test_metrics = multi_label_metrics(test_logits, test_labels.numpy(), label_keys)
print(f"Test Metrics: {test_metrics}")

Test Metrics: {'mean_ap': 0.9358319232425576, 'mean_precision': 0.8943476012954142, 'mean_recall': 0.8860208816705336, 'mean_f1': 0.8873895620341293, 'label_aps': {'bridges_any': 0.7043559862041444, 'buildings_any': 0.9842972066618617, 'buildings_affected_or_greater': 0.7072538123012451, 'buildings_minor_or_greater': 0.725720862696994, 'debris_any': 0.6954187820375122, 'flooding_any': 0.6206667841562463, 'flooding_structures': 0.3521433648082837, 'roads_any': 0.9688433380835724, 'roads_damage': 0.1193418948805612, 'trees_any': 0.9957081619718413, 'trees_damage': 0.5814604099602392, 'water_any': 0.9196310651592996}, 'label_f1s': {'bridges_any': 0.6119402985074627, 'buildings_any': 0.9494418910045962, 'buildings_affected_or_greater': 0.6136363636363636, 'buildings_minor_or_greater': 0.5454545454545454, 'debris_any': 0.6153846153846154, 'flooding_any': 0.5454545454545454, 'flooding_structures': 0.47058823529411764, 'roads_any': 0.9155609167671894, 'roads_damage': 0.11764705882352941, 'tre

In [None]:
# Display label-specific mAP
print("Label-wise mAP:")
for label, ap in test_metrics["label_aps"].items():
    print(f"{label}: {ap:.4f}")

Label-wise mAP:
bridges_any: 0.7044
buildings_any: 0.9843
buildings_affected_or_greater: 0.7073
buildings_minor_or_greater: 0.7257
debris_any: 0.6954
flooding_any: 0.6207
flooding_structures: 0.3521
roads_any: 0.9688
roads_damage: 0.1193
trees_any: 0.9957
trees_damage: 0.5815
water_any: 0.9196


In [None]:
# Display label-specific F1 scores
print("\nLabel-wise F1 Scores:")
for label, f1 in test_metrics["label_f1s"].items():
    print(f"{label}: {f1:.4f}")


Label-wise F1 Scores:
bridges_any: 0.6119
buildings_any: 0.9494
buildings_affected_or_greater: 0.6136
buildings_minor_or_greater: 0.5455
debris_any: 0.6154
flooding_any: 0.5455
flooding_structures: 0.4706
roads_any: 0.9156
roads_damage: 0.1176
trees_any: 0.9657
trees_damage: 0.5363
water_any: 0.8379
