In [None]:
import torch # For model
import torch.nn as nn # For model
from torchvision import models
import torchvision.transforms.functional as TF # For resizing if not multiple of 16
from torchvision.models import ResNet18_Weights
import torch.nn.functional as F

import os
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
import random
import albumentations as A # For training
from albumentations.pytorch import ToTensorV2 # For training
from tqdm import tqdm # For training
import torch.optim as optim # For training
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import accuracy_score
import cv2

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

In [None]:
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[32,64,128,256]):
        super(UNET, self).__init__()

        # Used as arrays for dynamic layers
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        # Used for fixed downsampling
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
        
        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                feature*2, feature, kernel_size=2, stride=2)
            )
            self.ups.append(DoubleConv(feature*2,feature))
        
        # Bottom part of UNET
        self.bottom = DoubleConv(features[-1],features[-1]*2)

        # Final ouput of UNET
        self.final_conv = nn.Conv2d(features[0],out_channels, kernel_size=1)
    
    def forward(self, x):
        skip_connections = []

        # Running the Down part of UNET
        for down in self.downs:
            x = down(x)
            # Skip connections saved to be used while upsampling
            skip_connections.append(x)
            x = self.pool(x)

        # Reverse the skip_connections list
        skip_connections = skip_connections[::-1]

        # Running the bottom part of UNET
        x = self.bottom(x)

        # Running the Up part of UNET
        for index in range(0, len(self.ups), 2):
            # Only up used
            x = self.ups[index](x)
            skip_connection = skip_connections[index//2]

            # Concatenate connection along feature axis
            # x will always be smaller or equal
            # like 51 X 51 -> (max pool) -> 25 X 25 -> (up) -> 50 X 50
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])
            skip_added = torch.cat((skip_connection,x), dim=1)

            # DoubleConv used then
            x = self.ups[index+1](skip_added)
        
        return torch.sigmoid(self.final_conv(x)) # To keep values between 0 & 1

In [None]:
class UNET_with_Classifier(nn.Module):
    def __init__(self,in_channels=3, out_channels=1, num_classes=10):
        super(UNET_with_Classifier, self).__init__()

        self.unet = UNET(in_channels=in_channels,out_channels=out_channels)

        # Using resnet for classification of cropped image
        resnet = models.resnet18(ResNet18_Weights.DEFAULT)
        resnet.fc = nn.Linear(resnet.fc.in_features,num_classes)
        self.classifier = resnet
    
    def bounding_box_and_cropping(self, x, mask):
        # Making bounding boxes for each image
        boxes = []
        for idx in range (0,len(mask)):
            m = mask[idx,0]
            indices = (m > 0.5).nonzero(as_tuple=False)
            if indices.size(0) > 0:
                y_min, x_min = indices[:, 0].min(), indices[:, 1].min()
                y_max, x_max = indices[:, 0].max(), indices[:, 1].max()
                boxes.append((y_min, x_min, y_max, x_max))
            else:
                boxes.append((0, 0, mask.shape[2], mask.shape[3]))

        # Cropping and resizing for classification
        crops = []
        for idx in range (0, len(mask)):
            img = x[idx]
            y_min, x_min, y_max, x_max = boxes[idx]
            if (y_max - y_min > 0) and (x_max - x_min > 0):
                cropped = img[:, y_min:y_max, x_min:x_max]
                resized = F.interpolate(cropped.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False)
            else :
                resized = F.interpolate(img.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False)
            crops.append(resized.squeeze(0))   
        return torch.stack(crops)
    
    def forward(self, x):
        mask = self.unet(x)
        cropped_image = self.bounding_box_and_cropping(x, mask)
        flood_level_logits = self.classifier(cropped_image)
        return mask, flood_level_logits

In [None]:
def load_image(path):
    image = Image.open(path).convert("RGB")
    return np.array(image, dtype=np.float32)  # (H, W, 3)

def load_mask(path):
    mask = Image.open(path).convert("L")
    return np.array(mask, dtype=np.float32)[..., np.newaxis]  # (H, W, 1)

class FloodDataset(Dataset):
    def __init__(self, image_paths, mask_paths, levels, transforms=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.levels = levels
        self.transforms = transforms

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = load_image(self.image_paths[idx])
        mask = load_mask(self.mask_paths[idx])
        level = torch.tensor(self.levels[idx], dtype=torch.long)

    # Ensure mask is always [H, W]
        if mask.ndim == 3 and mask.shape[2] == 1:
            mask = mask.squeeze(-1)
        elif mask.ndim != 2:
            raise ValueError(f"Unexpected mask shape at index {idx}: {mask.shape}")
        if image.shape[:2] != mask.shape[:2]:
            mask = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)

        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

    # After transforms, enforce [1, H, W]
        if mask.ndim == 2:
            mask = mask.unsqueeze(0)
        elif mask.ndim == 3 and mask.shape[0] != 1:
            raise ValueError(f"Mask has unexpected shape after transform at index {idx}: {mask.shape}")

        return {
            'image': image,
            'mask': mask,
            'level': level
        }



In [None]:
def train_fn(loader, model, optimizer, segmentation_loss_fn, classification_loss_fn, device):
    model.train()
    total_loss = 0
    total_seg_loss = 0
    total_cls_loss = 0

    loop = tqdm(loader)

    for batch in loop:
        images = batch['image'].to(device)
        true_masks = batch['mask'].to(device)
        flood_levels = batch['level'].to(device)

        if true_masks.shape[1] != 1:
            true_masks = true_masks.permute(0, 3, 1, 2)
        true_masks = true_masks/255.0

        # Forward pass
        pred_masks, pred_levels = model(images)

        # Compute losses
        seg_loss = segmentation_loss_fn(pred_masks, true_masks)
        cls_loss = classification_loss_fn(pred_levels, flood_levels)
        loss = seg_loss + 5*cls_loss

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_seg_loss += seg_loss.item()
        total_cls_loss += cls_loss.item()

        loop.set_postfix(loss=loss.item(), seg_loss=seg_loss.item(), cls_loss=cls_loss.item())

    avg_loss = total_loss / len(loader)
    avg_seg_loss = total_seg_loss / len(loader)
    avg_cls_loss = total_cls_loss / len(loader)

    print(f"Train Loss: {avg_loss:.4f} | Segmentation Loss: {avg_seg_loss:.4f} | Classification Loss: {avg_cls_loss:.4f}")


In [None]:
def test_fn(loader, model, segmentation_loss_fn, classification_loss_fn, saved_samples, device):
    model.eval()
    total_loss = 0
    total_seg_loss = 0
    total_cls_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        loop = tqdm(loader)

        for batch in loop:
            images = batch['image'].to(device)
            true_masks = batch['mask'].to(device)
            flood_levels = batch['level'].to(device)
            
            if true_masks.ndim == 3:
                true_masks = true_masks.unsqueeze(1)
                
            if true_masks.shape[1] != 1:
                true_masks = true_masks.permute(0, 3, 1, 2)
            
            true_masks_normalized = true_masks/255.0

            pred_masks, pred_levels = model(images)

            seg_loss = segmentation_loss_fn(pred_masks, true_masks_normalized)
            cls_loss = classification_loss_fn(pred_levels, flood_levels)
            loss = seg_loss + 5*cls_loss

            total_loss += loss.item()
            total_seg_loss += seg_loss.item()
            total_cls_loss += cls_loss.item()

            loop.set_postfix(loss=loss.item(), seg_loss=seg_loss.item(), cls_loss=cls_loss.item())

            # Save samples for visualization
            if len(saved_samples) < 10:
                for i in range(images.size(0)):
                    if len(saved_samples) >= 10:
                        break
                    img = images[i].cpu().permute(1, 2, 0).numpy()
                    true_mask = true_masks[i].cpu().squeeze().numpy()
                    pred_mask = pred_masks[i].cpu().squeeze().numpy()
                    pred_mask = (pred_mask > 0.5).astype(float)
                    saved_samples.append((img, true_mask, pred_mask))
                    
            preds = torch.argmax(pred_levels, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(flood_levels.cpu().numpy())

    avg_loss = total_loss / len(loader)
    avg_seg_loss = total_seg_loss / len(loader)
    avg_cls_loss = total_cls_loss / len(loader)

    print(f"\nTest Loss: {avg_loss:.4f} | Segmentation Loss: {avg_seg_loss:.4f} | Classification Loss: {avg_cls_loss:.4f}")
    model.train()
    return all_preds, all_labels


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = UNET_with_Classifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
segmentation_loss_fn = nn.BCELoss()
classification_loss_fn = nn.CrossEntropyLoss()
epochs = 25
number_of_images = 2000

image_dir = "/kaggle/input/flooded-cars-and-masks/cropped_car_small/cropped_car_small" # Path to images
mask_dir = "/kaggle/input/flooded-cars-and-masks/cropped_mask_small/cropped_mask_small" # Path to masks
image_filenames = os.listdir(image_dir)
# random.shuffle(image_filenames)
# selected_images = image_filenames[:number_of_images]

image_paths = []
mask_paths = []
levels = []
for filename in image_filenames:
    # Image path
    image_path = os.path.join(image_dir, filename)
    
    # Convert image filename to mask filename (change .jpg to .png)
    mask_filename = filename.replace(".jpg", ".png")
    mask_path = os.path.join(mask_dir, mask_filename)

    # Extract the level from filename
    level = int(filename.split('_')[2])  # The third part, i.e. after second '_' (index 2)

    image_paths.append(image_path)
    mask_paths.append(mask_path)
    levels.append(level)

# Split 80% train, 20% test
train_img_paths, test_img_paths, train_mask_paths, test_mask_paths, train_levels, test_levels = train_test_split(
    image_paths,
    mask_paths,
    levels,
    test_size=0.2,        # 20% test
    random_state=42,      # For reproducibility
    shuffle=True
)

# Define transformations
train_transforms = A.Compose([
    A.Resize(320, 320),
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2(),
])
train_dataset = FloodDataset(train_img_paths, train_mask_paths, train_levels, transforms=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

for epoch in range(epochs):
    print(f"Epoch [{epoch+1}/{epochs}]")
    train_fn(train_loader, model, optimizer, segmentation_loss_fn, classification_loss_fn, device)

In [None]:
# import os
# import numpy as np
# import matplotlib.pyplot as plt
# from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score

# # Paths
# test_img_dir2 = "/kaggle/input/flooded-cars-2/images_flood_small"
# test_mask_dir2 = "/kaggle/input/flooded-cars-and-masks/cropped_mask_small/cropped_mask_small"

# # Load filenames
# image_filenames = sorted(os.listdir(test_img_dir2))
# mask_filenames = sorted(os.listdir(test_mask_dir2))

# min_size = min(len(image_filenames), len(mask_filenames))

# image_filenames = image_filenames[:min_size]
# mask_filenames = mask_filenames[:min_size]

# # Create paths
# test_img_paths2 = [os.path.join(test_img_dir2, filename) for filename in image_filenames]
# test_mask_paths2 = [os.path.join(test_mask_dir2, filename) for filename in mask_filenames]

# # Dummy flood levels (adjust if you have actual levels)
# test_levels2 = np.zeros(min_size)

# Dataset and DataLoader
test_dataset = FloodDataset(test_img_paths, test_mask_paths, test_levels2, transforms=train_transforms)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, num_workers=4)
for batch in test_loader:
    images = batch['image']
    true_masks = batch['mask']
    flood_levels = batch['level']

    print(f"Image shape: {images.shape}")
    print(f"Mask shape: {true_masks.shape}")
    print(f"Flood levels shape: {flood_levels.shape}")
    break

# Run test
saved_samples = []
all_preds, all_labels = test_fn(test_loader, model, segmentation_loss_fn, classification_loss_fn, saved_samples, device)

# Plot sample predictions
for idx, (img, true_mask, pred_mask) in enumerate(saved_samples):
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(img)
    axs[0].set_title(f'Sample {idx+1}: Input Image')
    axs[1].imshow(true_mask, cmap='gray')
    axs[1].set_title('True Mask')
    axs[2].imshow(pred_mask, cmap='gray')
    axs[2].set_title('Predicted Mask')
    plt.show()

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap="Blues")
plt.title("Confusion Matrix")
plt.show()

# Accuracy
accuracy = accuracy_score(all_labels, all_preds)
print(f"Classification Accuracy: {accuracy * 100:.2f}%")
