In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
class DoubleBlock(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(DoubleBlock, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

In [3]:
class InConv(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(InConv, self).__init__()
        self.conv = DoubleBlock(in_ch, out_ch)

    def forward(self, x):
        return self.conv(x)

In [4]:
class Down(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(Down, self).__init__()
        self.block = DoubleBlock(in_ch, out_ch)
        self.pool = nn.MaxPool3d(2, return_indices=True)

    def forward(self, x):
        x = self.block(x)
        x, indices = self.pool(x)
        return x, indices

In [5]:
class Up(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(Up, self).__init__()
        self.block = DoubleBlock(in_ch, out_ch)
        self.unpool = nn.MaxUnpool3d(2)

    def forward(self, x, indices):
        x = self.unpool(x, indices)
        x = self.block(x)
        return x

In [6]:
class UpShape(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(UpShape, self).__init__()
        self.block = DoubleBlock(in_ch, out_ch)
        self.unpool = nn.MaxUnpool3d(2)

    def forward(self, x, indices, output_size):
        x = self.unpool(x, indices, output_size)
        x = self.block(x)
        return x

In [7]:
class OutConv(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_ch, out_ch, 1)

    def forward(self, x):
        return self.conv(x)

### Random Input

In [8]:
img = torch.randn(1, 1, 50, 128, 128)

### Start of the Model

In [91]:
class Start(nn.Module):
    
    def __init__(self, n_channels, n_classes):
        super(Start, self).__init__()
        self.inconv = InConv(n_channels, 2)
        self.down1 = Down(2, 4)
        
    def forward(self, x):
        x = self.inconv(x)
        x1, indices1 = self.down1(x)
        return x1, indices1

In [92]:
start = Start(1, 2)

In [93]:
x1, indices1 = start(img)

In [94]:
print(f'x1 shape: {x1.shape}')
print(f'indices1 shape: {indices1.shape}')

x1 shape: torch.Size([1, 4, 25, 64, 64])
indices1 shape: torch.Size([1, 4, 25, 64, 64])


### Nan0Unet: InConv -> Down -> Up -> OutConv

#### Working model - sans unpooling output_size

In [9]:
class NanoUnet(nn.Module):
    
    def __init__(self, n_channels, n_classes):
        super(NanoUnet, self).__init__()
        self.inconv = InConv(n_channels, 2)
        self.down = Down(2, 4)
        self.up = Up(4, 2)
        self.outconv = OutConv(2, n_classes)
        
        
    def forward(self, x):
        x = self.inconv(x)
        x, indices = self.down(x)
        x = self.up(x, indices)
        x = self.outconv(x)
        return x

In [10]:
nanonet = NanoUnet(1, 2)

In [11]:
out_nano = nanonet(img)

In [12]:
out_nano.shape

torch.Size([1, 2, 50, 128, 128])

#### Model - avec unpooling output_size

In [13]:
class NanoUnetShape(nn.Module):
    
    def __init__(self, n_channels, n_classes):
        super(NanoUnetShape, self).__init__()
        self.inconv = InConv(n_channels, 2)
        self.down = Down(2, 4)
        self.up = UpShape(4, 2)
        self.outconv = OutConv(2, n_classes)
        
        
    def forward(self, x):
        print(f'x shape: {x.shape}')
        x1 = self.inconv(x)
        print(f'x1 shape: {x1.shape}')
        x2, indices = self.down(x1)
        print(f'x2 shape: {x2.shape}')
        x3 = self.up(x2, indices, x.shape) # why x.shape here and not x1.shape?
        x4 = self.outconv(x3)
        return x4

In [14]:
nanonet2 = NanoUnetShape(1, 2)

In [15]:
out_nano2 = nanonet2(img)

x shape: torch.Size([1, 1, 50, 128, 128])
x1 shape: torch.Size([1, 2, 50, 128, 128])
x2 shape: torch.Size([1, 4, 25, 64, 64])


In [16]:
out_nano2.shape

torch.Size([1, 2, 50, 128, 128])

### MicroUnet: InConv -> Down -> Down -> Up -> Up -> OutConv

In [26]:
class MicroUnet(nn.Module):
    
    def __init__(self, n_channels, n_classes):
        super(MicroUnet, self).__init__()
        self.inconv = InConv(n_channels, 2)
        self.down1 = Down(2, 4)
        self.down2 = Down(4, 8)
        self.up1 = UpShape(8, 4)
        self.up2 = UpShape(4, 2)
        self.outconv = OutConv(2, n_classes)
        
        
    def forward(self, x):
        x1 = self.inconv(x)
        x2, indices1 = self.down1(x1)
        x3, indices2 = self.down2(x2)
        
        # temprint
        print(f'x shape: {x.size()}')
        print(f'x1 shape: {x1.size()}')
        print(f'x2 shape: {x2.size()}')
        print(f'x3 shape: {x3.size()}')
        
        x4 = self.up1(x3, indices2, x2.size())
        x5 = self.up2(x4, indices1, x1.size())
        x6 = self.outconv(x5)
        return x6

In [27]:
micronet = MicroUnet(1,2)

In [28]:
out = micronet(img)

x shape: torch.Size([1, 1, 50, 128, 128])
x1 shape: torch.Size([1, 2, 50, 128, 128])
x2 shape: torch.Size([1, 4, 25, 64, 64])
x3 shape: torch.Size([1, 8, 12, 32, 32])
success


In [29]:
out.shape

torch.Size([1, 2, 50, 128, 128])

In [38]:
class CMUnet(nn.Module):
    """Centimeters, not CMU."""
    def __init__(self, n_channels, n_classes):
        super(CMUnet, self).__init__()
        self.inconv = InConv(n_channels, 2)
        self.down1 = Down(2, 4)
        self.down2 = Down(4, 8)
        self.down3 = Down(8, 16)
        self.up1 = UpShape(16, 8)
        self.up2 = UpShape(8, 4)
        self.up3 = UpShape(4, 2)
        self.outconv = OutConv(2, n_classes)
        
        
    def forward(self, x):
        x1 = self.inconv(x)
        x2, indices1 = self.down1(x1)
        x3, indices2 = self.down2(x2)
        x4, indices3 = self.down3(x3)
        
        # temprint
        print(f'x shape:  {x.size()}')
        print(f'x1 shape: {x1.size()}')
        print(f'x2 shape: {x2.size()}')
        print(f'x3 shape: {x3.size()}')
        print(f'x4 shape: {x4.size()}')
        
        x5 = self.up1(x4, indices3, x3.size())
        x6 = self.up2(x5, indices2, x2.size())
        x7 = self.up3(x6, indices1, x1.size())
        x7 = self.outconv(x7)
        return x7

In [39]:
cmunet = CMUnet(1, 2)

In [40]:
out = cmunet(img)

x shape:  torch.Size([1, 1, 50, 128, 128])
x1 shape: torch.Size([1, 2, 50, 128, 128])
x2 shape: torch.Size([1, 4, 25, 64, 64])
x3 shape: torch.Size([1, 8, 12, 32, 32])
x4 shape: torch.Size([1, 16, 6, 16, 16])


In [41]:
out.shape

torch.Size([1, 2, 50, 128, 128])