In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet, self).__init__()

        self.inc = DoubleConv(in_channels, 64)
        self.down1 = nn.Sequential(
            nn.AvgPool2d(kernel_size=2, stride=2),
            DoubleConv(64, 128)
        )
        self.down2 = nn.Sequential(
            nn.AvgPool2d(kernel_size=2, stride=2),
            DoubleConv(128, 256)
        )
        self.down3 = nn.Sequential(
            nn.AvgPool2d(kernel_size=2, stride=2),
            DoubleConv(256, 512)
        )
        self.down4 = nn.Sequential(
            nn.AvgPool2d(kernel_size=2, stride=2),
            DoubleConv(512, 1024)
        )

        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
            DoubleConv(1024, 512)
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            DoubleConv(512, 256)
        )
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            DoubleConv(256, 128)
        )
        self.up4 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            DoubleConv(128, 64)
        )

        self.outc = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(torch.cat([F.interpolate(x5, scale_factor=2, align_corners=True), x4], dim=1))
        x = self.up2(torch.cat([F.interpolate(x, scale_factor=2, align_corners=True), x3], dim=1))
        x = self.up3(torch.cat([F.interpolate(x, scale_factor=2, align_corners=True), x2], dim=1))
        x = self.up4(torch.cat([F.interpolate(x, scale_factor=2, align_corners=True), x1], dim=1))

        final = nn.Sigmoid(self.outc(x))
        return final