In [7]:
import torch
import torch.nn as nn

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=0),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=0),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.double_conv(x)
    
class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),  
            DoubleConv(in_ch, out_ch)  
        )
    def forward(self, x):
        return self.maxpool_conv(x)

def crop_tensor(tensor, target_tensor):
        target_size = target_tensor.size()[2]
        tensor_size = tensor.size()[2]
        delta = tensor_size - target_size
        delta = delta // 2
        return tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta]
  
class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_ch, out_ch)  

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x2 = crop_tensor(x2, x1)  
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
    
class OutConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1)
    def forward(self, x):
        return self.conv(x)

class Unet(nn.Module):
    def __init__(self, in_ch, n_classes):
        super(Unet, self).__init__()
        self.inc = DoubleConv(in_ch, 64)
        self.down1 = Down(64,128)
        self.down2 = Down(128,256)
        self.down3 = Down(256,512)
        self.down4 = Down(512,1024)
        
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, n_classes)
        
    def forward(self, x):
        print (x.shape)
        
        x1 = self.inc(x)
        print (x1.shape)
        
        x2 = self.down1(x1)
        print (x2.shape)
        
        x3 = self.down2(x2)
        print (x3.shape)
        
        x4 = self.down3(x3)
        print (x4.shape)
        
        x5 = self.down4(x4)
        print (x5.shape)
        print("===="*20)
        
        up1 = self.up1(x5, x4)
        print(up1.shape)
        
        up2 = self.up2(up1, x3)
        print(up2.shape)
        
        up3 = self.up3(up2, x2)
        print(up3.shape)
        
        up4 = self.up4(up3, x1)
        print(up4.shape)
        
        out = self.outc(up4)
        return out
    
if __name__ == '__main__':
    input_data = torch.randn([1, 1, 572, 572])
    unet = Unet(1, 1)
    output = unet(input_data)
    print(output.shape)


torch.Size([1, 1, 572, 572])
torch.Size([1, 64, 568, 568])
torch.Size([1, 128, 280, 280])
torch.Size([1, 256, 136, 136])
torch.Size([1, 512, 64, 64])
torch.Size([1, 1024, 28, 28])
torch.Size([1, 512, 52, 52])
torch.Size([1, 256, 100, 100])
torch.Size([1, 128, 196, 196])
torch.Size([1, 64, 388, 388])
torch.Size([1, 1, 388, 388])
