In [None]:
impopt torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [None]:
training_data = datasets.CIFAR100(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.CIFAR100(
    root='data',
    train=False,
    download=True,
    transform=ToTensor()
)


In [None]:
relu = nn.ReLU()

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        x = relu(x)
        x = self.conv2(x)
        x = relu(x)
        return x


In [None]:
class DecoderConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        # output_size = (input_size - 1) * stride + kernel_size - 2 * padding
        self.up_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) 
        self.conv = DoubleConv(out_channels, out_channels)

    def forward(self, x, x1):
        y1 = self.up_conv(x)
        y2 = torch.cat((y1, x1))
        y = self.conv(y2)
        return y

In [None]:
class UNet(nn.Module):
    def __init__(self, input_channels=1):
        super().__init__()

        # non trainable layers
        self.pool = nn.MaxPool2d(kernel_size=2)
        
        # trainable layers
        ## encoder 
        self.en_conv1 = DoubleConv(input_channels,32)
        self.en_conv2 = DoubleConv(32,64)
        self.en_conv3 = DoubleConv(64,128)
        self.en_conv4 = DoubleConv(128,256)
        
        ## bottleneck
        self.conv = DoubleConv(256,512)

        ## decoder
        self.de_conv4 = DecoderConv(512,256)
        self.de_conv3 = DecoderConv(256,128)
        self.de_conv2 = DecoderConv(128,64)
        self.de_conv1 = DecoderConv(64,32)
        
        ## reconstruct layer
        self.reconstruct = nn.Conv2d(32, 3, kernel_size=3)

    def forward(self, x):
        # encoder inference 
        e1 = self.en_conv1(x)
        p1 = self.pool(e1)
        
        e2 = self.en_conv2(p1)
        p2 = self.pool(e2)
        
        e3 = self.en_conv3(p2)
        p3 = self.pool(e3)
        
        e4 = self.en_conv3(p3)
        p4 = self.pool(e4)
        
        # bottleneck inference
        b = self.conv(e4)
        
        # decoder inference 
        d4 = self.de_conv4(b,e4)
        d3 = self.de_conv3(d4,e3)
        d2 = self.de_conv2(d3,e2)
        d1 = self.de_conv1(d2,e1)

        # reconstruct image
        colored_image = self.reconstruct(d1)


        
        return colored_image