<a href="https://colab.research.google.com/github/tungnhitran/Anomaly-Detection-in-Smart-Manufacturing/blob/main/UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as F


In [20]:
class Encoder_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
      super(Encoder_block, self).__init__()
      self.application=nn.Sequential(
          nn.Conv2d(in_channels, out_channels, kernel_size, padding='same'),
          nn.ReLU(inplace=True),
          nn.Conv2d(out_channels, out_channels, kernel_size, padding='same'),
          nn.ReLU(inplace=True)
      )
    def forward(self, x):
      return self.application(x)

In [21]:
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=2):
        super(UNet, self).__init__()
        # Contracting path
        self.down_b1 = Encoder_block(in_channels, out_channels=64, kernel_size=3)
        self.down_b2 = Encoder_block(64, 128, 3)
        self.down_b3 = Encoder_block(128, 256, 3)
        self.down_b4 = Encoder_block(256, 512, 3)
        self.down_b5 = Encoder_block(512, 1024, 3)
        self.pooling = nn.MaxPool2d(kernel_size=2, stride=2)

        # Expansive path
        self.up_deconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up_deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)

        self.up_b1 = Encoder_block(1024, 512, 3)
        self.up_b2 = Encoder_block(512, 256, 3)
        self.up_b3 = Encoder_block(256, 128, 3)
        self.up_b4 = Encoder_block(128, 64, 3)

        self.out = nn.Conv2d(64, out_channels, kernel_size=1, padding='same')

    def forward(self, x):
        save_size = []
        # Contracting path
        x = self.down_b1(x)
        save_size.append(x)
        x = self.pooling(x)

        x = self.down_b2(x)
        save_size.append(x)
        x = self.pooling(x)

        x = self.down_b3(x)
        save_size.append(x)
        x = self.pooling(x)

        x = self.down_b4(x)
        save_size.append(x)
        x = self.pooling(x)

        x = self.down_b5(x)
        # Expansive path
        x = self.up_deconv1(x)
        y = save_size.pop()
        x = F.resize(x, size=y.size()[2:])
        y_new = torch.cat((y, x), dim=1)
        x = self.up_b1(y_new)

        x = self.up_deconv2(x)
        y = save_size.pop()
        x = F.resize(x, size=y.size()[2:])
        y_new = torch.cat((y, x), dim=1)
        x = self.up_b2(y_new)

        x = self.up_deconv3(x)
        y = save_size.pop()
        x = F.resize(x, size=y.size()[2:])
        y_new = torch.cat((y, x), dim=1)
        x = self.up_b3(y_new)

        x = self.up_deconv4(x)
        y = save_size.pop()
        x = F.resize(x, size=y.size()[2:])
        y_new = torch.cat((y, x), dim=1)
        x = self.up_b4(y_new)

        x = self.out(x)
        return x




In [22]:
from torchsummary import summary
model = UNet()
summary(model, (1, 572, 572))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 572, 572]             640
              ReLU-2         [-1, 64, 572, 572]               0
            Conv2d-3         [-1, 64, 572, 572]          36,928
              ReLU-4         [-1, 64, 572, 572]               0
     Encoder_block-5         [-1, 64, 572, 572]               0
         MaxPool2d-6         [-1, 64, 286, 286]               0
            Conv2d-7        [-1, 128, 286, 286]          73,856
              ReLU-8        [-1, 128, 286, 286]               0
            Conv2d-9        [-1, 128, 286, 286]         147,584
             ReLU-10        [-1, 128, 286, 286]               0
    Encoder_block-11        [-1, 128, 286, 286]               0
        MaxPool2d-12        [-1, 128, 143, 143]               0
           Conv2d-13        [-1, 256, 143, 143]         295,168
             ReLU-14        [-1, 256, 1