In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader  # Make sure to include Dataset here
from torchvision import transforms
from PIL import Image
import numpy as np
import copy


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define U-Net Model (same as in your training code)
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()

        def conv_block(in_channels, out_channels):
            block = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )
            return block

        def up_conv(in_channels, out_channels):
            return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

        # Down-sampling layers
        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)

        # Max Pooling layer
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck layer
        self.bottleneck = conv_block(512, 1024)

        # Up-sampling layers
        self.upconv4 = up_conv(1024, 512)
        self.dec4 = conv_block(1024, 512)

        self.upconv3 = up_conv(512, 256)
        self.dec3 = conv_block(512, 256)

        self.upconv2 = up_conv(256, 128)
        self.dec2 = conv_block(256, 128)

        self.upconv1 = up_conv(128, 64)
        self.dec1 = conv_block(128, 64)

        # Output layer
        self.conv_last = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool(enc4))

        # Decoder
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.dec4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.dec3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.dec2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.dec1(dec1)

        return torch.sigmoid(self.conv_last(dec1))

# Define the custom Dataset class
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_names = os.listdir(image_dir)

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_names[idx])
        mask_path = os.path.join(self.mask_dir, self.image_names[idx])

        image = Image.open(image_path).convert('L')  # Convert to grayscale
        mask = Image.open(mask_path).convert('L')  # Binary mask in grayscale

        if self.transform is not None:
            image = self.transform(image)
            mask = self.transform(mask)

        mask = torch.where(mask > 0, torch.tensor(1.0), torch.tensor(0.0))  # Normalize mask to 0 and 1

        return image, mask

# Define evaluation metrics
def calculate_precision(pred, target):
    true_positive = torch.sum(pred * target)
    false_positive = torch.sum(pred * (1 - target))
    return (true_positive + 1e-6) / (true_positive + false_positive + 1e-6)

def calculate_recall(pred, target):
    true_positive = torch.sum(pred * target)
    false_negative = torch.sum((1 - pred) * target)
    return (true_positive + 1e-6) / (true_positive + false_negative + 1e-6)

def calculate_accuracy(pred, target):
    correct = torch.sum(pred == target)
    total = torch.numel(pred)
    return correct.float() / total

def calculate_iou(pred, target):
    intersection = torch.sum(pred * target)
    union = torch.sum(pred) + torch.sum(target) - intersection
    return (intersection + 1e-6) / (union + 1e-6)

def calculate_dice(pred, target):
    intersection = torch.sum(pred * target)
    return (2 * intersection + 1e-6) / (torch.sum(pred) + torch.sum(target) + 1e-6)

# Function to test the model on the test set
def test_unet(model, dataloader):
    model.eval()
    precision_total = 0.0
    recall_total = 0.0
    accuracy_total = 0.0
    iou_total = 0.0
    dice_total = 0.0

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            outputs = torch.where(outputs > 0.5, torch.tensor(1.0).to(device), torch.tensor(0.0).to(device))

            # Calculate metrics
            precision_total += calculate_precision(outputs, masks).item()
            recall_total += calculate_recall(outputs, masks).item()
            accuracy_total += calculate_accuracy(outputs, masks).item()
            iou_total += calculate_iou(outputs, masks).item()
            dice_total += calculate_dice(outputs, masks).item()

    # Average metrics
    num_samples = len(dataloader)
    precision_avg = precision_total / num_samples
    recall_avg = recall_total / num_samples
    accuracy_avg = accuracy_total / num_samples
    iou_avg = iou_total / num_samples
    dice_avg = dice_total / num_samples

    print(f'Precision: {precision_avg:.4f}')
    print(f'Recall: {recall_avg:.4f}')
    print(f'Accuracy: {accuracy_avg:.4f}')
    print(f'IoU: {iou_avg:.4f}')
    print(f'Dice: {dice_avg:.4f}')


## 1. Baseline
#### Test Total

In [2]:

model_path = '/home/tim/Documents/04_Projekt_ConSim/camera_contamination/models/bestmodel_50.pth'


# Paths to test images and masks
rgb_path = '/home/tim/Documents/04_Projekt_ConSim/camera_contamination/data/test/rgb_total'
gt_path = '/home/tim/Documents/04_Projekt_ConSim/camera_contamination/data/test/gt_total'

# Data transformations
transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor()
])

# Create test dataset and dataloader
test_dataset = SegmentationDataset(image_dir=rgb_path, mask_dir=gt_path, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# Load the trained U-Net model
model = UNet(in_channels=1, out_channels=1).to(device)
model.load_state_dict(torch.load(model_path))

# Test the model
test_unet(model, test_loader)


  model.load_state_dict(torch.load(model_path))


Precision: 0.9561
Recall: 0.9510
Accuracy: 0.9537
IoU: 0.9111
Dice: 0.9527


#### Test Semi

In [3]:


# Paths to test images and masks
rgb_path = '/home/tim/Documents/04_Projekt_ConSim/camera_contamination/data/test/rgb_semi'
gt_path = '/home/tim/Documents/04_Projekt_ConSim/camera_contamination/data/test/gt_semi'

# Data transformations
transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor()
])

# Create test dataset and dataloader
test_dataset = SegmentationDataset(image_dir=rgb_path, mask_dir=gt_path, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# Load the trained U-Net model
model = UNet(in_channels=1, out_channels=1).to(device)
model.load_state_dict(torch.load(model_path))

# Test the model
test_unet(model, test_loader)


  model.load_state_dict(torch.load(model_path))


Precision: 0.8636
Recall: 0.9028
Accuracy: 0.8695
IoU: 0.7891
Dice: 0.8788


#### Test General

In [4]:



# Paths to test images and masks
rgb_path = '/home/tim/Documents/04_Projekt_ConSim/camera_contamination/data/dirty_cam/test_rgb'
gt_path = '/home/tim/Documents/04_Projekt_ConSim/camera_contamination/data/dirty_cam/test_gt'

# Data transformations
transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor()
])

# Create test dataset and dataloader
test_dataset = SegmentationDataset(image_dir=rgb_path, mask_dir=gt_path, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# Load the trained U-Net model
model = UNet(in_channels=1, out_channels=1).to(device)
model.load_state_dict(torch.load(model_path))

# Test the model
test_unet(model, test_loader)


  model.load_state_dict(torch.load(model_path))


Precision: 0.7703
Recall: 0.7945
Accuracy: 0.7772
IoU: 0.6544
Dice: 0.7782
