In [None]:
import torch
from torch import nn

In [None]:
# real image segmentation U-Net model
# for image size of 288x288
class imgsegUnet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.step_down_1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2,2), stride=2)
        )

        self.step_down_2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2,2), stride=2)
        )

        self.step_down_3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2,2), stride=2)
        )

        self.step_down_4 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(3,3), stride=3)
        )


        self.step_up_4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 128, kernel_size=(3,3), stride=3)
        )

        new_channels = 128+128

        self.step_up_3 = nn.Sequential(
            nn.Conv2d(new_channels, new_channels, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.Conv2d(new_channels, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 128, kernel_size=(2,2), stride = 2),
            nn.ReLU(inplace=True),
        )

        new_channels = 128+64

        self.step_up_2 = nn.Sequential(
            nn.Conv2d(new_channels, new_channels, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.Conv2d(new_channels, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 64, kernel_size=(2,2), stride = 2),
            nn.ReLU(inplace=True),
        )

        new_channels = 64+64

        self.step_up_1 = nn.Sequential(
            nn.Conv2d(new_channels, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 64, kernel_size=(2,2), stride = 2),
            nn.ReLU(inplace=True),
        )

        self.class_predictions_block = nn.Sequential(
            nn.Conv2d(64, 16, kernel_size = (3,3), padding = 1),
            nn.Conv2d(16, out_channels, kernel_size = (3,3), padding = 1)
        )

        self.drop = nn.Dropout2d(p=0.1)


    def forward(self, x):

        output = self.step_down_1(x)
        skip_1 = output

        output = self.step_down_2(output)
        skip_2 = output

        output = self.step_down_3(output)
        skip_3 = output

        output = self.step_down_4(output)

        output = self.drop(output)

        output = self.step_up_4(output)
        output = torch.cat((output, skip_3), dim=1)

        output = self.step_up_3(output)
        output = torch.cat((output, skip_2), dim=1)

        output = self.step_up_2(output)
        output = torch.cat((output, skip_1), dim=1)

        output = self.step_up_1(output)

        output = self.drop(output)

        output = self.class_predictions_block(output)

        return output

In [None]:
# real image segmentation U-Net model mach 2
# for image size of 288x288
class imgSegUnet_m2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.step_down_1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2,2), stride=2)
        )

        self.step_down_2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2,2), stride=2)
        )

        self.step_down_3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2,2), stride=2)
        )

        self.step_down_4 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(3,3), stride=3)
        )


        self.step_up_4 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 128, kernel_size=(3,3), stride=3)
        )

        new_channels = 128+128

        self.step_up_3 = nn.Sequential(
            nn.Conv2d(new_channels, new_channels, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.Conv2d(new_channels, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 128, kernel_size=(2,2), stride = 2),
            nn.ReLU(inplace=True),
        )

        new_channels = 128+64

        self.step_up_2 = nn.Sequential(
            nn.Conv2d(new_channels, new_channels, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.Conv2d(new_channels, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 64, kernel_size=(2,2), stride = 2),
            nn.ReLU(inplace=True),
        )

        new_channels = 64+64

        self.step_up_1 = nn.Sequential(
            nn.Conv2d(new_channels, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 64, kernel_size=(2,2), stride = 2),
            nn.ReLU(inplace=True),
        )

        self.class_predictions_block = nn.Sequential(
            nn.Conv2d(64, 16, kernel_size = (3,3), padding = 1),
            nn.Conv2d(16, out_channels, kernel_size = (3,3), padding = 1)
        )

        self.drop = nn.Dropout2d(p=0.1)


    def forward(self, x):

        output = self.step_down_1(x)
        skip_1 = output

        output = self.step_down_2(output)
        skip_2 = output

        output = self.step_down_3(output)
        skip_3 = output

        output = self.step_down_4(output)

        output = self.drop(output)

        output = self.step_up_4(output)
        output = torch.cat((output, skip_3), dim=1)

        output = self.step_up_3(output)
        output = torch.cat((output, skip_2), dim=1)

        output = self.step_up_2(output)
        output = torch.cat((output, skip_1), dim=1)

        output = self.step_up_1(output)

        output = self.drop(output)

        output = self.class_predictions_block(output)

        return output

In [None]:
# Not the right implementation, output channels is set to 1 always at the end, I don't use the 'out_channels' parameter.
# Only keeping it because one model uses this and I trained it a bunch and it works because it needs out_channels=1

class image_seg_UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.step_down_1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2,2), stride=2)
        )

        self.step_down_2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2,2), stride=2)
        )

        self.step_down_3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2,2), stride=2)
        )

        self.step_down_4 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(3,3), stride=3)
        )


        self.step_up_4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 128, kernel_size=(3,3), stride=3)
        )

        new_channels = 128+128

        self.step_up_3 = nn.Sequential(
            nn.Conv2d(new_channels, new_channels, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.Conv2d(new_channels, 128, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 128, kernel_size=(2,2), stride = 2),
            nn.ReLU(inplace=True),
        )

        new_channels = 128+64

        self.step_up_2 = nn.Sequential(
            nn.Conv2d(new_channels, new_channels, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.Conv2d(new_channels, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 64, kernel_size=(2,2), stride = 2),
            nn.ReLU(inplace=True),
        )

        new_channels = 64+64

        self.step_up_1 = nn.Sequential(
            nn.Conv2d(new_channels, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, kernel_size=(3,3), padding = 1),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 64, kernel_size=(2,2), stride = 2),
            nn.ReLU(inplace=True),
        )

        self.class_predictions_block = nn.Sequential(
            nn.Conv2d(64, 16, kernel_size = (3,3), padding = 1),
            nn.Conv2d(16, 1, kernel_size = (3,3), padding = 1)
        )


    def forward(self, x):

        output = self.step_down_1(x)
        skip_1 = output

        output = self.step_down_2(output)
        skip_2 = output

        output = self.step_down_3(output)
        skip_3 = output

        output = self.step_down_4(output)

        output = self.step_up_4(output)
        output = torch.cat((output, skip_3), dim=1)

        output = self.step_up_3(output)
        output = torch.cat((output, skip_2), dim=1)

        output = self.step_up_2(output)
        output = torch.cat((output, skip_1), dim=1)

        output = self.step_up_1(output)

        output = self.class_predictions_block(output)

        return output