In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def ELU(elu, nchan):
    if elu:
        return nn.ELU(inplace=True)
    else:
        return nn.PReLU(nchan)
    
def n_conv(nchan, depth, elu):
    layers = []
    for _ in range(depth):
        layers.append(single_conv(nchan, elu))
    return nn.Sequential(*layers)


class single_conv(nn.Module):
    def __init__(self, nchan, elu):
        super(single_conv, self).__init__()
        self.relu = ELU(elu, nchan)
        self.conv = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
        self.bn = nn.BatchNorm3d(nchan)

    def forward(self, x):
        out = self.relu(self.bn(self.conv(x)))
        return out


In [3]:
class input_layer(nn.Module):
    def __init__(self, in_ch, out_ch, elu):
        super(input_layer, self).__init__()

        self.conv = nn.Conv3d(in_ch,out_ch,kernel_size=5,padding=2)
        self.relu = ELU(elu, out_ch)
        self.bn = nn.BatchNorm3d(out_ch)
        
    def forward(self, x):
        # 1. Convolve
        out = self.conv(x)
        # 2. Normalize
        out = self.bn(out)
        # 3. Concat output accross 16 channels
        x16 = torch.cat((x, x, x, x, x, x, x, x, x, x, x, x, x, x, x, x), 0)
        # 4. Add concatenated output and input
        out = torch.add(out, x16)
        # 5. Activation
        out = self.relu(out)
        return out
    

In [4]:
class down_layer(nn.Module):
    def __init__(self, in_ch, nConv, elu):
        super(down_layer, self).__init__()
        out_ch = 2*in_ch
        self.down_conv = nn.Conv3d(in_ch, out_ch, kernel_size=2, stride=2)
        self.bn = nn.BatchNorm3d(out_ch)
        self.relu = ELU(elu, out_ch)
        self.layers = n_conv(out_ch, nConv, elu)
        
    def forward(self, x):
        down = self.relu1(self.bn(self.down_conv(x)))
        out = self.layers(out)
        out = self.relu2(torch.add(out, down))
        return down, out

In [5]:
class up_layer(nn.Module):
    def __init__(self, in_ch, out_ch, nConv, elu):
        super(up_layer, self).__init__()
        self.up_conv = nn.ConvTranspose3d(in_ch, out_ch // 2, kernel_size=2, stride=2)
        self.bn = nn.BatchNorm3d(out_ch // 2)
        self.relu1 = ELU(elu, out_ch // 2)
        self.relu2 = ELU(elu, out_ch)
        self.layers = n_conv(out_ch, nConv, elu)

    def forward(self, x, skipx):
        out = self.relu1(self.bn(self.up_conv(out)))
        xcat = torch.cat((out, skipx), 1)
        out = self.layers(xcat)
        out = self.relu2(torch.add(out, xcat))
        return out

In [6]:
class output_layer(nn.Module):
    def __init__(self, in_ch, elu, nll):
        super(output_layer, self).__init__()
        self.conv1 = nn.Conv3d(in_ch, 2, kernel_size=5, padding=2)
        self.bn = nn.BatchNorm3d(2)
        self.conv2 = nn.Conv3d(2, 2, kernel_size=1)
        self.relu1 = ELU(elu, 2)
        if nll:
            self.softmax = F.log_softmax
        else:
            self.softmax = F.softmax

    def forward(self, x):
        out = self.relu1(self.bn(self.conv1(x)))
        out = self.conv2(out)
        out = out.permute(0, 2, 3, 4, 1).contiguous()
        out = out.view(out.numel() // 2, 2)
        out = self.softmax(out)
        return out

In [7]:
class VNet(nn.Module):
    
    def __init__(self, elu=True, nll=False):
        super(VNet, self).__init__()
        #In
        self.input = input_layer(1, 16, elu)
        
        #Down
        self.down32 = down_layer(16, 2, elu)
        self.down64 = down_layer(32, 3, elu)
        self.down128 = down_layer(64, 3, elu)
        self.down256 = down_layer(128, 3, elu)
        
        #Up
        self.up256 = up_layer(256,256, 3, elu)
        self.up128 = up_layer(128,128, 3, elu)
        self.up64 = up_layer(64,64, 2, elu)
        self.up32 = up_layer(32,32, 1, elu)
        
        #Out
        self.output = output_layer(16, elu, nll)
    
    def forward(self, x):
        #Layer 1: In
        out16 = self.input(x)
        
        #Layer 2 : Down ( 2 conv layers deep)
        d_32, out32 = self.down32(out16)
        
        #Layer 3 : Down ( 3 conv layers deep)
        d_64, out64 = self.down64(out32)
        
        #Layer 4 : Down ( 3 conv layers deep)
        d_128, out128 = self.down128(out64)      
        
        #Layer 5 : Down ( 3 conv layers deep)
        d_256, out256 = self.down256(out128)
        
        #Layer 5 : up ( 3 conv layers deep)
        output = self.up256(out256, out128)
        
        #Layer 4 : up ( 3 conv layers deep)
        output = self.up128(output, out64)
        
        #Layer 3 : up ( 3 conv layers deep)
        output = self.up64(output, out32)
        
        #Layer 2 : up ( 2 conv layers deep)
        output = self.up32(output, out16)
        
        #Layer 1 : out
        output = self.output(output)
        
        return output

In [8]:
net = VNet(elu=True, nll=False)
#print(model)

In [9]:
n_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('Number of parameters in network: ', n_params)

Number of parameters in network:  65096700


In [10]:
import glob
import numpy as np
import nibabel as nib
import os
import pickle
from torch.utils.data import Dataset
from torchvision import transforms
from torch.autograd import Function, Variable
from tqdm import tqdm

# get all the image and mask path and number of images
image_paths = glob.glob("train/*.gz")
mask_paths = glob.glob("train_masks/*.gz")

# split these path using a certain percentage
for path in image_paths:
    img = nib.load(path)
    data = img.get_fdata()
    print('Data: ',data.shape)
    break
    
for path in mask_paths:
    img = nib.load(path)
    data = img.get_fdata()
    print('Mask: ',data.shape)
    break

This code is using an older version of pydicom, which is no longer 
maintained as of Jan 2017.  You can access the new pydicom features and API 
by installing `pydicom` from PyPI.
See 'Transitioning to pydicom 1.x' section at pydicom.readthedocs.org 
for more information.



Data:  (144, 512, 512)
Mask:  (144, 512, 512)


In [11]:
def split_train_val(image_paths, mask_paths, train_size):
    img_paths_dic = {}
    mask_paths_dic = {}
    len_data = len(image_paths)
    print('total len:', len_data)
    for i in range(len(image_paths)):
        img_paths_dic[os.path.basename(image_paths[i])[8:]] = image_paths[i]

    for i in range(len(mask_paths)):
        mask_paths_dic[os.path.basename(mask_paths[i])[19:]] = mask_paths[i]
        
    img_mask_list = []
    #print(img_paths_dic)
    #print(mask_paths_dic)
    for key in img_paths_dic:
        img_mask_list.append((img_paths_dic[key], mask_paths_dic[key]))
        
    train_img_mask_paths = img_mask_list[:int(len_data*train_size)] 
    val_img_mask_paths = img_mask_list[int(len_data*train_size):]
    return train_img_mask_paths, val_img_mask_paths
    #return img_mask_list
    
    
def preprocess_image(image_mask_paths):
    img_mask_list = []
    #new_h, new_w = 80, 100
    
    for i in tqdm(range(len(image_mask_paths))):
        vol = nib.load(image_mask_paths[i][0])
        m = nib.load(image_mask_paths[i][1])
        img = np.array(vol.dataobj, np.float32) / 255.0 
        mask = np.array(m.dataobj,np.uint8)
        
        #print(img)
        # Use cv2 to resize images to 80x100, use INTER_CUBIC interpolation
        #img_resize = cv2.resize(img, dsize=(new_w, new_h), interpolation=cv2.INTER_CUBIC)
        #mask_resize = np.uint8(cv2.resize(mask, dsize=(new_w, new_h), interpolation=cv2.INTER_CUBIC))
        
        img_mask_list.append((img, mask))
    return img_mask_list

In [12]:
train_img_mask_paths, val_img_mask_paths = split_train_val(image_paths, mask_paths, 0.95)


def pickle_store(file_name,save_data):
    fileObj = open(file_name,'wb')
    pickle.dump(save_data,fileObj)
    fileObj.close()

train_img_masks_save_path = './train_img_masks.pickle'
if os.path.exists(train_img_masks_save_path):
    with open(train_img_masks_save_path,'rb') as f:
        train_img_masks = pickle.load(f)
    f.close()
else:
    train_img_masks = preprocess_image(train_img_mask_paths)
    pickle_store(train_img_masks_save_path,train_img_masks)
print('train len: {}'.format(len(train_img_masks)))

total len: 7
train len: 6


In [13]:
# For validation data!!!!!


val_img_masks_save_path = '.val_img_masks.pickle'
if os.path.exists(val_img_masks_save_path):
    with open(val_img_masks_save_path,'rb') as f:
        val_img_masks = pickle.load(f)
    f.close()
else:
    val_img_masks = preprocess_image(val_img_mask_paths)
    pickle_store(val_img_masks_save_path,val_img_masks)
print('val len: {}'.format(len(val_img_masks)))

val len: 1


In [22]:
class ToTensor(object):
    """
    Convert ndarrays in sample to Tensors.
    """
    def __init__(self):
        pass

    def __call__(self, sample):
        image, label = sample['img'], sample['label']
        image = image[None,:,:]
        label = label[None,:,:]
        return {'img': torch.from_numpy(image.copy()).type(torch.FloatTensor),
                'label': torch.from_numpy(label.copy()).type(torch.FloatTensor)}

In [23]:
class CustomDataset(Dataset):
    def __init__(self, image_masks, transforms=None): 

        self.image_masks = image_masks
        self.transforms = transforms
    
    def __len__(self):  # return count of sample we have

        return len(self.image_masks)
    
    def __getitem__(self, index):

        image = self.image_masks[index][0] # H, W, C
        mask = self.image_masks[index][1]
        
#         image = np.transpose(image, axes=[2, 0, 1]) # C, H, W
        
        sample = {'img': image, 'label': mask}
        
        if transforms:
            sample = self.transforms(sample)
            
        return sample

train_dataset = CustomDataset(train_img_masks, transforms=transforms.Compose([ToTensor()]))
val_dataset = CustomDataset(val_img_masks, transforms=transforms.Compose([ToTensor()]))

In [24]:

# define dice coefficient 
class DiceCoeff(Function):
    """Dice coeff for one pair of input image and target image"""
    def forward(self, prediction, target):
        self.save_for_backward(prediction, target)
        eps = 0.0001 # in case union = 0
        # Calculate intersection and union. 
        # You can convert the input image into a vector with input.contiguous().view(-1)
        # Then use torch.dot(A, B) to calculate the intersection.
        A = prediction.view(-1)
        B = target.view(-1)
        inter = torch.dot(A.float(),B.float())
        union = torch.sum(A.float()) + torch.sum(B.float()) - inter + eps
        # Calculate DICE 
        d = inter / union
        return d

# Calculate dice coefficients for batches
def dice_coeff(prediction, target):
    """Dice coeff for batches"""
    s = torch.FloatTensor(1).zero_()
    
    # For each pair of input and target, call DiceCoeff().forward(prediction, target) to calculate dice coefficient
    # Then average
    for i, (a,b) in enumerate(zip(prediction, target)):
        s += DiceCoeff().forward(a,b)
    s = s / (i + 1)
    return s


In [25]:
def eval_net(net, dataset):
    # set net mode to evaluation
    net.eval()
    tot = 0
    for i, b in enumerate(dataset):
        img = b['img'].to(device)
        B = img.shape[0]
        true_mask = b['label'].to(device)
        # Feed the image to the network to get predicted mask
        mask_pred = net(img.float())
        
        # For all pixels in predicted mask, set them to 1 if larger than 0.5. Otherwise set them to 0
        mask_pred = mask_pred > 0.5
        # calculate dice_coeff()
        # note that you should add all the dice_coeff in validation/testing dataset together 
        # call dice_coeff() here
        tot += dice_coeff(true_mask,mask_pred)
        # Return average dice_coeff()
    return tot / (i + 1)

In [26]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [27]:
from torch import optim
epochs = 5 # e.g. 10, or more until dice converge
batch_size = 1 # e.g. 16
lr = 0.01        # e.g. 0.01
N_train = len(train_img_masks)
model_save_path = './model/'  # directory to same the model after each epoch. 

optimizer = optim.SGD(net.parameters(),lr = lr,momentum=0.9, weight_decay=0.0005)

criterion = nn.BCELoss()

# Start training
for epoch in range(epochs):
    print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
    net.train()
    # Reload images and masks for training and validation and perform random shuffling at the begining of each epoch
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    
    epoch_loss = 0
    count = 0

    for i, b in enumerate(train_loader):
        # Get images and masks from each batch
        
        imgs = b['img'].to(device)
        true_masks = b['label'].to(device)
        
        # Feed your images into the network
        
        masks_pred = net.forward(imgs.float())
        #masks_pred = nn.functional.interpolate(masks_pred, size=(80,100), mode='bilinear')
        # Flatten the predicted masks and true masks. For example, A_flat = A.view(-1)
        masks_probs_flat = masks_pred.view(-1)
        true_masks_flat = true_masks.view(-1)
        # Calculate the loss by comparing the predicted masks vector and true masks vector
        # And sum the losses together 
        loss = criterion(masks_probs_flat,true_masks_flat.float())
        epoch_loss += loss.item()
        if count % 50 == 0:
            print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item()))
        count = count + 1
        # optimizer.zero_grad() clears x.grad for every parameter x in the optimizer. 
        # It’s important to call this before loss.backward(), otherwise you’ll accumulate the gradients from multiple passes.
        optimizer.zero_grad()
        # loss.backward() computes dloss/dx for every parameter x which has requires_grad=True. 
        # These are accumulated into x.grad for every parameter x
        loss.backward()
        # optimizer.step updates the value of x using the gradient x.grad.
        optimizer.step()
    print('Epoch finished ! Loss: {}'.format(epoch_loss / i))
    
    # Perform validation with eval_net() on the validation data
    val_dice = eval_net(net,val_loader)
    print('Validation Dice Coeff: {}'.format(val_dice))
    # Save the model after each epoch
    if os.path.isdir(model_save_path):
        torch.save(net.state_dict(),model_save_path + 'Car_Seg_Epoch{}.pth'.format(epoch + 1))
    else:
        os.makedirs(model_save_path, exist_ok=True)
        torch.save(net.state_dict(),model_save_path + 'Car_Seg_Epoch{}.pth'.format(epoch + 1))
    print('Checkpoint {} saved !'.format(epoch + 1))


Starting epoch 1/5.


RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 18874368000 bytes. Buy new RAM!
