In [None]:
# pip install datasets

In [None]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoModel
from sklearn.metrics import average_precision_score, precision_recall_fscore_support
from tqdm import tqdm
from torchvision import transforms
import numpy as np
import cv2
from PIL import Image
from datasets import load_dataset
import matplotlib.pyplot as plt

In [None]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

In [None]:
# Define the classes of interest
classes_of_interest = ["flooding_any", "trees_damage", "buildings_minor_or_greater"]

# Define label keys
label_keys = [
    'bridges_any', 'buildings_any', 'buildings_affected_or_greater', 'buildings_minor_or_greater',
    'debris_any', 'flooding_any', 'flooding_structures', 'roads_any', 'roads_damage',
    'trees_any', 'trees_damage', 'water_any'
]

In [None]:
# Define the dehazing function
def haze_removal(image, omega=0.85, radius=10, epsilon=0.002):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32)
    atmospheric_light = np.max(gray)
    transmission = 1 - omega * cv2.erode(gray / atmospheric_light, np.ones((radius, radius), np.uint8))
    transmission = np.clip(transmission, 0.1, 1.0)
    refined_transmission = cv2.blur(transmission, (radius, radius))
    dehazed = np.zeros_like(image, dtype=np.float32)
    for i in range(3):
        dehazed[:, :, i] = (image[:, :, i] - atmospheric_light) / refined_transmission + atmospheric_light
    return np.clip(dehazed, 0, 255).astype(np.uint8)

In [None]:
# Load dataset
dataset = load_dataset("MITLL/LADI-v2-dataset", streaming=True, split="test")

# Map labels to indices
label_to_index = {label: i for i, label in enumerate(label_keys)}

# Extract images corresponding to the classes of interest
def get_images_for_classes(dataset, classes, num_images=10):
    selected_images = []
    for example in dataset:
        labels = [example[label] for label in label_keys]
        if any(labels[label_to_index[class_name]] == 1 for class_name in classes):
            selected_images.append(example["image"])
        if len(selected_images) >= num_images:
            break
    return selected_images

In [None]:
# Get images for "flooding_any" and "trees_damage"
test_images = get_images_for_classes(dataset, classes_of_interest, num_images=20)

# Define transformations
image_transforms = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define the Custom DINOv2 Model with Dropout
class CustomDINOv2WithDropout(nn.Module):
    def __init__(self, base_model, num_labels, dropout_rate=0.3):
        super(CustomDINOv2WithDropout, self).__init__()
        self.base_model = base_model
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(self.base_model.config.hidden_size, num_labels)

    def forward(self, pixel_values):
        outputs = self.base_model(pixel_values=pixel_values, output_hidden_states=False)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

# Load the DINOv2 base model from Hugging Face
base_model = AutoModel.from_pretrained("facebook/dinov2-base")
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")

# Initialize the Custom Model
num_labels = len(label_keys)
dropout_rate = 0.4
model = CustomDINOv2WithDropout(base_model=base_model, num_labels=num_labels, dropout_rate=dropout_rate)

# Load the state dictionary from your saved checkpoint
model_path = "/content/drive/MyDrive/Classification/partially_dehazed_final_best_model.pth"
model.load_state_dict(torch.load(model_path))
model.to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()

In [None]:
# Define the number of images per row
images_per_row = 2

# Preprocess and perform inference
def preprocess_and_infer(image, model, apply_dehaze=False):
    if apply_dehaze:
        image = haze_removal(np.array(image))
        image = Image.fromarray(image)

    image_tensor = image_transforms(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")

    with torch.no_grad():
        logits = model(image_tensor)
        probs = torch.sigmoid(logits).cpu().numpy()[0]

    return probs

# Define a probability threshold
probability_threshold = 0.4

# Filter images based on predictions
filtered_images = []
for i, image in enumerate(test_images):
    # Get class name
    class_name = classes_of_interest[i % len(classes_of_interest)]

    # Apply dehazing only for selected classes
    apply_dehaze = class_name in ["flooding_any", "trees_damage"]  # Do not dehaze "buildings_minor_or_greater"

    # Perform inference
    probs = preprocess_and_infer(image, model=model, apply_dehaze=apply_dehaze)

    # Check if the probability for the class of interest exceeds the threshold
    class_prob = probs[label_to_index[class_name]]
    if class_prob >= probability_threshold:
        filtered_images.append((image, class_name, class_prob))

# Display filtered images
num_images = len(filtered_images)
num_rows = (num_images + images_per_row - 1) // images_per_row  # Calculate total rows

plt.figure(figsize=(15, num_rows * 5))  # Adjust figure size dynamically
for i, (image, class_name, class_prob) in enumerate(filtered_images):
    plt.subplot(num_rows, images_per_row, i + 1)
    plt.imshow(image)
    plt.title(f"{class_name}: {class_prob:.2f}")
    plt.axis("off")

plt.tight_layout()
plt.show()
