In [2]:
import torch
from torch import nn
import torch.nn.functional as F

In [87]:
class DoubleBlock(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(DoubleBlock, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

In [12]:
class InConv(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(InConv, self).__init__()
        self.conv = DoubleBlock(in_ch, out_ch)

    def forward(self, x):
        return self.conv(x)

In [66]:
class Down(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(Down, self).__init__()
        self.block = DoubleBlock(in_ch, out_ch)
        self.pool = nn.MaxPool3d(2, return_indices=True)

    def forward(self, x):
        x = self.block(x)
        x, indices = self.pool(x)
        return x, indices

In [88]:
class Up(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(Up, self).__init__()
        self.block = DoubleBlock(in_ch, out_ch)
        self.unpool = nn.MaxUnpool3d(2)

    def forward(self, x, indices):
        x = self.unpool(x, indices)
        x = self.block(x)
        return x

In [89]:
class OutConv(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_ch, out_ch, 1)

    def forward(self, x):
        return self.conv(x)

In [90]:
img = torch.randn(1, 1, 50, 128, 128)

### Start of the Model

In [91]:
class Start(nn.Module):
    
    def __init__(self, n_channels, n_classes):
        super(Start, self).__init__()
        self.inconv = InConv(n_channels, 2)
        self.down1 = Down(2, 4)
        
    def forward(self, x):
        x = self.inconv(x)
        x1, indices1 = self.down1(x)
        return x1, indices1

In [92]:
start = Start(1, 2)

In [93]:
x1, indices1 = start(img)

In [94]:
print(f'x1 shape: {x1.shape}')
print(f'indices1 shape: {indices1.shape}')

x1 shape: torch.Size([1, 4, 25, 64, 64])
indices1 shape: torch.Size([1, 4, 25, 64, 64])


In [116]:
class NanoUnet(nn.Module):
    
    def __init__(self, n_channels, n_classes):
        super(NanoUnet, self).__init__()
        self.inconv = InConv(n_channels, 2)
        self.down = Down(2, 4)
        self.up = Up(4, 2)
        self.outconv = OutConv(2, n_classes)
        
        
    def forward(self, x):
        x = self.inconv(x)
        x, indices = self.down(x)
        x = self.up(x, indices)
        x = self.outconv(x)
        return x

In [117]:
nanonet = NanoUnet(1, 2)

In [118]:
out_nano = nanonet(img)

In [119]:
out_nano.shape

torch.Size([1, 2, 50, 128, 128])

In [110]:
class MicroUnet(nn.Module):
    
    def __init__(self, n_channels, n_classes):
        super(MicroUnet, self).__init__()
        self.inconv = InConv(n_channels, 2)
        self.down1 = Down(2, 4)
        self.down2 = Down(4, 8)
        self.up1 = Up(8, 4)
        self.up2 = Up(4, 2)
        self.outconv = OutConv(2, n_classes)
        
        
    def forward(self, x):
        x = self.inconv(x)
        x1, indices1 = self.down1(x)
        x2, indices2 = self.down2(x1)
        x3 = self.up1(x2, indices2)
        x4 = self.up2(x3, indices1)
        x5 = self.outconv(x4)
        return x5

In [111]:
micronet = MicroUnet(1,2)

In [112]:
out = micronet(img)

RuntimeError: input and indices shapes do not match: input [1 x 4 x 24 x 64 x 64], indices [1 x 4 x 25 x 64 x 64] at /Users/yngtodd/src/checkout/pytorch/aten/src/THNN/generic/VolumetricMaxUnpooling.c:23

In [113]:
out.shape

torch.Size([1, 2, 50, 128, 128])