In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class down_conv(nn.Module):
    def __init__(self,inc,l):
        super(down_conv, self).__init__()
        self.conv1 = nn.Conv3d(in_channels=inc, out_channels=l, kernel_size=(3,3,3))
        self.conv2 = nn.Conv3d(in_channels=l, out_channels=l*2, kernel_size=(3,3,3))
        self.BN = nn.BatchNorm3d(num_features=l*2)
    
    def forward(self,x):
        y = self.conv1(x)
        z = F.relu(self.conv2(y))
        
        return(self.BN(z))

class up_conv(nn.Module):    
    def __init__(self,inc,l):
        super(up_conv, self).__init__()
        self.conv1 = nn.Conv3d(in_channels=inc, out_channels=l, kernel_size=(3,3,3))
        self.conv2 = nn.Conv3d(in_channels=l, out_channels=l, kernel_size=(3,3,3))
        self.BN = nn.BatchNorm3d(num_features=l)
        
    def forward(self,x):   
        y = self.conv1(x)
        z = F.relu(self.conv2(y))
        
        return(self.BN(z))

class up_sample(nn.Module):    
    def __init__(self,inc,l):
        super(up_sample, self).__init__()
        self.conv_t = nn.ConvTranspose3d(in_channels=inc, out_channels=l, kernel_size=(2,2,2),stride=2)
    
    def forward(self,x):
        return(self.conv_t(x))

class conv_final(nn.Module):
    def __init__(self,l):
        super(conv_final, self).__init__()
        self.conv = nn.ConvTranspose3d(in_channels=l, out_channels=1, kernel_size=(1,1,1))
      
    def forward(self,x):
        return(F.sigmoid(self.conv(x)))

class pooling(nn.Module):
    def __init__(self):
        super(pooling,self).__init__()
        self.pool = nn.MaxPool3d(kernel_size=(2,2,2), stride=2)
        
    def forward(self,x):
        return(self.pool(x))

class UNet(nn.Module):
    def __init__(self,depth):
        super(UNet, self).__init__()
        
        if depth == 1:
            print("try higher depth")
        else:
            encoder=[]
            i=0
            while i < depth:
                l = 2**(5+i)
                
                if i == 0:
                    inc = 1
                else:
                    inc = 2**(4+i)
                    encoder.append(pooling())
                    
                encoder.append(down_conv(inc,l))
                i += 1
                
            self.encoder=nn.ModuleList(encoder)
            
            decoder=[]
            l *= 2
            j = depth
            while j > 1:
                inc = l
                l = 2**(4+j)
                decoder.append(up_sample(inc,l))
                
                decoder.append(up_conv(l,l))
                
                j -= 1
            
            decoder.append(conv_final(l))
            self.decoder = nn.ModuleList(decoder)

    
    def forward(self, x):
        y=x
        encoding=[y]
        for _ in self.encoder:
            y=_(y)
            encoding.append(y)
            
        for i,j in enumerate(self.decoder()):
            if i%2!=0:
                y+=encoding[len(encoding)-(i+2)]
            y=j(y)
        return(y)

In [8]:
unet=UNet(3).cuda()
optimizer=torch.optim.SGD(unet.parameters(),lr=0.05)
loss=nn.BCELoss().cuda()

In [9]:
def train(epoch,x_train,y_train):
    unet.train()
    tr_loss=0
    
    x_train=torch.autograd.Variable(x_train).cuda()
    y_train=torch.autograd.Variable(y_train).cuda()
   
    optimizer.zero_grad()
    out_train=unet(x_train.float())
    
    loss_train=loss(out_train,y_train.float())
    
    loss_train.backward()
    optimizer.step()
    
    print('Epoch: ',epoch,'/t','acc: ',accuracy_score(y_train.detach().cpu(), np.argmax(list(out_train.detach().cpu().numpy()),axis=1)),'/t','loss: ',loss_train)

In [10]:
import os
import pydicom
import numpy as np
def read(f1):
    img=[]
    imgs={}
    for i in os.listdir(f1):
        ds=pydicom.dcmread(f1+"/"+i)
        img.append(ds)
    if len(img)!=0:
        imgs[ds.SeriesDescription]=img
    [_.sort(key = lambda x: int(x.InstanceNumber)) for _ in imgs.values()]

    image = np.stack([_.pixel_array for _ in imgs[ds.SeriesDescription]]).astype(np.int16)
    return(np.array(image, dtype=np.int16))