# Setup

In [19]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision.transforms import v2
from torchvision.models.resnet import ResNet18_Weights
from torchvision.transforms.functional import InterpolationMode
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import glob
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dataloader

In [None]:
class DMCDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None, mask_transform=None):
        '''
        '''
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.mask_transform = mask_transform
        # print('Dataloader initialized')

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the image and mask
        # print('Loading image:', self.image_paths[idx])
        image = Image.open(self.image_paths[idx])
        mask = Image.open(self.mask_paths[idx])
        # print('Images loaded')

        # Convert both image and mask to tensors, apply any specified transformations
        if self.transform:
            image = self.transform(image)
            mask = self.mask_transform(mask)

        # print('Images transformed to tensors')

        # Ensure mask is a binary tensor (0 or 1 values)
        # mask = (mask > 0.5).float() # Threshold mask to binary if not already

        # print('Mask thresholded')

        return image, mask

In [3]:
transform = v2.Compose([
    v2.Resize((512, 512), interpolation=InterpolationMode.BILINEAR), # Resize to match model input
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ResNet normalization
])

mask_transform = v2.Compose([
    v2.Resize((512, 512), interpolation=InterpolationMode.BILINEAR), # Resize to match model input
    v2.Grayscale(num_output_channels=1), # ensure mask is grayscale
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

In [4]:
train_size = 1000
test_size = 100
batch_size = 4

# Get paths to all images and masks
train_image_paths = glob.glob('../data/synth_data/noisy/*.png')[:train_size]
train_mask_paths = glob.glob('../data/synth_data/ground_truth/*.png')[:train_size]

test_image_paths = glob.glob('../data/synth_data/noisy/*.png')[-test_size:]
test_mask_paths = glob.glob('../data/synth_data/ground_truth/*.png')[-test_size:]

# Create dataset and dataloader
train_dataset = DMCDataset(train_image_paths, train_mask_paths, transform=transform, mask_transform=mask_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = DMCDataset(test_image_paths, test_mask_paths, transform=transform, mask_transform=mask_transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [5]:
# train dataset example
for i, (image, mask) in enumerate(train_dataset):
    print(image.shape, mask.shape)
    break

# train Dataloader example
for i, (image, mask) in enumerate(train_dataloader):
    print(image.shape, mask.shape)
    image = v2.ToPILImage()(image[0].squeeze(0))
    # image.show()

    mask = v2.ToPILImage()(mask[0].squeeze(0))
    # mask.show()
    break

torch.Size([3, 512, 512]) torch.Size([1, 512, 512])
torch.Size([4, 3, 512, 512]) torch.Size([4, 1, 512, 512])


In [6]:
# test dataset example
for i, (image, mask) in enumerate(test_dataset):
    print(image.shape, mask.shape)
    break

# test Dataloader example
for i, (image, mask) in enumerate(test_dataloader):
    print(image.shape, mask.shape)
    image = v2.ToPILImage()(image[0].squeeze(0))
    image.show()

    mask = v2.ToPILImage()(mask[0].squeeze(0))
    mask.show()
    break

torch.Size([3, 512, 512]) torch.Size([1, 512, 512])
torch.Size([4, 3, 512, 512]) torch.Size([4, 1, 512, 512])


# Modifying Architecture

In [7]:
class binarizer(nn.Module):
    def __init__(self):
        super(binarizer, self).__init__()

        # Load the pretrained ResNet-18 model
        resnet18 = models.resnet18(weights=ResNet18_Weights.DEFAULT)

        # Use the ResNet layers up to the last layer (remove the fully connected layer)
        self.encoder = nn.Sequential(*list(resnet18.children())[:-2]) # output size = 512 x 8 x 8

        # Define the decoder part to upsample back to 512 x 512
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid() # Use sigmoid for binary output in range [0, 1]
        )

    def forward(self, x):
        '''Pass input through the resnet-18 encoder and custom decoder'''
        # print('FORWARD A')
        x = self.encoder(x)
        # print('FORWARD B')
        x = self.decoder(x)
        return x

# Training Loop

In [8]:
# setup
model = binarizer().to(device)

# params
lr = 0.0001
num_epochs = 10

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
# Loss function - balances between two methods
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, output, target):
        smooth = 1.0  # To avoid division by zero
        output = output.contiguous().view(-1)
        target = target.contiguous().view(-1)
        intersection = (output * target).sum()
        dice = (2. * intersection + smooth) / (output.sum() + target.sum() + smooth)
        return 1 - dice

class BCEDiceLoss(nn.Module):
    def __init__(self, bce_weight=1, dice_weight=1):
        super(BCEDiceLoss, self).__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()

    def forward(self, output, target):
        bce_loss = self.bce(output, target)
        dice_loss = self.dice(output, target)
        return self.bce_weight * bce_loss + self.dice_weight * dice_loss

In [10]:
criterion = BCEDiceLoss(bce_weight=0.5, dice_weight=1.0)

consider:
- adding early stopping so that we can increase epochs to a huge amount and leave running
- use more of the dataset (90/10 train/test split?)
- tweaking params (weight DMC restoration more than background denoising?)
- replace nn.BCEWithLogitsLoss with BCELoss as sigmoid layer redundant
- Look into weight argument of BCELoss, can we weigh dark pixels more?

In [11]:
# Actual training loop
for epoch in range(num_epochs):
    # print('EPOCH', epoch)
    model.train()
    # print('ok')
    running_loss = 0.0

    count = 0
    for inputs, masks in train_dataloader: # consider renaming masks to targets
        inputs, masks = inputs.to(device), masks.to(device)

        # print('forward pass')
        # Forward pass
        outputs = model(inputs) # Add batch dimension to input

        # print('loss calculate')
        # Calculate loss
        loss = criterion(outputs, masks)

        # print('backwards pass')
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        count += batch_size
        print(f'{count}/{train_size}', end='\r')

        # if count >= 10:
            # break
    # break

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_dataloader):.4f}')

Epoch [1/10], Loss: 0.2312
Epoch [2/10], Loss: 0.1757
Epoch [3/10], Loss: 0.1752
Epoch [4/10], Loss: 0.1734
Epoch [5/10], Loss: 0.1721
Epoch [6/10], Loss: 0.1714
Epoch [7/10], Loss: 0.1711
Epoch [8/10], Loss: 0.1708
Epoch [9/10], Loss: 0.1705
Epoch [10/10], Loss: 0.1703


# Evaluation

In [26]:
# Assuming you have a test DataLoader providing (images, ground_truth_masks)
def evaluate_model(model, test_loader, threshold=0.5):
    model.eval()  # Set model to evaluation mode

    iou_total = 0
    dice_total = 0
    accuracy_total = 0
    count = 0

    with torch.no_grad():  # No need to compute gradients for evaluation
        for images, ground_truths in test_loader:
            # Move data to the same device as the model
            images = images.to(device)
            ground_truths = ground_truths.to(device)

            # Forward pass
            outputs = model(images)

            # Binarize the outputs
            predictions = (outputs > threshold).float() # Binarize predictions

            # predictions = outputs # inspect the output pre binarization

            # Calculate IoU
            intersection = (predictions * ground_truths).sum((1, 2, 3))
            union = (predictions + ground_truths).sum((1, 2, 3)) - intersection
            iou = (intersection / union).mean()  # Mean IoU for batch
            iou_total += iou.item()

            # Calculate Dice Coefficient
            dice = (2 * intersection / (predictions.sum((1, 2, 3)) + ground_truths.sum((1, 2, 3)))).mean()
            dice_total += dice.item()

            # Calculate Pixel Accuracy
            correct = (predictions == ground_truths).float().sum()
            accuracy = correct / ground_truths.numel()
            accuracy_total += accuracy.item()

            count += batch_size
            print(f'{count}/{test_size}', end='\r')

            if count >= 100:
                break

    # debug print
    print(f'Output range: min={outputs.min().item()}, max={outputs.max().item()}')
    print(f'Output range after binarization: min={predictions.min().item()}, max={predictions.max().item()}')

    # Average the metrics over the whole test set
    mean_iou = iou_total / count
    mean_dice = dice_total / count
    mean_accuracy = accuracy_total / count

    print(f'Mean IoU: {mean_iou:.4f}')
    print(f'Mean Dice Coefficient: {mean_dice:.4f}')
    print(f'Mean Pixel Accuracy: {mean_accuracy:.4f}')

    # show a few examples
    for i in range(1):
        image = v2.ToPILImage()(images[i].squeeze(0))
        mask = v2.ToPILImage()(ground_truths[i].squeeze(0))
        pred = v2.ToPILImage()(predictions[i].squeeze(0))
        # image.show()
        # time.sleep(1)
        # mask.show()
        # time.sleep(1)
        # pred.show()

        image.save('../figures/binarization/failure_noisy.png')
        mask.save('../figures/binarization/failure_ground_truth.png')
        pred.save('../figures/binarization/failure_prediction.png')

    return mean_iou, mean_dice, mean_accuracy

evaluate_model(model, test_dataloader, threshold=0.5)

Output range: min=3.484792654383367e-24, max=1.0
Output range after binarization: min=0.0, max=1.0
Mean IoU: 0.2467
Mean Dice Coefficient: 0.2483
Mean Pixel Accuracy: 0.2445


(0.24670674264431, 0.2483392882347107, 0.24445390701293945)

In [27]:
# Save the model
torch.save(model.state_dict(), 'model.pth')