In [1]:
import numpy as np
import pandas as pd

from skimage.io import imread
from skimage.transform import resize

import pydicom
from pydicom.data import get_testdata_files

import os

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision.transforms as transforms
from torch.utils.data import Dataset

In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [4]:
class brains(Dataset):
    
    def __init__(self, folder='../scansmasks/', crop=True, crop_size=(256,256)):
        self.folder = folder
        self.patients = os.listdir(folder)
        self.max_dim = 24
        
        self.crop = crop
        self.crop_size = crop_size
        
        if self.crop:
            self.i, self.j = 0, 0
        
    def __len__(self):
        return len(self.patients)
    
    def pad(self, x): #never used
        dim_differs = self.max_dim - x.shape[0]
        zeros = torch.zeros_like(x)[:dim_differs,...]
        x = torch.cat([x, zeros])
        return x
    
    def get_random_crop_index(self, x):
        size=self.crop_size
        x_shape = x.shape[1:]
        i = np.random.randint(low=0, high=x_shape[0]-size[0])
        j = np.random.randint(low=0, high=x_shape[1]-size[1])
        return i, j
    
    def do_crop(self, x, index=(0,0)):
        size=self.crop_size
        return x[:,index[0]:index[0]+size[0], index[1]:index[1]+size[1]]
    
    def transform_scan(self, x):
        x = torch.FloatTensor(np.array(x, dtype=np.float64))
        x = self.pad(x)
        return x
    
    def transform_mask(self, x):
        x = torch.FloatTensor(np.array(x, dtype=np.float64))
        x = self.pad(x)
        return x
    
    def __getitem__(self, i):
        
        scan = np.load(self.folder+self.patients[i]+'/' + self.patients[i] + '_scan.npy')
        mask = np.load(self.folder+self.patients[i]+'/' + self.patients[i] + '_mask.npy')
        
        scan = self.transform_scan(scan)
        mask = self.transform_mask(mask)
        
        if self.crop:
            scan = self.do_crop(scan, index=(self.i,self.j))
            mask = self.do_crop(mask, index=(self.i,self.j))
        
        return scan[None,...], mask[None,...]

In [5]:
data = brains(crop_size=(64,64))
datait = torch.utils.data.DataLoader(data, batch_size=1, shuffle=True, drop_last=False)

In [6]:
a,b = next(iter(datait))

In [7]:
class Conv2D1(nn.Module):
    def __init__(self, in_slices=24, in_channels=1, out_channels=1):
        super(Conv2D1, self).__init__()
        self.in_slices = in_slices
        self.in_channels = in_channels
        
        self.conv_set = [nn.Conv2d(in_channels=in_channels,
                             out_channels=out_channels,
                             kernel_size=(4,4), 
                             stride=2, padding=1).cuda() for _ in range(self.in_slices)]
                
    def forward(self, x):
        x_slice = x[:,:,0,...][:,None,...].squeeze(1)
        
        
        downsampled = self.conv_set[0](x_slice).unsqueeze(2)
        
        for i in range(1,self.in_slices):
            x_slice = x[:,:,i,...][:,None,...].squeeze(1)
            downsampled_curr = self.conv_set[i](x_slice).unsqueeze(2)
            downsampled = torch.cat([downsampled, downsampled_curr],2)
            
        return downsampled

In [8]:
class ConvTranspose2D1(nn.Module):
    def __init__(self, in_slices=24, in_channels=1, out_channels=1):
        super(ConvTranspose2D1, self).__init__()
        self.in_chanels = in_slices
        
        convs2d = [nn.ConvTranspose2d(in_channels=in_channels,
                             out_channels=out_channels,
                             kernel_size=(4,4), 
                             stride=2, padding=1).cuda() for _ in range(self.in_chanels)]
        
        self.conv_set = convs2d
        
    def forward(self, x):
        
        x_slice = x[:,:,0,...][:,None,...].squeeze(1)
        downsampled = self.conv_set[0](x_slice).unsqueeze(2)
        
        for i in range(1,self.in_chanels):
            x_slice = x[:,:,i,...][:,None,...].squeeze(1)
            downsampled_curr = self.conv_set[i](x_slice).unsqueeze(2)
            downsampled = torch.cat([downsampled, downsampled_curr],2)
        return downsampled

In [9]:
class downstream_block(nn.Module):
    def __init__(self, in_chan, out_chan, kernel_dim=3, padding=1, dropout=False):
        super(downstream_block, self).__init__()
        
        block = [
            nn.Conv3d(in_channels=in_chan, out_channels=out_chan, kernel_size=kernel_dim, padding=padding),
            nn.BatchNorm3d(out_chan),
            nn.LeakyReLU()
        ]
        
        if dropout:
            block += [nn.Dropout3d()]
        
        block += [
            nn.Conv3d(in_channels=out_chan, out_channels=out_chan, kernel_size=kernel_dim, padding=padding),
            nn.BatchNorm3d(out_chan),
            nn.LeakyReLU(),
            Conv2D1(24, out_chan, out_chan).cuda()
            #nn.Conv3d(in_channels=out_chan, out_channels=out_chan, kernel_size=4, stride=2, padding=1)
        ]
        
        self.block = nn.Sequential(*block)
    
    def forward(self, x):
        return self.block(x)

In [10]:
class upstream_block(nn.Module):
    def __init__(self, dim_x, dim_y, out_dim, kernel_dim=3, padding=1, dropout=False):
        super(upstream_block, self).__init__()
        
        self.up = ConvTranspose2D1(24, dim_x, dim_x).cuda()
        #self.up = nn.ConvTranspose3d(in_channels=dim_x, out_channels=dim_x, kernel_size=4, stride=2, padding=1)
        
        block = [
            nn.Conv3d(in_channels=dim_x+dim_y, out_channels=out_dim, kernel_size=kernel_dim, padding=padding),
            nn.BatchNorm3d(out_dim),
            nn.LeakyReLU(),
        ]
        
        if dropout:
            block += [nn.Dropout3d()]

        
        block += [
            nn.Conv3d(in_channels=out_dim, out_channels=out_dim, kernel_size=kernel_dim, padding=padding),
            nn.BatchNorm3d(out_dim),
            nn.LeakyReLU(),
        ]
        
        self.block = nn.Sequential(*block)
    
    def forward(self, x, y):
        x = self.up(x)
        return self.block(torch.cat([x,y],1))

In [11]:
class UNet3D(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, dropout=False):
        super(UNet3D, self).__init__()
        
        self.inconv = nn.Sequential(
            nn.Conv3d(in_channels=input_nc, out_channels=ngf, kernel_size=3, padding=1),
            nn.BatchNorm3d(ngf),
            nn.LeakyReLU(),
            
            nn.Conv3d(in_channels=ngf, out_channels=ngf, kernel_size=3, padding=1),
            nn.BatchNorm3d(ngf),
            nn.LeakyReLU()
        )
        
        self.down1 = downstream_block(in_chan=ngf, out_chan=ngf*2, dropout=dropout)
        self.down2 = downstream_block(in_chan=ngf*2, out_chan=ngf*4, dropout=dropout)
        self.down3 = downstream_block(in_chan=ngf*4, out_chan=ngf*8, dropout=dropout)
        self.down4 = downstream_block(in_chan=ngf*8, out_chan=ngf*16, dropout=dropout)
        
        self.up1 = upstream_block(ngf*16, ngf*8, ngf*8, dropout=dropout)
        self.up2 = upstream_block(ngf*8, ngf*4, ngf*4, dropout=dropout)
        self.up3 = upstream_block(ngf*4, ngf*2, ngf*2, dropout=dropout)
        self.up4 = upstream_block(ngf*2, ngf, ngf, dropout=dropout)
        
        self.outconv = nn.Sequential(
            nn.Conv3d(in_channels=ngf, out_channels=ngf, kernel_size=3, padding=1),
            nn.BatchNorm3d(ngf),
            nn.LeakyReLU(),
            
            nn.Conv3d(in_channels=ngf, out_channels=ngf, kernel_size=3, padding=1),
            nn.BatchNorm3d(ngf),
            nn.LeakyReLU(),
            
            nn.Conv3d(in_channels=ngf, out_channels=output_nc, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
            
        x_in = self.inconv(x)
#         print(x_in.shape)
        x1 = self.down1(x_in)
#         print(x1.shape)
        
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
#         print(x4.shape)
#         print(x3.shape)
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.up4(x, x_in)
        
        return self.outconv(x)

In [12]:
baseline_unet = UNet3D(input_nc=1, output_nc=2).cuda()

In [13]:
criterion = nn.CrossEntropyLoss()
opt = torch.optim.Adam(baseline_unet.parameters(), lr=1e-3)
epoch = 0

In [14]:
while True:
    print("Epoch", epoch)
    baseline_unet.train(True)
    
    epoch_loss = []
    
    for X_batch, masks_batch in datait:
#         X_batch = X_batch[0]
        preds = baseline_unet(Variable(X_batch).cuda()).squeeze(1)
        print(preds.shape, masks_batch.shape)
        loss = criterion(preds, masks_batch)

        # train on batch
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        try:
            epoch_loss.append(loss.data.item())
        except:
            epoch_loss.append(loss.data.numpy()[0])

        train_loss.append(np.mean(epoch_loss))
            
        plt.clf()
        plt.plot(train_loss[-1000:])
        plt.title('Train loss')
        
        display.clear_output(wait=True)
        display.display(plt.gcf())

        print("Loss: {}".format(train_loss[-1]))
        
        epoch += 1

Epoch 0
torch.Size([1, 2, 24, 64, 64]) torch.Size([1, 1, 24, 64, 64])


ValueError: Expected target size (1, 24, 64, 64), got torch.Size([1, 1, 24, 64, 64])

In [None]:
aa = baseline_unet(Variable(a).cuda())

In [None]:
aa.shape