In [None]:
from turtle import shape
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import numpy as np

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.BatchNorm3d(),
            nn.Conv3d(in_channels, out_channels, 1, 1, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm3d(),
            nn.Conv3d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.ReLU(inplace=True),
        ).to('cuda')

    def forward(self, x):
        return self.conv(x)

def ConvBlock(in_channels, out_channels):
    block = nn.Sequential()
    block.add_module(
        nn.BatchNorm3d(),
        nn.ReLU(inplace=True),
        nn.Conv3d(in_channels, out_channels, 1),
        nn.BatchNorm3d(),
        nn.ReLU(inplace=True),
        nn.Conv3d(in_channels, out_channels, 3)
    )

def TransitionBlock(in_channels, out_channels):
    block = nn.Sequential()
    block.add_module(
        nn.BatchNorm3d(),
        nn.ReLU(),
        nn.Conv3d(in_channels, out_channels, 1),
        nn.AvgPool3d(2, 2)
    )

    return block

class DenseBlock(nn.Module):
    def __init__(self, num_conva, in_channels, out_channels, **kwargs) -> None:
        super(DenseBlock, self).__init__(**kwargs)
        self.net = nn.Sequential()

        for _ in range(num_conva):
            self.net.add_module(ConvBlock(in_channels, out_channels))

    def forward(self, x):
        for block in self.net:
            y = block(x)
            x = torch.concat(x, y, dim=1)
        
        return x

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[96, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose3d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv3d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

def test():
    x = torch.randn((3, 1, 161, 161))
    model = UNET(in_channels=1, out_channels=1)
    preds = model(x)
    assert preds.shape == x.shape

if __name__ == "__main__":
    test()

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from easydict import EasyDict

train_type = EasyDict({
    "TRAIN" : 0,
    "VALIDATION" : 1,
    "TEST" : 2
})

data_type = EasyDict({
    "FLAIR" : 0,
    "MPRAGE" : 1,
    "PDW" : 2,
    "T2W" : 3,
    "MASK" : 4
})

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, growthRate = 12, dropRate = 0.0):
        #input dimsnsion을 정하고, output dimension을 정하고(growh_rate임), dropRate를 정함.
        super(BasicBlock, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 7, stride = 2, padding = 0, bias = False)
        self.bn = nn.BatchNorm3d(in_channels)
        self.relu = nn.ReLU(inplace = True) # inplace 하면 input으로 들어온 것 자체를 수정하겠다는 뜻. 메모리 usage가 좀 좋아짐. 하지만 input을 없앰.
        self.max_pool = nn.MaxPool3d(3, 2)
        self.droprate = dropRate
        
    def forward(self,x):
        out = self.max_pool(self.conv(self.relu(self.bn(x))))
        if self.droprate > 0:
            out = F.dropout (out, p = self.droprate, training = self.training)
        
        return out
        # return torch.cat([x, out], dim=1)
        
class BottleneckBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropRate = 0.0):
        #out_channels => growh_rate를 입력으로 받게 된다.
        super(BottleneckBlock, self).__init__()
        inter_planes = out_channels * 4 # bottleneck layer의 conv 1x1 filter chennel 수는 4*growh_rate이다.
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace = True)
        self.conv1 = nn.Conv2d(in_channels, inter_planes, kernel_size=1,stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(inter_planes)
        self.conv2 = nn.Conv2d(inter_planes, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.droprate = dropRate
        
    def forward(self, x):
        out = self.conv1(self.relu(self.bn1(x)))
        if self.droprate > 0:
            out = F.dropout(out, p = self.droprate, inplace = False, training = self.training)
        out = self.conv2(self.relu(self.bn2(out)))
        if self.droprate > 0:
            out = F.dropout(out, p = self.droprate, inplace = False, training = self.training)
        return torch.cat([x, out], 1) # 입력으로 받은 x와 새로 만든 output을 합쳐서 내보낸다

class DenseBlock(nn.Module):
    def __init__(self, nb_layers, in_channels, growh_rate, block,dropRate = 0.0):
        super(DenseBlock, self).__init__()
        self.layer = self._make_layer(block, in_channels, growh_rate, nb_layers, dropRate)
    
    def _make_layer(self, block, in_channels, growh_rate, nb_layers, dropRate):
        layers = []
        for i in range(nb_layers):
            layers.append(block(in_channels + i*growh_rate, growh_rate, dropRate))
            
        return nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layer(x)

class TransitionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropRate=0.0):
        super(TransitionBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace = True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.droprate = dropRate
        
    def forward(self, x):
        out = self.conv1(self.relu(self.bn1(x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, inplace = False, training = self.training)
        return F.avg_pool2d(out, 2)

class DenseNet(nn.Module):
    def __init__(self, depth, num_classes, growh_rate = 12, reduction = 0.5, bottleneck = True, dropRate = 0.0):
        super(DenseNet, self).__init__()
        num_of_blocks = 3
        in_channels = 16 # 2 * growh_rate
        n = (depth - num_of_blocks - 1) / num_of_blocks # 총 depth에서 첫 conv , 2개의 transit , 마지막 linear 빼고 / num_of_blocks
        if reduction != 1 :
            in_channels = 2 * growh_rate
        if bottleneck == True:
            in_channels = 2 * growh_rate #논문에서 Bottleneck + Compression 할 경우 first layer은 2*growh_rate라고 했다.
            n = n/2 # conv 1x1 레이어가 추가되니까 !
            block = BottleneckBlock 
        else:
            block = BasicBlock
        
        n = int(n) #n = DenseBlock에서 block layer 개수를 의미한다.
        self.conv1 = nn.Conv2d(3, in_channels, kernel_size = 3, stride = 1, padding = 1, bias = False) # input:RGB -> output:growhR*2
        
        #1st block
        # nb_layers,in_planes,growh_rate,block,dropRate
        self.block1 = DenseBlock(n, in_channels, growh_rate, block, dropRate)
        in_planes = int(in_planes + n*growh_rate) # 입력 + 레이어 만큼의 growh_rate
        
        # in_planes,out_planes,dropRate
        self.trans1 = TransitionBlock(in_channels, int(math.floor(in_channels*reduction)), dropRate = dropRate)
        in_planes = int(math.floor(in_channels*reduction))
        
        
        #2nd block
        # nb_layers,in_planes,growh_rate,block,dropRate
        self.block2 = DenseBlock(n, in_channels, growh_rate, block,dropRate)
        in_planes = int(in_channels + n*growh_rate) # 입력 + 레이어 만큼의 growh_rate
        
        # in_planes,out_planes,dropRate
        self.trans2 = TransitionBlock(in_channels, int(math.floor(in_channels*reduction)), dropRate = dropRate)
        in_planes = int(math.floor(in_channels*reduction))
        
        
        #3rd block
        # nb_layers,in_planes,growh_rate,block,dropRate
        self.block3 = DenseBlock(n, in_channels, growh_rate, block, dropRate)
        in_planes = int(in_channels + n*growh_rate) # 입력 + 레이어 만큼의 growh_rate
        
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace = True)
        
        self.fc = nn.Linear(in_channels, num_classes) # 마지막에 ave_pool 후에 1x1 size의 결과만 남음.
        
        self.in_channels = in_channels
        
        # module 초기화
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # Conv layer들은 필터에서 나오는 분산 root(2/n)로 normalize 함
                # mean = 0 , 분산 = sqrt(2/n) // 이게 무슨 초기화 방법이었는지 기억이 안난다.
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d): # shifting param이랑 scaling param 초기화(?)
                m.weight.data.fill_(1) # 
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):# linear layer 초기화.
                m.bias.data.zero_()
        
    def forward(self, x):
        #x : 32*32
        out = self.conv1(x) # 32*32
        out = self.block1(out) # 32*32
        out = self.trans1(out) # 16*16
        out = self.block2(out) # 16*16
        out = self.trans2(out) # 8*8
        out = self.block3(out) # 8*8
        out = self.relu(self.bn1(out)) #8*8
        out = F.avg_pool2d(out, 8) #1*1
        out = out.view(-1, self.in_planes) #channel수만 남기 때문에 Linear -> in_planes

        return self.fc(out)

In [None]:
import torch.nn as nn

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropRate = 0.0):
        #input dimsnsion을 정하고, output dimension을 정하고(growh_rate임), dropRate를 정함.
        super(BasicBlock, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 7, stride = 2, padding = 3, bias = False)
        self.bn = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace = True) # inplace 하면 input으로 들어온 것 자체를 수정하겠다는 뜻. 메모리 usage가 좀 좋아짐. 하지만 input을 없앰.
        self.max_pool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.droprate = dropRate
        
    def forward(self,x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)

        if self.droprate > 0:
            out = F.dropout (out, p = self.droprate, training = self.training)
            
        return out, self.max_pool(out)

class DenseBlock(nn.Module):
    def __init__(self, num_conva, in_channels, growthRate = 48, **kwargs) -> None:
        super(DenseBlock, self).__init__(**kwargs)
        self.net = nn.Sequential()

        for idx in range(num_conva):
            self.net.add_module(name=f'Dense + {idx}', module=self.doubleConvBlock(in_channels + idx*growthRate, growthRate))
    
    def doubleConvBlock(self, in_channels, out_channels):
        inter_planes = out_channels * 4
        block = nn.Sequential(
            nn.BatchNorm3d(in_channels),
            nn.Conv3d(in_channels, inter_planes, kernel_size=1, stride=1, padding=0, bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm3d(inter_planes),
            nn.Conv3d(inter_planes, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True)
        )

        return block
        
    def forward(self, x):
        for block in self.net:
            out = block(x)
            x = torch.concat((x, out), dim=1)
            print(x.shape)
        
        return x

class TransitionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropRate=0.0):
        super(TransitionBlock, self).__init__()
        self.bn = nn.BatchNorm3d(in_channels)
        self.relu = nn.ReLU(inplace = True)
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.avg_pool = nn.AvgPool3d(kernel_size=2, stride=2)
        self.droprate = dropRate
        
    def forward(self, x):
        out = self.conv(self.relu(self.bn(x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)

        out = self.avg_pool(out)
        print(out.shape)

        return out

class UpSampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropRate=0.0):
        super(UpSampleBlock, self).__init__()
        self.bn = nn.BatchNorm3d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.transConv = nn.ConvTranspose3d(in_channels, out_channels, 3, stride=2, padding=0, bias=False)


In [None]:
import numpy as np
import torch

in_channels = [64, 128, 256, 512]
num_blocks = [6, 12, 36, 24]
grows_rate = 48

x = torch.rand(1,1,208,224,208).cuda()
basic = BasicBlock(1, in_channels[0]).cuda()
dense = DenseBlock(num_blocks[0], in_channels[0], grows_rate).cuda()
trans = TransitionBlock(in_channels[0] + grows_rate*num_blocks[0], in_channels[1]).cuda()
test1 = basic(x)
test2 = dense(test1[1])
test3 = trans(test2)
dense = DenseBlock(num_blocks[1], in_channels[1], grows_rate).cuda()
trans = TransitionBlock(in_channels[1] + grows_rate*num_blocks[1], in_channels[2]).cuda()
test4 = dense(test3)
test5 = trans(test4)
dense = DenseBlock(num_blocks[2], in_channels[2], grows_rate).cuda()
trans = TransitionBlock(in_channels[2] + grows_rate*num_blocks[2], in_channels[3]).cuda()
test6 = dense(test5)
test7 = trans(test6)
# dense = DenseBlock(num_blocks[3], in_channels[2], grows_rate)
# trans = TransitionBlock(in_channels[2] + grows_rate*num_blocks[3], in_channels[3])
# test8 = dense(test7)
# test9 = trans(test8)

In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.BatchNorm3d(in_channels),
            nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm3d(out_channels),
            nn.Conv3d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.ReLU(inplace=True),
        ).to('cuda')

    def forward(self, x):
        return self.conv(x)

class DenseBlock(nn.Module):
    def __init__(self, num_conva, in_channels, growthRate = 48, **kwargs) -> None:
        super(DenseBlock, self).__init__(**kwargs)
        self.net = nn.Sequential()

        for idx in range(num_conva):
            self.net.add_module(name=f'Dense + {idx}', module=self.doubleConvBlock(in_channels + idx*growthRate, growthRate))
    
    def doubleConvBlock(self, in_channels, out_channels):
        inter_planes = out_channels * 4
        block = nn.Sequential(
            nn.BatchNorm3d(in_channels),
            nn.Conv3d(in_channels, inter_planes, kernel_size=1, stride=1, padding=0, bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm3d(inter_planes),
            nn.Conv3d(inter_planes, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True)
        )

        return block
        
    def forward(self, x):
        for block in self.net:
            out = block(x)
            x = torch.concat([x, out], dim=1)

        return x

class TransitionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropRate=0.0):
        super(TransitionBlock, self).__init__()
        self.bn = nn.BatchNorm3d(in_channels)
        self.relu = nn.ReLU(inplace = True)
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.avg_pool = nn.AvgPool3d(kernel_size=2, stride=2)
        self.droprate = dropRate
        
    def forward(self, x):
        out = self.conv(self.relu(self.bn(x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)

        out = self.avg_pool(out)
        return out

in_channels_dense = [64, 128, 256, 512]
num_blocks = [5, 10, 30, 20]
grows_rate = 48

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropRate = 0.0):
        #input dimsnsion을 정하고, output dimension을 정하고(growh_rate임), dropRate를 정함.
        super(BasicBlock, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 7, stride = 2, padding = 3, bias = False)
        self.bn = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace = True) # inplace 하면 input으로 들어온 것 자체를 수정하겠다는 뜻. 메모리 usage가 좀 좋아짐. 하지만 input을 없앰.
        self.max_pool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.droprate = dropRate
        
    def forward(self,x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)

        if self.droprate > 0:
            out = F.dropout (out, p = self.droprate, training = self.training)
        
        return out, self.max_pool(out)

class TransposeBlock(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(TransposeBlock, self).__init__()
        self.convTrans = nn.Sequential(
            nn.BatchNorm3d(in_channels),
            nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm3d(out_channels),
            nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.convTrans(x)

class NestedDenseUNet3D(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features_encoder=[32, 64, 128, 256], features_decoder=[1216, 512, 256, 128], features_double=[1824, 800, 464, 154]
    ):
        super(NestedDenseUNet3D, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.downs_dense = nn.ModuleList()
        self.downs_trans = nn.ModuleList()
        self.basic_block = BasicBlock(1, 32)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        self.sigmoid = nn.Sigmoid()
        self.up = nn.Upsample(scale_factor=2)

        # Down part of UNET
        for idx, (feature, num_block) in enumerate(zip(features_encoder, num_blocks)):
            self.downs.append(DoubleConv(in_channels, feature))
            self.downs_dense.append(DenseBlock(num_block, feature, grows_rate))
            in_channels = feature

            if idx < 3:
                self.downs_trans.append(TransitionBlock(in_channels + num_block*grows_rate, features_encoder[idx+1]))

        # Up part of UNET
        for idx, (feature, double, encoder) in enumerate(zip(features_decoder, features_double, reversed(features_encoder))):
            self.ups.append(TransposeBlock(feature, encoder))
            self.ups.append(DoubleConv(double, encoder*2))

        self.final_convTrans = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
        self.final_conv = nn.Conv3d(32, out_channels, kernel_size=1)

    def forward(self, x):
        x = self.basic_block(x)
        x0_0 = x[0]
        x1_0 = x[1]
        x1_0_dense = self.downs_dense[0](x1_0)
        x0_1 = DoubleConv(304, 30)
        x0_1 = x0_1(torch.cat([x0_0, self.up(x1_0_dense)], 1))

        x1_0 = self.downs_trans[0](x1_0_dense)

        x2_0_dense = self.downs_dense[1](x1_0)
        x1_1 = DoubleConv(816, 64)
        x2_0_trans = TransposeBlock(x2_0_dense.shape[1], x2_0_dense.shape[1]).to('cuda')
        x1_1 = x1_1(torch.cat([x1_0_dense, x2_0_trans(x2_0_dense)], 1))
        x0_2 = DoubleConv(126, 30)
        x1_1_trans = TransposeBlock(x1_1.shape[1], x1_1.shape[1]).to('cuda')
        x0_2 = x0_2(torch.cat([x0_0, x0_1, x1_1_trans(x1_1)], 1))

        x2_0 = self.downs_trans[1](x2_0_dense)

        x3_0_dense = self.downs_dense[2](x2_0)
        x2_1 = DoubleConv(2112, 128)
        x3_0_trans = TransposeBlock(x3_0_dense.shape[1], x3_0_dense.shape[1]).to('cuda')
        x2_1 = x2_1(torch.cat([x2_0_dense, x3_0_trans(x3_0_dense)], 1))
        x1_2 = DoubleConv(464, 64)
        x2_1_trans = TransposeBlock(x2_1.shape[1], x2_1.shape[1]).to('cuda')
        x1_2 = x1_2(torch.cat([x1_0_dense, x1_1, x2_1_trans(x2_1)], 1))
        x0_3 = DoubleConv(156, 30)
        x1_2_trans = TransposeBlock(x1_2.shape[1], x1_2.shape[1]).to('cuda')
        x0_3 = x0_3(torch.cat([x0_0, x0_1, x0_2, x1_2_trans(x1_2)], 1))

        x3_0 = self.downs_trans[2](x3_0_dense)

        x4_0 = self.downs_dense[3](x3_0)
        print(x4_0.shape)
        p3d = (0, 1, 0, 0, 0, 1)
        x4_0_up = self.ups[0](x4_0)
        x4_0_up = F.pad(x4_0_up, p3d)
        print(x4_0_up.shape)
        x3_1 = self.ups[1](torch.cat([x3_0_dense, x4_0_up], 1))
        print(x3_1.shape)
        x3_1_up = self.ups[2](x3_1)
        print(x3_1_up.shape)
        x2_2 = self.ups[3](torch.cat([x2_0_dense, x2_1, x3_1_up], 1))
        x2_2_up = self.ups[4](x2_2)
        print(x2_2.shape)
        print(x2_2_up.shape)
        x1_3 = self.ups[5](torch.cat([x1_0_dense, x1_1, x1_2, x2_2_up], 1))
        x1_3_up = self.ups[6](x1_3)
        print(x1_3.shape)
        print(x1_3_up.shape)
        x0_4 = self.ups[7](torch.cat([x0_0, x0_1, x0_2, x0_3, x1_3_up], 1))

        ouput = self.final_convTrans(x0_4)
        ouput = self.final_conv(ouput)
        ouput = self.sigmoid(ouput)

        return ouput

def test():
    x = torch.randn((1, 1, 208, 224, 208)).to('cuda')
    model = NestedDenseUNet3D(in_channels=1, out_channels=1).to('cuda')
    preds = model(x)
    # assert preds.shape == x.shape

if __name__ == "__main__":
    test()

torch.Size([1, 1216, 6, 7, 6])
torch.Size([1, 256, 13, 14, 13])
torch.Size([1, 512, 13, 14, 13])
torch.Size([1, 128, 26, 28, 26])
torch.Size([1, 256, 26, 28, 26])
torch.Size([1, 64, 52, 56, 52])
torch.Size([1, 128, 52, 56, 52])
torch.Size([1, 32, 104, 112, 104])


RuntimeError: running_mean should contain 154 elements not 160

In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.BatchNorm3d(in_channels),
            nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm3d(out_channels),
            nn.Conv3d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        print(x.shape)
        print(self.conv(x).shape)
        return self.conv(x)

def test():
    x = torch.randn((1, 1, 208, 224, 208))
    model = DoubleConv(in_channels=1, out_channels=32)
    preds = model(x)
    # assert preds.shape == x.shape

if __name__ == "__main__":
    test()