In [None]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import random 
from PIL import Image

# Pytorch
import torch
import torch.nn as nn 
import torchvision.transforms as tf
import torchvision.transforms.functional as tf_f
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.data.sampler import SubsetRandomSampler

# UNET

- Encoder Klasse: channels_in = 1, channels_out=64 
- Decoder Klasse: channels= 64 
- Bottleneck Klasse: damit man leichter auswechseln kann, nicht hartgecoded

In [None]:
def convBlock(in_channels, out_channels, kernel_size=3, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

In [None]:
#TODO maybe add depth 

class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = convBlock(in_channels, out_channels)

    def forward(self, x):
        x = self.conv(x)
        return x


In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.e1 = convBlock(in_channels, out_channels)
        self.e2 = convBlock(out_channels, out_channels * 2)
        self.e3 = convBlock(out_channels * 2, out_channels * 4)
        self.e4 = convBlock(out_channels * 4, out_channels * 8)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x1 = self.e1(x)
        x2 = self.e2(self.pool(x1))
        x3 = self.e3(self.pool(x2))
        x4 = self.e4(self.pool(x3))
        p = self.pool(x4)
        return p, x1, x2, x3, x4    # p will be used for bottleneck, x1, x2, x3, x4 will be used for skip connections

In [None]:
class DecoderInterpolated(nn.Module):
    def __init__(self, classes: int, channels: int = 64):
        super().__init__()
        # d4 -> d3
        self.d4 = convBlock(channels * 16, channels * 8)
        # d3 -> d2
        self.d3 = convBlock(channels * 8, channels * 4)
        # d2 -> d1
        self.d2 = convBlock(channels * 4, channels * 2)
        # d1 -> output
        self.d1 = convBlock(channels * 2, channels)

        self.output = nn.Conv2d(channels, classes, kernel_size=1)

    def forward(self, x, skip1, skip2, skip3, skip4):
        # d4
        x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = torch.cat((x, skip4), dim=1)  
        x = self.d4(x)
        # d3
        x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = torch.cat((x, skip3), dim=1)  
        x = self.d3(x)
        # d2
        x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = torch.cat((x, skip2), dim=1)  
        x = self.d2(x)
        # d1
        x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = torch.cat((x, skip1), dim=1)
        x = self.d1(x)
        # output
        x = self.output(x)

        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self, classes:int ,channels:int = 64):
        super().__init__()
        # d4 -> d3
        self.d4 = convBlock(channels * 16, channels * 8)
        self.up4 = nn.ConvTranspose2d(channels * 16, channels * 8, kernel_size=2, stride=2)
        # d3 -> d2
        self.d3 = convBlock(channels * 8, channels * 4)
        self.up3 = nn.ConvTranspose2d(channels * 8, channels * 4, kernel_size=2, stride=2)
        # d2 -> d1
        self.d2 = convBlock(channels * 4, channels * 2)
        self.up2 = nn.ConvTranspose2d(channels * 4, channels * 2, kernel_size=2, stride=2)
        # d1 -> output
        self.d1 = convBlock(channels * 2, channels)
        self.up1 = nn.ConvTranspose2d(channels * 2, channels, kernel_size=2, stride=2)

        self.output = nn.Conv2d(channels, classes, kernel_size=1)

    def forward(self, x, skip1, skip2, skip3, skip4):
        # d4 -> d3  
        x = self.up4(x)
        x = torch.cat((x, skip4), dim=1)  
        x = self.d4(x)
        # d3 -> d2
        x = self.up3(x)
        x = torch.cat((x, skip3), dim=1)  
        x = self.d3(x)
        # d2 -> d1
        x = self.up2(x)
        x = torch.cat((x, skip2), dim=1)  
        x = self.d2(x)
        # d1 -> output
        x = self.up1(x)
        x = torch.cat((x, skip1), dim=1)
        x = self.d1(x)
        # output
        x = self.output(x)

        return x

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, classes=1, channels=64, bilinear=False):
        super().__init__()
        self.encoder = Encoder(in_channels, channels)
        self.bottleneck = Bottleneck(channels * 8, channels * 16)
        self.decoder = DecoderInterpolated(classes, channels) if bilinear else Decoder(classes, channels)

    def forward(self, x):
        p, x1, x2, x3, x4 = self.encoder(x)
        p = self.bottleneck(p)
        out = self.decoder(p, x1, x2, x3, x4)
        return out

Preprocessing

In [None]:
class Preprocessing:
    def __init__(self, path, size=(256, 256),seed = 42, normalize=True):
        self.path = path
        self.size = size

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        # GPU 
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True

        # Split dataset
        dataset = self.get_dataset(normalize)
        self.g = torch.Generator(manual_seed=seed)
        train_ds,val_ds,test_ds = self.split_dataset(dataset)

        # Masking Datasets
        


    def split_dataset(self,dataset, ratio=(0.75, 0.15), seed=42):
        self.g = torch.Generator(manual_seed=seed)
        return random_split(dataset, [ratio[0], ratio[1], 1 - ratio[0] - ratio[1]], generator=self.g)

    def get_dataset(self, normalize):
        # Transformation für die Bilder
        transform = tf.Compose([
            tf.Grayscale(num_output_channels=1), 
            tf.Resize(self.size),
            tf.ToTensor()
        ])
        # Falls gewünscht Normalisieren
        if normalize:
            transform.transforms.append(tf.Normalize([0.5], [0.5]))
        return ImageFolder(self.path, transform=transform)