In [None]:
# pip install optuna

In [None]:
# pip install datasets

In [None]:
import optuna
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoImageProcessor, AutoModelForImageClassification
from sklearn.metrics import average_precision_score, precision_recall_fscore_support
from tqdm import tqdm
from torchvision import transforms
import numpy as np
import cv2
from PIL import Image

In [None]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Define the checkpoint directory and path
checkpoint_dir = "/content/drive/My Drive/Colab Checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)  # Create directory if it doesn't exist
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pth")

In [None]:
# Custom Model with Dropout
class CustomDINOv2WithDropout(nn.Module):
    def __init__(self, base_model, num_labels, dropout_rate=0.3):
        super(CustomDINOv2WithDropout, self).__init__()
        self.base_model = base_model
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(self.base_model.config.hidden_size, num_labels)

    def forward(self, pixel_values):
        outputs = self.base_model(pixel_values=pixel_values, output_hidden_states=False)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

# Multi-label Metrics
def multi_label_metrics(logits, y_true, labels, threshold=0.5):
    probs = torch.sigmoid(logits).cpu().numpy()
    y_pred = probs > threshold

    # Compute overall metrics
    mean_prec, mean_rec, mean_f1, _ = precision_recall_fscore_support(
        y_true=y_true, y_pred=y_pred, average="weighted", zero_division=np.nan
    )
    mean_ap = average_precision_score(y_true, probs, average="weighted")

    # Compute label-wise metrics
    precs, recs, f1s, _ = precision_recall_fscore_support(
        y_true=y_true, y_pred=y_pred, average=None, zero_division=np.nan
    )
    aps = average_precision_score(y_true, probs, average=None)

    # Combine metrics into a dictionary
    metrics = {
        "mean_ap": mean_ap,
        "mean_precision": mean_prec,
        "mean_recall": mean_rec,
        "mean_f1": mean_f1,
        "label_aps": {labels[i]: aps[i] for i in range(len(labels))},
        "label_f1s": {labels[i]: f1s[i] for i in range(len(labels))}
    }
    return metrics


# Haze Removal Function with Tunable Parameters
def haze_removal(image, omega, radius, epsilon):
    # Ensure the image is a NumPy array
    if not isinstance(image, np.ndarray):
        image = np.array(image)

    # Convert image to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32)
    atmospheric_light = np.max(gray)
    transmission = 1 - omega * cv2.erode(gray / atmospheric_light, np.ones((radius, radius), np.uint8))
    transmission = np.clip(transmission, 0.1, 1.0)

    # Refine the transmission map using a basic box filter
    refined_transmission = cv2.blur(transmission, (radius, radius))

    # Dehaze the image
    dehazed = np.zeros_like(image, dtype=np.float32)
    for i in range(3):
        dehazed[:, :, i] = (image[:, :, i] - atmospheric_light) / refined_transmission + atmospheric_light
    return np.clip(dehazed, 0, 255).astype(np.uint8)

# Dataset Class with Dehazing Enhancement for "roads_damage"
from PIL import Image

# Dataset Class with Selective Dehazing
class StreamDatasetWithSelectiveEnhancement(IterableDataset):
    def __init__(self, dataset, split_name, label_keys, image_transforms, omega, radius, epsilon):
        self.dataset = dataset
        self.split_name = split_name
        self.label_keys = label_keys
        self.image_transforms = image_transforms
        self.omega = omega
        self.radius = radius
        self.epsilon = epsilon

        # Indices of labels for selective dehazing
        self.dehaze_labels = [
            self.label_keys.index("roads_damage"),
            self.label_keys.index("flooding_structures"),
            self.label_keys.index("flooding_any"),
            self.label_keys.index("trees_damage"),


        ]

    def process_item(self, item):
        image = item["image"]
        labels = [int(item[key]) for key in self.label_keys]

        # Ensure image is a NumPy array before processing
        if not isinstance(image, np.ndarray):
            image = np.array(image)

        # Apply haze removal for specified labels only
        if any(labels[label_idx] for label_idx in self.dehaze_labels):
            image = haze_removal(image, self.omega, self.radius, self.epsilon)

        # Convert the processed NumPy array back to a PIL Image
        image = Image.fromarray(image)

        # Apply transformations
        processed_image = self.image_transforms(image)
        processed_labels = torch.tensor(labels, dtype=torch.float32)
        return processed_image, processed_labels

    def __iter__(self):
        for item in self.dataset[self.split_name]:
            yield self.process_item(item)

# Updated process_dataset Function
def process_dataset(
    model, dataset, split_name, label_keys, image_transforms, optimizer=None, train=False, batch_size=8, omega=None, radius=None, epsilon=None
):
    model.train() if train else model.eval()
    running_loss = 0.0
    all_logits, all_labels = [], []
    batch_count = 0

    # Use the enhanced dataset with selective dehazing
    processed_dataset = StreamDatasetWithSelectiveEnhancement(
        dataset, split_name, label_keys, image_transforms, omega=omega, radius=radius, epsilon=epsilon
    )
    loader = DataLoader(processed_dataset, batch_size=batch_size, collate_fn=lambda x: tuple(zip(*x)))

    for batch_images, batch_labels in tqdm(loader, desc="Training" if train else "Validation"):
        batch_count += 1
        batch_images, batch_labels = map(torch.stack, (batch_images, batch_labels))
        batch_images, batch_labels = batch_images.to(device), batch_labels.to(device)

        if train:
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                logits = model(batch_images)
                loss = torch.nn.BCEWithLogitsLoss()(logits, batch_labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item()
        else:
            with torch.no_grad():
                logits = model(batch_images)
                loss = torch.nn.BCEWithLogitsLoss()(logits, batch_labels)
                running_loss += loss.item()

        all_logits.append(logits.detach().cpu())
        all_labels.append(batch_labels.detach().cpu())

    all_logits = torch.cat(all_logits, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    return running_loss / batch_count, all_logits, all_labels

def objective(trial):
    global model, processor, device

    # Hyperparameters to tune
    lr = trial.suggest_float("lr", 1e-6, 1e-3, log=True)
    dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.5)
    batch_size = trial.suggest_int("batch_size", 16, 64, step=16)

    # Haze removal parameters
    omega = trial.suggest_float("omega", 0.8, 1.0)
    radius = trial.suggest_int("radius", 10, 20)
    epsilon = trial.suggest_float("epsilon", 0.0001, 0.01)

    # Redefine the model with tuned dropout rate
    model = CustomDINOv2WithDropout(base_model.base_model, num_labels=len(label_keys), dropout_rate=dropout_rate).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scaler = torch.cuda.amp.GradScaler()

    # Warm Start: Resume from checkpoint if it exists
    start_epoch, best_val_map = 0, 0.0
    if os.path.exists(checkpoint_path):
        print(f"Resuming training from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scaler.load_state_dict(checkpoint["scaler_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        best_val_map = checkpoint["best_val_map"]

    # Training Loop
    num_epochs = 5  # Shorter for tuning
    patience, no_improvement = 3, 0

    for epoch in range(start_epoch, num_epochs):
        print(f"Trial {trial.number}, Epoch {epoch + 1}/{num_epochs}")

        # Training
        train_loss, _, _ = process_dataset(
            model, ds, "train", label_keys, image_transforms, optimizer, train=True,
            batch_size=batch_size, omega=omega, radius=radius, epsilon=epsilon
        )
        print(f"Trial {trial.number}, Epoch {epoch + 1}: Train Loss = {train_loss:.4f}")

        # Validation
        val_loss, val_logits, val_labels = process_dataset(
            model, ds, "validation", label_keys, image_transforms, batch_size=batch_size,
            omega=omega, radius=radius, epsilon=epsilon
        )
        val_metrics = multi_label_metrics(val_logits, val_labels.numpy(), label_keys)
        print(f"Validation Metrics (Trial {trial.number}, Epoch {epoch + 1}): {val_metrics}")

        # Display label-wise metrics for better visibility
        print(f"Label-wise APs (Trial {trial.number}, Epoch {epoch + 1}):")
        for label, ap in val_metrics["label_aps"].items():
            print(f"  {label}: {ap:.4f}")
        print(f"Label-wise F1 Scores (Trial {trial.number}, Epoch {epoch + 1}):")
        for label, f1 in val_metrics["label_f1s"].items():
            print(f"  {label}: {f1:.4f}")

        # Log mean_ap for the trial
        trial.report(val_metrics["mean_ap"], epoch)
        print(f"Trial {trial.number}, Epoch {epoch + 1}: mean_ap = {val_metrics['mean_ap']:.4f}")

        # Save the best model of the trial
        if val_metrics["mean_ap"] > best_val_map:
            best_val_map = val_metrics["mean_ap"]
            no_improvement = 0
            torch.save(model.state_dict(), f"{checkpoint_dir}/partially_dehazed_best_model_trial_{trial.number}.pth")
            print(f"Best model saved for trial {trial.number} with mAP: {best_val_map:.4f}")
        else:
            no_improvement += 1

        # Save a checkpoint after every epoch
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scaler_state_dict": scaler.state_dict(),
            "best_val_map": best_val_map,
        }, checkpoint_path)
        print(f"Checkpoint saved at: {checkpoint_path}")

        # Early stopping
        if no_improvement >= patience:
            print(f"Trial {trial.number}: Early stopping triggered.")
            break

        torch.cuda.empty_cache()  # Clear memory to prevent fragmentation

    return best_val_map


# Load dataset
from datasets import load_dataset
ds = load_dataset("MITLL/LADI-v2-dataset", streaming=True)

# Define label keys
label_keys = [
    'bridges_any', 'buildings_any', 'buildings_affected_or_greater',
    'buildings_minor_or_greater', 'debris_any', 'flooding_any',
    'flooding_structures', 'roads_any', 'roads_damage', 'trees_any',
    'trees_damage', 'water_any'
]

# Load DINOv2 model
model_name = "facebook/dinov2-base"
processor = AutoImageProcessor.from_pretrained(model_name)
base_model = AutoModelForImageClassification.from_pretrained(
    model_name, ignore_mismatched_sizes=True
)

# Define transformations
image_transforms = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Device and scaler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler = torch.cuda.amp.GradScaler()

# Run Optuna Study
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=2)

# Print best hyperparameters
print("Best hyperparameters:", study.best_params)



In [None]:
def train_with_best_params(best_params):
    global model

    # Extract best hyperparameters
    lr = best_params["lr"]
    dropout_rate = best_params["dropout_rate"]
    batch_size = best_params["batch_size"]
    omega = best_params["omega"]
    radius = best_params["radius"]
    epsilon = best_params["epsilon"]

    # Redefine the model with the best dropout rate
    model = CustomDINOv2WithDropout(base_model.base_model, num_labels=len(label_keys), dropout_rate=dropout_rate).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scaler = torch.cuda.amp.GradScaler()

    # Training loop
    num_epochs = 15
    patience = 3     # Number of epochs to wait for improvement
    no_improvement = 0
    best_val_map = 0.0
    best_model_path = "/content/partially_dehazed_final_best_model.pth"

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Training
        train_loss, _, _ = process_dataset(
            model, ds, "train", label_keys, image_transforms, optimizer, train=True,
            batch_size=batch_size, omega=omega, radius=radius, epsilon=epsilon
        )
        print(f"Epoch {epoch + 1}: Train Loss = {train_loss:.4f}")

        # Validation
        val_loss, val_logits, val_labels = process_dataset(
            model, ds, "validation", label_keys, image_transforms, batch_size=batch_size,
            omega=omega, radius=radius, epsilon=epsilon
        )
        val_metrics = multi_label_metrics(val_logits, val_labels.numpy(), label_keys)
        print(f"Validation Metrics (Epoch {epoch + 1}): {val_metrics}")

        # Display label-wise metrics for better visibility
        print(f"Label-wise APs (Epoch {epoch + 1}):")
        for label, ap in val_metrics["label_aps"].items():
            print(f"  {label}: {ap:.4f}")
        print(f"Label-wise F1 Scores (Epoch {epoch + 1}):")
        for label, f1 in val_metrics["label_f1s"].items():
            print(f"  {label}: {f1:.4f}")

        # Save the best model
        if val_metrics["mean_ap"] > best_val_map:
            best_val_map = val_metrics["mean_ap"]
            no_improvement = 0
            torch.save(model.state_dict(), best_model_path)
            print(f"Best model saved for epoch {epoch + 1} with mAP: {best_val_map:.4f}")
        else:
            no_improvement += 1

        # Early stopping
        if no_improvement >= patience:
            print(f"Early stopping triggered after {patience} epochs without improvement.")
            break

        # Clear GPU cache
        torch.cuda.empty_cache()

    print("Training complete.")
    print(f"Best validation mAP: {best_val_map:.4f}")
    print(f"Best model saved to: {best_model_path}")

# Load the best parameters from Optuna and start training
# train_with_best_params(study.best_params)
train_with_best_params(study.best_params)

In [None]:
# Evaluate on the test set using the final trained model
print("Evaluating on the test set with the final trained model...")

# Ensure the model is in evaluation mode
model.eval()

# Process the test dataset
test_loss, test_logits, test_labels = process_dataset(
    model,
    ds,
    "test",
    label_keys,
    image_transforms,
    batch_size=study.best_params["batch_size"],
    omega=study.best_params["omega"],
    radius=study.best_params["radius"],
    epsilon=study.best_params["epsilon"]
)

# Compute test metrics
test_metrics = multi_label_metrics(test_logits, test_labels.numpy(), label_keys)

# Print overall test metrics
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Metrics:")
print(test_metrics)

# Print label-wise metrics for detailed evaluation
print("Label-wise APs on Test Set:")
for label, ap in test_metrics["label_aps"].items():
    print(f"  {label}: {ap:.4f}")

print("Label-wise F1 Scores on Test Set:")
for label, f1 in test_metrics["label_f1s"].items():
    print(f"  {label}: {f1:.4f}")
