In [2]:
import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init()
        
        # Contracting Part
        self.encoder = nn.ModuleList()
        for channels in [64, 128, 256, 512]:
            self.encoder.append(self.conv_block(in_channels, channels))
            in_channels = channels
        
        # Connecting Part
        self.connect = self.conv_block(512, 1024)
        
        # Expansive Part
        self.decoder = nn.ModuleList()
        for channels in [512, 256, 128, 64]:
            self.decoder.append(self.upconv_block(in_channels, channels))
            in_channels = channels
        
        # Final Part
        self.final = nn.Conv2d(64, out_channels, 1)
    
    def conv_block(self, in_channels, out_channels, has_maxpooling = True):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)    
        )if has_maxpooling == True else nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU() 
        ) 
        
    
    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2),
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU()
        )
    
    def forward(self, x):
        # Contracting Part
        encoder_outputs = []
        for encoder in self.encoder:
            x = encoder(x)
            encoder_outputs.append(x)
        
        # Connecting Part
        x = self.connect(x)
        
        # Expansive Part
        for decoder, encoder_output in zip(self.decoder, reversed(encoder_outputs)):
            x = decoder(x)
            # Concatenate with the corresponding cropped feature map
            x = torch.cat((x, encoder_output), dim=1)
        
        # Final Part
        x = self.final(x)
        return x

# Create a U-Net model with 1 input channel and 2 output channels for binary segmentation
model = UNet(in_channels=1, out_channels=2)

# Print the model architecture
print(model)


tensor([1, 2, 3, 4, 5])
