In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
#from torch_summary import summary
from torchinfo import summary

class Swish(nn.Module):
    def __init__(self, inplace=False):
        super(Swish,self).__init__()
        self.inplace = True
    def forward(self, x):
        if self.inplace:
            x.mul_(torch.sigmoid(x))
            return x
        else:
            return x * torch.sigmoid(x)

class Sigmoid(nn.Module):
    def __init__(self, inplace=False):
        super(Sigmoid,self).__init__()
        self.inplace = True
    def forward(self, x):
        return x.sigmoid_()
    
ACTV = Swish #nn.ReLU

In [None]:
def ConvBnActv(in_chs, out_chs, kernel_size=3, stride=1, actv=ACTV):
    return nn.Sequential(
                nn.Conv3d(in_chs,out_chs,kernel_size,stride,kernel_size//2,bias=False),
                nn.BatchNorm3d(out_chs),
                actv(inplace=True),
        ) 

def Upsample(factor=2):
    return nn.Upsample(scale_factor=2, mode='nearest', align_corners=None)

    
def FCBlock(in_chs, out_chs, droprate=0.1, actv=None):
    return nn.Sequential(
            nn.Dropout3d(droprate) if droprate>0 else nn.Sequential(),
            nn.Linear(in_chs, out_chs),
            actv(inplace=True) if actv is not None else nn.Sequential(),
        )

In [None]:
class GhostBlock(nn.Module):
    
    def __init__(self, in_chs, mid_chs, out_chs, stride=1):
        super(GhostBlock, self).__init__()
        self.stride = stride
        self.ghost1 = GhostModule(in_chs,mid_chs,actv=True)
        self.ghost2 = GhostModule(mid_chs,out_chs,actv=False)
        self.dropout = nn.Dropout3d(0.1, inplace=True)
        self.shortcut = nn.Sequential()
                
    def forward(self, x):
        res = x
        x = self.ghost1(x)
        x = self.ghost2(x)
        x += self.shortcut(res)
        return x

In [None]:
class GhostModule(nn.Module):
    def __init__(self, inp, oup, kernel_size=1, stride=1, ratio=2, actv=True):
        super(GhostModule, self).__init__()
        self.oup = oup
        init_channels = math.ceil(oup / ratio)
        new_channels = init_channels*(ratio-1)

        self.primary_conv = nn.Sequential(
            nn.Conv3d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False),
            nn.BatchNorm3d(init_channels),
            ACTV(inplace=True) if actv else nn.Sequential(),
        )

        self.cheap_operation = nn.Sequential(
            nn.Conv3d(init_channels, new_channels, 3, 1, 3//2, groups=init_channels, bias=False),
            nn.BatchNorm3d(new_channels),
            ACTV(inplace=True) if actv else nn.Sequential(),
        )

    def forward(self, x):
        x1 = self.primary_conv(x)
        x2 = self.cheap_operation(x1)
        out = torch.cat([x1,x2], dim=1)
        return out[:,:self.oup,:,:]
    

In [None]:
class NetCAE(nn.Module):
        
    def __init__(self):
        super(NetCAE, self).__init__()

        # MRI
        self.encoder = nn.Sequential(
            
            ConvBnActv( 1,16,3,2),
            ConvBnActv(16,16,3,1),
            ConvBnActv(16,24,3,1),
            nn.MaxPool3d(3,2,1),
            nn.Dropout3d(0.3),
            
            GhostBlock(24,72,24),
            ConvBnActv(24,40,3,1),
            nn.MaxPool3d(3,2,1),
            nn.Dropout3d(0.3),
            
            GhostBlock(40,120,40),
            ConvBnActv(40,80,3,1),
            nn.MaxPool3d(3,2,1),
            nn.Dropout3d(0.3),
            
            GhostBlock(80,480,80),
            GhostBlock(80,480,80),
            GhostBlock(80,480,80),
            ConvBnActv(80,120,1,1),
        )
        
        self.decoder = nn.Sequential(
            ConvBnActv(120,80,1,1),
            Upsample(2),
            ConvBnActv(80,40,3,1),
            Upsample(2),
            ConvBnActv(40,24,3,1),
            Upsample(2),
            ConvBnActv(24,16,3,1),
            ConvBnActv(16,16,3,1),
            Upsample(2),
            ConvBnActv(16, 1,3,1, actv=Sigmoid),
        )
        
    def forward(self, mri):
        h = self.encoder(mri)
        r = self.decoder(h)
        return r
    

In [None]:
class NetMTL(nn.Module):
        
    def __init__(self, encoder, modality='multimodality'):
        super(NetMTL, self).__init__()
        
        self.encoder   = encoder
        self.use_mri   = (modality != 'cdata')
        self.use_cdata = (modality == 'cdata' or modality == 'multimodality')
        
        # for MRI features from encoder
        self.fc_mri = nn.Sequential(
            nn.AdaptiveAvgPool3d(2),
            nn.Flatten(),
            FCBlock(960, 32, droprate=0.1),
            FCBlock(32, 8, droprate=0.1),
        )
                
        # for clinical features
        self.fc_cdata = nn.Sequential(
            nn.Flatten(),
            FCBlock(11, 32, droprate=0.1),
            FCBlock(32, 8, droprate=0.1),
        )
        
        self.fc_fuse = FCBlock(16,4) if self.use_mri and self.use_cdata else FCBlock(8,4)
        
        self.task1 = FCBlock(4,1,actv=Sigmoid)
        self.task2 = FCBlock(4,1,actv=Sigmoid)
    
    
    def forward(self, mri, cdata):
        
        if self.use_mri:
            h = self.encoder(mri)
            x = self.fc_mri(h)
            
        if self.use_cdata:
            y = self.fc_cdata(cdata)
        
        if self.use_mri and self.use_cdata:
            z = torch.cat([x,y], 1)
        elif self.use_cdata:
            z = y
        elif self.use_mri:
            z = x
        
        z = self.fc_fuse(z)

        out1 = self.task1(z)
        out2 = self.task2(z)
        
        return out1, out2


In [None]:
if __name__ == '__main__':
    #calculate model params
    autoencoder = NetCAE()
    print(summary(autoencoder, (1, 1, 96, 96, 96), device='cpu', depth=2, col_names=["output_size", "num_params", "mult_adds"]))
    model = NetMTL(autoencoder.encoder)
    print(summary(model, ((1, 1, 96, 96, 96), (1,11)), device='cpu', depth=1, col_names=["output_size", "num_params", "mult_adds"]))