In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '9'

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class in_conv(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size = 3, 
                 dropout_p=0.3, leakiness=1e-2, conv_bias=True, 
                 inst_norm_affine=True, res = False, lrelu_inplace=True):
        """[The initial convolution to enter the network, kind of like encode]
        
        [This function will create the input convolution]
        
        Arguments:
            input_channels {[int]} -- [the input number of channels, in our case
                                       the number of modalities]
            output_channels {[int]} -- [the output number of channels, will det-
                                        -ermine the upcoming channels]
        
        Keyword Arguments:
            kernel_size {number} -- [size of filter] (default: {3})
            dropout_p {number} -- [dropout probablity] (default: {0.3})
            leakiness {number} -- [the negative leakiness] (default: {1e-2})
            conv_bias {bool} -- [to use the bias in filters] (default: {True})
            inst_norm_affine {bool} -- [affine use in norm] (default: {True})
            res {bool} -- [to use residual connections] (default: {False})
            lrelu_inplace {bool} -- [To update conv outputs with lrelu outputs] 
                                    (default: {True})
        """
        nn.Module.__init__(self)
        self.residual = res
        self.dropout_p = dropout_p
        self.conv_bias = conv_bias
        self.leakiness = leakiness
        self.inst_norm_affine = inst_norm_affine
        self.lrelu_inplace = lrelu_inplace
        self.dropout = nn.Dropout3d(dropout_p)  
        self.in_0 = nn.InstanceNorm3d(output_channels, 
                                      affine=self.inst_norm_affine,
                                      track_running_stats=True)
        self.in_1 = nn.InstanceNorm3d(output_channels, 
                                      affine=self.inst_norm_affine,
                                      track_running_stats=True)
        self.conv0 = nn.Conv3d(input_channels, output_channels, kernel_size=3,
                               stride=1, padding=(kernel_size - 1) // 2, 
                               bias=self.conv_bias)
        self.conv1 = nn.Conv3d(output_channels, output_channels, kernel_size=3,
                               stride=1, padding=(kernel_size - 1) // 2, 
                               bias=self.conv_bias)
        self.conv2 = nn.Conv3d(output_channels, output_channels, kernel_size=3,
                               stride=1, padding=(kernel_size - 1) // 2, 
                               bias=self.conv_bias)

    def forward(self, x):
        """The forward function for initial convolution
        
        [input --> conv0 --> | --> in --> lrelu --> conv1 --> dropout --> in -|
                             |                                                |
                  output <-- + <-------------------------- conv2 <-- lrelu <--|]
        
        Arguments:
            x {[Tensor]} -- [Takes in a type of torch Tensor]
        
        Returns:
            [Tensor] -- [Returns a torch Tensor]
        """
        x = self.conv0(x)
        if self.residual == True:
            skip = x
        x = F.leaky_relu(self.in_0(x), negative_slope=self.leakiness, 
                         inplace=self.lrelu_inplace)
        x = self.conv1(x)
        if self.dropout_p is not None and self.dropout_p > 0:
            x = self.dropout(x)
        x = F.leaky_relu(self.in_1(x), negative_slope=self.leakiness, 
                         inplace=self.lrelu_inplace)
        x = self.conv2(x)
        if self.residual == True:
            x = x + skip
        #print(x.shape)
        return x

class DownsamplingModule(nn.Module):
    def __init__(self, input_channels, output_channels, leakiness=1e-2, 
                 dropout_p=0.3, kernel_size=3, conv_bias=True, 
                 inst_norm_affine=True, lrelu_inplace=True):
        """[To Downsample a given input with convolution operation]
        
        [This one will be used to downsample a given comvolution while doubling 
        the number filters]
        
        Arguments:
            input_channels {[int]} -- [The input number of channels are taken
                                       and then are downsampled to double usually]
            output_channels {[int]} -- [the output number of channels are 
                                        usually the double of what of input]
        
        Keyword Arguments:
            leakiness {float} -- [the negative leakiness] (default: {1e-2})
            conv_bias {bool} -- [to use the bias in filters] (default: {True})
            inst_norm_affine {bool} -- [affine use in norm] (default: {True})
            lrelu_inplace {bool} -- [To update conv outputs with lrelu outputs] 
                                    (default: {True})
        """
        #nn.Module.__init__(self)
        super(DownsamplingModule, self).__init__()
        self.dropout_p=dropout_p
        self.conv_bias = conv_bias
        self.leakiness = leakiness
        self.inst_norm_affine = inst_norm_affine
        self.lrelu_inplace = True
        self.in_0 = nn.InstanceNorm3d(output_channels, 
                                    affine=self.inst_norm_affine,
                                    track_running_stats=True)
        self.conv0 = nn.Conv3d(input_channels, output_channels, kernel_size = 3,
                               stride=2, padding=(kernel_size - 1) // 2, 
                               bias = self.conv_bias)

    def forward(self, x):
        """[This is a forward function for ]
        
        [input -- > in --> lrelu --> ConvDS --> output]
        
        Arguments:
            x {[Tensor]} -- [Takes in a type of torch Tensor]
        
        Returns:
            [Tensor] -- [Returns a torch Tensor]
        """
        x = F.leaky_relu(self.in_0(self.conv0(x)), 
                         negative_slope=self.leakiness, 
                         inplace=self.lrelu_inplace)
        #print(x.shape)
        return x

class EncodingModule(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size = 3, 
                 dropout_p=0.3, leakiness=1e-2, conv_bias=True, 
                 inst_norm_affine=True, res = False, lrelu_inplace=True):
        """[The Encoding convolution module to learn the information and use later]
            
            [This function will create the Learning convolutions]
            
            Arguments:
                input_channels {[int]} -- [the input number of channels, in our case
                                           the number of channels from downsample]
                output_channels {[int]} -- [the output number of channels, will det-
                                            -ermine the upcoming channels]
            
            Keyword Arguments:
                kernel_size {number} -- [size of filter] (default: {3})
                dropout_p {number} -- [dropout probablity] (default: {0.3})
                leakiness {number} -- [the negative leakiness] (default: {1e-2})
                conv_bias {bool} -- [to use the bias in filters] (default: {True})
                inst_norm_affine {bool} -- [affine use in norm] (default: {True})
                res {bool} -- [to use residual connections] (default: {False})
                lrelu_inplace {bool} -- [To update conv outputs with lrelu outputs] 
                                        (default: {True})
        """
        nn.Module.__init__(self)
        self.res = res
        self.dropout_p = dropout_p
        self.conv_bias = conv_bias
        self.leakiness = leakiness
        self.inst_norm_affine = inst_norm_affine
        self.lrelu_inplace = lrelu_inplace
        self.dropout = nn.Dropout3d(dropout_p)
        self.in_0 = nn.InstanceNorm3d(output_channels, 
                                      affine=self.inst_norm_affine,
                                      track_running_stats=True)
        self.in_1 = nn.InstanceNorm3d(output_channels, 
                                      affine=self.inst_norm_affine,
                                      track_running_stats=True)
        self.conv0 = nn.Conv3d(output_channels, output_channels, kernel_size=3,
                               stride=1, padding=(kernel_size - 1) // 2, 
                               bias=self.conv_bias)
        self.conv1 = nn.Conv3d(output_channels, output_channels, kernel_size=3,
                               stride=1, padding=(kernel_size - 1) // 2, 
                               bias=self.conv_bias)

    def forward(self, x):
        """The forward function for initial convolution
        
        [input --> | --> in --> lrelu --> conv0 --> dropout --> in -|
                   |                                                |
        output <-- + <-------------------------- conv1 <-- lrelu <--|]
        
        Arguments:
            x {[Tensor]} -- [Takes in a type of torch Tensor]
        
        Returns:
            [Tensor] -- [Returns a torch Tensor]
        """
        if self.res == True:
            skip = x
        x = F.leaky_relu(self.in_0(x), negative_slope=self.leakiness, 
                         inplace=self.lrelu_inplace)
        x = self.conv0(x)
        if self.dropout_p is not None and self.dropout_p > 0:
            x = self.dropout(x)
        x = F.leaky_relu(self.in_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
        x = self.conv1(x)
        if self.res == True:
            x = x + skip
        #print(x.shape)
        return x

class Interpolate(nn.Module):
    def __init__(self, size=None, scale_factor=None, mode='nearest', 
                 align_corners=True):
        super(Interpolate, self).__init__()
        self.align_corners = align_corners
        self.mode = mode
        self.scale_factor = scale_factor
        self.size = size
        
    def forward(self, x):
        return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, 
                             mode=self.mode, align_corners=self.align_corners)
    
class UpsamplingModule(nn.Module): 
    def __init__(self, input_channels, output_channels, leakiness=1e-2, 
        lrelu_inplace=True, kernel_size=3, scale_factor=2,
        conv_bias=True, inst_norm_affine=True):
        """[summary]
        
        [description]
        
        Arguments:
            input__channels {[type]} -- [description]
            output_channels {[type]} -- [description]
        
        Keyword Arguments:
            leakiness {number} -- [description] (default: {1e-2})
            lrelu_inplace {bool} -- [description] (default: {True})
            kernel_size {number} -- [description] (default: {3})
            scale_factor {number} -- [description] (default: {2})
            conv_bias {bool} -- [description] (default: {True})
            inst_norm_affine {bool} -- [description] (default: {True})
        """
        nn.Module.__init__(self)
        self.lrelu_inplace = lrelu_inplace
        self.inst_norm_affine = inst_norm_affine
        self.conv_bias = conv_bias
        self.leakiness = leakiness
        self.scale_factor = scale_factor
        self.interpolate = Interpolate(scale_factor=self.scale_factor, mode='trilinear', 
                                       align_corners=True)
        self.conv0 = nn.Conv3d(input_channels, output_channels, kernel_size=1,
                                stride=1, padding=0, 
                                bias = self.conv_bias)
        
    def forward(self, x):
        """[summary]
        
        [description]
        
        Extends:
        """
        x = self.conv0(self.interpolate(x))
        #print(x.shape)
        return x

class FCNUpsamplingModule(nn.Module):
    def __init__(self, input_channels, output_channels, leakiness=1e-2, 
        lrelu_inplace=True, kernel_size=3, scale_factor=2,
        conv_bias=True, inst_norm_affine=True):
        """[summary]
        
        [description]
        
        Arguments:
            input__channels {[type]} -- [description]
            output_channels {[type]} -- [description]
        
        Keyword Arguments:
            leakiness {number} -- [description] (default: {1e-2})
            lrelu_inplace {bool} -- [description] (default: {True})
            kernel_size {number} -- [description] (default: {3})
            scale_factor {number} -- [description] (default: {2})
            conv_bias {bool} -- [description] (default: {True})
            inst_norm_affine {bool} -- [description] (default: {True})
        """
        nn.Module.__init__(self)
        self.lrelu_inplace = lrelu_inplace
        self.inst_norm_affine = inst_norm_affine
        self.conv_bias = conv_bias
        self.leakiness = leakiness
        self.scale_factor = scale_factor
        self.interpolate = Interpolate(scale_factor=2**(self.scale_factor-1), mode='trilinear', 
                                       align_corners=True)
        self.conv0 = nn.Conv3d(input_channels, output_channels, kernel_size=1,
                                stride=1, padding=0, 
                                bias = self.conv_bias)
        
    def forward(self, x):
        """[summary]
        
        [description]
        
        Extends:
        """
        #print("Pre interpolate and conv:", x.shape)
        x = self.interpolate(self.conv0(x))
        #print("Post interpolate and conv:", x.shape)
        return x


class DecodingModule(nn.Module):
    def __init__(self, input_channels, output_channels, leakiness=1e-2, conv_bias=True, kernel_size=3,
        inst_norm_affine=True, res=True, lrelu_inplace=True):
        """[The Decoding convolution module to learn the information and use later]
        
        [This function will create the Learning convolutions]
        
        Arguments:
            input_channels {[int]} -- [the input number of channels, in our case
                                       the number of channels from downsample]
            output_channels {[int]} -- [the output number of channels, will det-
                                        -ermine the upcoming channels]
        
        Keyword Arguments:
            kernel_size {number} -- [size of filter] (default: {3})
            leakiness {number} -- [the negative leakiness] (default: {1e-2})
            conv_bias {bool} -- [to use the bias in filters] (default: {True})
            inst_norm_affine {bool} -- [affine use in norm] (default: {True})
            res {bool} -- [to use residual connections] (default: {False})
            lrelu_inplace {bool} -- [To update conv outputs with lrelu outputs] 
                                    (default: {True})
        """
        nn.Module.__init__(self)
        self.lrelu_inplace = lrelu_inplace
        self.inst_norm_affine = inst_norm_affine
        self.conv_bias = conv_bias
        self.leakiness = leakiness
        self.res = res
        self.in_0 = nn.InstanceNorm3d(input_channels, 
                                      affine=self.inst_norm_affine,
                                      track_running_stats=True)
        self.in_1 = nn.InstanceNorm3d(output_channels, 
                                      affine=self.inst_norm_affine,
                                      track_running_stats=True)
        self.in_2 = nn.InstanceNorm3d(output_channels, 
                                      affine=self.inst_norm_affine,
                                      track_running_stats=True)
        self.conv0 = nn.Conv3d(input_channels, output_channels, kernel_size=3,
                               stride=1, padding=(kernel_size - 1) // 2, 
                               bias=self.conv_bias)
        self.conv1 = nn.Conv3d(output_channels, output_channels, kernel_size=3,
                               stride=1, padding=(kernel_size - 1) // 2, 
                               bias=self.conv_bias)
        self.conv2 = nn.Conv3d(output_channels, output_channels, kernel_size=3,
                               stride=1, padding=(kernel_size - 1) // 2, 
                               bias=self.conv_bias)
        
    def forward(self, x1, x2):
        x = torch.cat([x1, x2], dim=1)
        #print(x.shape)
        x = F.leaky_relu(self.in_0(x))
        x = self.conv0(x)
        if self.res == True:
            skip = x
        x = F.leaky_relu(self.in_1(x))
        x = F.leaky_relu(self.in_2(self.conv1(x)))
        x = self.conv2(x)
        if self.res == True:
            x = x + skip
        return x

class out_conv(nn.Module):
    def __init__(self, input_channels, output_channels, leakiness=1e-2, kernel_size=3,
        conv_bias=True, inst_norm_affine=True, res=True, lrelu_inplace=True):
        """[The Out convolution module to learn the information and use later]
        
        [This function will create the Learning convolutions]
        
        Arguments:
            input_channels {[int]} -- [the input number of channels, in our case
                                       the number of channels from downsample]
            output_channels {[int]} -- [the output number of channels, will det-
                                        -ermine the upcoming channels]
        
        Keyword Arguments:
            kernel_size {number} -- [size of filter] (default: {3})
            leakiness {number} -- [the negative leakiness] (default: {1e-2})
            conv_bias {bool} -- [to use the bias in filters] (default: {True})
            inst_norm_affine {bool} -- [affine use in norm] (default: {True})
            res {bool} -- [to use residual connections] (default: {False})
            lrelu_inplace {bool} -- [To update conv outputs with lrelu outputs] 
                                    (default: {True})
        """
        nn.Module.__init__(self)
        self.lrelu_inplace = lrelu_inplace
        self.inst_norm_affine = inst_norm_affine
        self.conv_bias = conv_bias
        self.leakiness = leakiness
        self.res = res
        self.in_0 = nn.InstanceNorm3d(input_channels, 
                                      affine=self.inst_norm_affine,
                                      track_running_stats=True)
        self.in_1 = nn.InstanceNorm3d(input_channels//2, 
                                      affine=self.inst_norm_affine,
                                      track_running_stats=True)
        self.in_2 = nn.InstanceNorm3d(input_channels//2, 
                                      affine=self.inst_norm_affine,
                                      track_running_stats=True)
        self.in_3 = nn.InstanceNorm3d(input_channels//2, 
                                      affine=self.inst_norm_affine,
                                      track_running_stats=True)
        self.conv0 = nn.Conv3d(input_channels, input_channels//2, kernel_size=3,
                               stride=1, padding=(kernel_size - 1) // 2, 
                               bias=self.conv_bias)
        self.conv1 = nn.Conv3d(input_channels//2, input_channels//2, kernel_size=3,
                               stride=1, padding=(kernel_size - 1) // 2, 
                               bias=self.conv_bias)
        self.conv2 = nn.Conv3d(input_channels//2, input_channels//2, kernel_size=3,
                               stride=1, padding=(kernel_size - 1) // 2, 
                               bias=self.conv_bias)
        self.conv3 = nn.Conv3d(input_channels//2, output_channels, kernel_size=1,
                               stride=1, padding=0, 
                               bias=self.conv_bias)
        
    def forward(self, x1, x2):
        x = torch.cat([x1, x2], dim=1)
        #print(x.shape)
        x = F.leaky_relu(self.in_0(x))
        x = self.conv0(x)
        if self.res == True:
            skip = x
        x = F.leaky_relu(self.in_1(x))
        x = F.leaky_relu(self.in_2(self.conv1(x)))
        x = self.conv2(x)
        if self.res == True:
            x = x + skip
        x = F.leaky_relu(self.in_3(x))
        x = F.sigmoid(self.conv3(x))
        return x


In [3]:
class unet_light(nn.Module):
	def __init__(self, n_channels, n_classes):
		super(unet_light, self).__init__()
		self.n_channels = n_channels
		self.n_classes = n_classes
		self.ins = in_conv(self.n_channels, 8)
		self.ds_0 = DownsamplingModule(8, 16)
		self.en_1 = EncodingModule(16, 16)
		self.ds_1 = DownsamplingModule(16, 32)
		self.en_2 = EncodingModule(32, 32)
		self.ds_2 = DownsamplingModule(32, 64)
		self.en_3 = EncodingModule(64, 64)
		self.ds_3 = DownsamplingModule(64, 128)
		self.en_4 = EncodingModule(128, 128)
		self.us_3 = UpsamplingModule(128, 64)
		self.de_3 = DecodingModule(128, 64)
		self.us_2 = UpsamplingModule(64, 32)
		self.de_2 = DecodingModule(64, 32)
		self.us_1 = UpsamplingModule(32, 16)
		self.de_1 = DecodingModule(32, 16)
		self.us_0 = UpsamplingModule(16, 8)
		self.out = out_conv(16, self.n_classes-1)

	def forward(self, x):
		x1 = self.ins(x)
		x2 = self.ds_0(x1)
		x2 = self.en_1(x2)
		x3 = self.ds_1(x2)
		x3 = self.en_2(x3)
		x4 = self.ds_2(x3)
		x4 = self.en_3(x4)
		x5 = self.ds_3(x4)
		x5 = self.en_4(x5)

		x = self.us_3(x5)
		x = self.de_3(x, x4)
		x = self.us_2(x)
		x = self.de_2(x, x3)
		x = self.us_1(x)
		x = self.de_1(x, x2)
		x = self.us_0(x)
		x = self.out(x, x1)
		return x

class unet_crisp(nn.Module):
	def __init__(self, n_channels, n_classes):
		super(unet_crisp, self).__init__()
		self.n_channels = n_channels
		self.n_classes = n_classes
		self.ins = in_conv(self.n_channels, 4)
		self.ds_0 = DownsamplingModule(4, 8)
		self.en_1 = EncodingModule(8, 8)
		self.ds_1 = DownsamplingModule(8, 16)
		self.en_2 = EncodingModule(16, 16)
		self.ds_2 = DownsamplingModule(16, 32)
		self.en_3 = EncodingModule(32, 32)
		self.ds_3 = DownsamplingModule(32, 64)
		self.en_4 = EncodingModule(64, 64)
		self.us_3 = UpsamplingModule(64, 32)
		self.de_3 = DecodingModule(64, 32)
		self.us_2 = UpsamplingModule(32, 16)
		self.de_2 = DecodingModule(32, 16)
		self.us_1 = UpsamplingModule(16, 8)
		self.de_1 = DecodingModule(16, 8)
		self.us_0 = UpsamplingModule(8, 4)
		self.out = out_conv(8, self.n_classes-1)

	def forward(self, x):
		x1 = self.ins(x)
		x2 = self.ds_0(x1)
		x2 = self.en_1(x2)
		x3 = self.ds_1(x2)
		x3 = self.en_2(x3)
		x4 = self.ds_2(x3)
		x4 = self.en_3(x4)
		x5 = self.ds_3(x4)
		x5 = self.en_4(x5)

		x = self.us_3(x5)
		x = self.de_3(x, x4)
		x = self.us_2(x)
		x = self.de_2(x, x3)
		x = self.us_1(x)
		x = self.de_1(x, x2)
		x = self.us_0(x)
		x = self.out(x, x1)
		return x
    
class unet(nn.Module):
	def __init__(self, n_channels, n_classes):
		super(unet, self).__init__()
		self.n_channels = n_channels
		self.n_classes = n_classes
		self.ins = in_conv(self.n_channels, 16)
		self.ds_0 = DownsamplingModule(16, 32)
		self.en_1 = EncodingModule(32, 32)
		self.ds_1 = DownsamplingModule(32, 64)
		self.en_2 = EncodingModule(64, 64)
		self.ds_2 = DownsamplingModule(64, 128)
		self.en_3 = EncodingModule(128, 128)
		self.ds_3 = DownsamplingModule(128, 256)
		self.en_4 = EncodingModule(256, 256)
		self.us_3 = UpsamplingModule(256, 128)
		self.de_3 = DecodingModule(256, 128)
		self.us_2 = UpsamplingModule(128, 64)
		self.de_2 = DecodingModule(128, 64)
		self.us_1 = UpsamplingModule(64, 32)
		self.de_1 = DecodingModule(64, 32)
		self.us_0 = UpsamplingModule(32, 16)
		self.out = out_conv(32, self.n_classes-1)

	def forward(self, x):
		x1 = self.ins(x)
		x2 = self.ds_0(x1)
		x2 = self.en_1(x2)
		x3 = self.ds_1(x2)
		x3 = self.en_2(x3)
		x4 = self.ds_2(x3)
		x4 = self.en_3(x4)
		x5 = self.ds_3(x4)
		x5 = self.en_4(x5)

		x = self.us_3(x5)
		x = self.de_3(x, x4)
		x = self.us_2(x)
		x = self.de_2(x, x3)
		x = self.us_1(x)
		x = self.de_1(x, x2)
		x = self.us_0(x)
		x = self.out(x, x1)
		return x

class resunet(nn.Module):
	def __init__(self, n_channels, n_classes):
		super(resunet, self).__init__()
		self.n_channels = n_channels
		self.n_classes = n_classes
		self.ins = in_conv(self.n_channels, 16, res=True)
		self.ds_0 = DownsamplingModule(16, 32)
		self.en_1 = EncodingModule(32, 32, res=True)
		self.ds_1 = DownsamplingModule(32, 64)
		self.en_2 = EncodingModule(64, 64, res=True)
		self.ds_2 = DownsamplingModule(64, 128)
		self.en_3 = EncodingModule(128, 128, res=True)
		self.ds_3 = DownsamplingModule(128, 256)
		self.en_4 = EncodingModule(256, 256, res=True)
		self.us_3 = UpsamplingModule(256, 128)
		self.de_3 = DecodingModule(256, 128, res=True)
		self.us_2 = UpsamplingModule(128, 64)
		self.de_2 = DecodingModule(128, 64, res=True)
		self.us_1 = UpsamplingModule(64, 32)
		self.de_1 = DecodingModule(64, 32, res=True)
		self.us_0 = UpsamplingModule(32, 16)
		self.out = out_conv(32, self.n_classes-1, res=True)

	def forward(self, x):
		x1 = self.ins(x)
		x2 = self.ds_0(x1)
		x2 = self.en_1(x2)
		x3 = self.ds_1(x2)
		x3 = self.en_2(x3)
		x4 = self.ds_2(x3)
		x4 = self.en_3(x4)
		x5 = self.ds_3(x4)
		x5 = self.en_4(x5)

		x = self.us_3(x5)
		x = self.de_3(x, x4)
		x = self.us_2(x)
		x = self.de_2(x, x3)
		x = self.us_1(x)
		x = self.de_1(x, x2)
		x = self.us_0(x)
		x = self.out(x, x1)
		return x

class fcn(nn.Module):
	def __init__(self, n_channels, n_classes):
		super(fcn, self).__init__()
		self.n_channels = n_channels
		self.n_classes = n_classes
		self.ins = in_conv(self.n_channels, 16)
		self.ds_0 = DownsamplingModule(16, 32)
		self.en_1 = EncodingModule(32, 32)
		self.ds_1 = DownsamplingModule(32, 64)
		self.en_2 = EncodingModule(64, 64)
		self.ds_2 = DownsamplingModule(64, 128)
		self.en_3 = EncodingModule(128, 128)
		self.ds_3 = DownsamplingModule(128, 256)
		self.en_4 = EncodingModule(256, 256)
		self.us_4 = FCNUpsamplingModule(256, 1, scale_factor = 5)
		self.us_3 = FCNUpsamplingModule(128, 1, scale_factor = 4)
		self.us_2 = FCNUpsamplingModule(64, 1, scale_factor = 3)
		self.us_1 = FCNUpsamplingModule(32, 1, scale_factor = 2)
		self.us_0 = FCNUpsamplingModule(16, 1, scale_factor = 1)
		self.conv_0 = nn.Conv3d(in_channels=5, out_channels=self.n_classes-1, kernel_size=1, stride=1, padding=0, bias = True)

	def forward(self, x):
		x1 = self.ins(x)
		x2 = self.ds_0(x1)
		x2 = self.en_1(x2)
		x3 = self.ds_1(x2)
		x3 = self.en_2(x3)
		x4 = self.ds_2(x3)
		x4 = self.en_3(x4)
		x5 = self.ds_3(x4)
		x5 = self.en_4(x5)

		u5 = self.us_4(x5)
		u4 = self.us_3(x4)
		u3 = self.us_2(x3)
		u2 = self.us_1(x2)
		u1 = self.us_0(x1)
		x = torch.cat([u5, u4, u3, u2, u1], dim=1)
		x = self.conv_0(x)
		return F.sigmoid(x)

In [4]:
net = unet(1, 2)

In [5]:
import numpy as np
import torch.utils.data
from tfedlrn.datasets import load_dataset

def reshape_for_3d(x):
    v = x.reshape(x.shape[0] // 155, 1, 155, 128, 128)
    return v[:, :, 14:142, :, :]

def create_loader(X, y, **kwargs):
    tX = torch.stack([torch.Tensor(i) for i in X])
    ty = torch.stack([torch.Tensor(i) for i in y])
    return torch.utils.data.DataLoader(torch.utils.data.TensorDataset(tX, ty), **kwargs)

def init_data_pipeline(batch_size=1):
    # load all the institutions
    data_by_institution = [load_dataset('BraTS17_institution',
                                        channels_first=True,
                                        institution=i) for i in range(10)]
    data_by_type = zip(*data_by_institution)
    data_by_type = [np.concatenate(d) for d in data_by_type]
    X_train, y_train, X_val, y_val = [reshape_for_3d(d) for d in data_by_type]
    return (create_loader(X_train, y_train, batch_size=batch_size, shuffle=True), 
            create_loader(X_val, y_val, batch_size=batch_size, shuffle=True))

train_loader, val_loader = init_data_pipeline(batch_size=2)

In [6]:
# from torchviz import make_dot
# import torch.utils.data
# device = torch.device('cpu')
# data = torch.Tensor(np.random.random(size=155*128*128).reshape(1, 1, 128, 128, 128))
# data = data.to(device)
# net = net.to(device)
# output = net(data)
# make_dot(output, params=dict(net.named_parameters()))

In [13]:
def dice_coef(pred, target, smoothing=1.0, dim=(1,2,3,4)):    
    intersection = (pred * target).sum(dim=dim)
    union = (pred + target).sum(dim=dim)
    
    return ((2 * intersection + smoothing) / (union + smoothing)).mean()


def dice_coef_loss(pred, target, smoothing=1.0, dim=(1,2,3,4)):    
    intersection = (pred * target).sum(dim=dim)
    union = (pred + target).sum(dim=dim)
    
    term1 = -torch.log(2 * intersection + smoothing)
    term2 = torch.log(union + smoothing)
    
    return term1.mean() + term2.mean()


In [8]:
import torch.optim as optim

def init_optimizer(net, optimizer='RMSprop', lr=1e-4, momentum=0.9):
    if optimizer == 'SGD':
        optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    elif optimizer == 'RMSprop':
        optimizer = optim.RMSprop(net.parameters(), lr=lr, momentum=momentum)
    elif optimizer == 'Adam':
        optimizer = optim.Adam(net.parameters(), lr=lr)
    else:
        raise ValueError()
    return optimizer

def train_epoch(net, train_loader, device, optimizer):
    # set to "training" mode
    net.train()

    losses = []

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = net(data)
        loss = dice_coef_loss(output, target, smoothing=4.0)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())

    return np.mean(losses)

def validate(net, val_loader, device):
    net.eval()
    dice = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = data.to(device), target.to(device)
            output = net(data)
            dice += dice_coef(output, target).cpu().numpy() * samples
    return dice / total_samples

In [9]:
# optimizer = init_optimizer(net, optimizer='Adam', lr=1e-4)
optimizer = init_optimizer(net, optimizer='RMSprop', lr=1e-5, momentum=0.9)
# optimizer = init_optimizer(net, optimizer='SGD', lr=1e-3, momentum=0.9)

In [10]:
device = torch.device('cuda')
net = net.to(device)

In [11]:
for e in range(16):
    print('loss for epoch', e, ':', train_epoch(net, train_loader, device, optimizer))
    print('dice for epoch', e, ':', validate(net, val_loader, device))



loss for epoch 0 : 4.5720387
dice for epoch 0 : 0.10173853304651048
loss for epoch 1 : 4.463547
dice for epoch 1 : 0.10610134568479326
loss for epoch 2 : 4.399174
dice for epoch 2 : 0.10362325575616625


KeyboardInterrupt: 

In [16]:
for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    print(dice_coef(target, target))
#     optimizer.zero_grad()
#     output = net(data)
#     loss = dice_coef_loss(output, target, smoothing=4.0)
#     loss.backward()
#     optimizer.step()
#     losses.append(loss.detach().cpu().numpy())
    break

tensor(1., device='cuda:0')


In [None]:
# tumor pixels are a bell shape, starting around 30 and ending around 138