<a href="https://colab.research.google.com/github/qzlinqian/6_869_project_med_seg/blob/main/segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 0. Load data and basic setup

In [None]:
use_gdrive = True  # want to use data in my google drive

In [None]:
import os
from tqdm import tqdm

if use_gdrive:
  from google.colab import drive
  drive.mount('/content/drive', force_remount=True)

  data_dir = "/content/drive/MyDrive/data"
else:
  data_dir = "./data"

datasets_dir = data_dir + '/Task03_Liver'

os.makedirs(datasets_dir, exist_ok=True)

training_imgs_dir = datasets_dir + '/imagesTr'
training_labels_dir = datasets_dir + '/labelsTr'
test_imgs_dir = datasets_dir + '/imagesTs'
two_d_imgs_dir = datasets_dir + '/2d_data/images/'
two_d_labels_dir = datasets_dir + '/2d_data/labels/'


Mounted at /content/drive


In [None]:
import nibabel as nib  # to read .nii.gz files
import numpy as np
import matplotlib.pyplot as plt

### To clone from repo

In [None]:
username = 'qzlinqian'
repository = '6_869_project_med_seg'
git_token =  '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'

In [None]:
!git clone https://{git_token}@github.com/{username}/{repository} temp
%cp -r temp/* .
%rm -rf temp
%rm segmentation.ipynb

Cloning into 'temp'...
remote: Enumerating objects: 16, done.[K
remote: Counting objects: 100% (16/16), done.[K
remote: Compressing objects: 100% (15/15), done.[K
remote: Total 16 (delta 4), reused 3 (delta 0), pack-reused 0[K
Unpacking objects: 100% (16/16), done.


# 1. Initialize a new model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time
import copy
import PIL 
  
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print("Using the GPU!")
else:
    print("WARNING: Could not find GPU! Using CPU only")
    print("You may want to try to use the GPU in Google Colab by clicking in:")
    print("Runtime > Change Runtime type > Hardware accelerator > GPU.")

Using the GPU!


In [None]:
# import the necessary packages
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms
from imutils import paths

### Define parameters

In [None]:
image_width = 512
image_height = 512
threshold = 0.5

### Dataset

In [None]:
from torch.utils.data import Dataset
import cv2

In [None]:
# for reference
class SegmentationDataset(Dataset):
  def __init__(self, tr_image_paths, tr_label_paths, ts_image_paths, transforms=None):
		# store the image and label filepaths
    self.tr_image_paths = tr_image_paths
    self.tr_label_paths = tr_label_paths
    self.ts_image_paths = ts_image_paths
    self.transforms = transforms

  def __len__(self):
		# return the number of total samples contained in the dataset
    return 120
  
  def __getitem__(self, idx):
		# grab the image path from the current index
    tr_image_path = self.tr_image_paths + '/liver_' + str(idx) + '.nii.gz'
    tr_label_path = self.tr_label_paths + '/liver_' + str(idx) + '.nii.gz'
		# load the image from disk, swap its channels from BGR to RGB,
		# and read the associated mask from disk in grayscale mode
    image = nib.load(tr_image_path).get_fdata().squeeze()
    label = nib.load(tr_label_path).get_fdata().squeeze()
		# check to see if we are applying any transformations
    if self.transforms is not None:
			# apply the transformations to both image and its mask
      image = self.transforms(image)
      label = self.transforms(label)
		# return a tuple of the image and its mask
    return {'image': image, 'label': label}

In [None]:
# This was just for test
class Test2DDataset(Dataset):
  def __init__(self, tr_image_paths, tr_label_paths, ts_image_paths, transforms=None):
		# store the image and label filepaths
    self.tr_image_paths = tr_image_paths
    self.tr_label_paths = tr_label_paths
    self.ts_image_paths = ts_image_paths
    self.transforms = transforms
    
    tr_image_path = self.tr_image_paths + '/liver_2.nii.gz'
    tr_label_path = self.tr_label_paths + '/liver_2.nii.gz'
    # load the image from disk, swap its channels from BGR to RGB,
    # and read the associated mask from disk in grayscale mode
    self.images = nib.load(tr_image_path).get_fdata().squeeze()
    self.labels = nib.load(tr_label_path).get_fdata().squeeze()
    for i in range(2):
      tr_image_path = self.tr_image_paths + '/liver_' + str(i) + '.nii.gz'
      tr_label_path = self.tr_label_paths + '/liver_' + str(i) + '.nii.gz'
      # load the image from disk, swap its channels from BGR to RGB,
      # and read the associated mask from disk in grayscale mode
      self.images = np.concatenate([self.images, nib.load(tr_image_path).get_fdata().squeeze()], axis=2)
      self.labels = np.concatenate([self.labels, nib.load(tr_label_path).get_fdata().squeeze()], axis=2)

  def __len__(self):
		# return the number of total samples contained in the dataset
    return self.images.shape[2]
  
  def __getitem__(self, idx):
		# grab the image path from the current index
		# check to see if we are applying any transformations
    image = self.images[:,:,idx]
    label = self.labels[:,:,idx]
    # print(image.shape)
    image = torch.from_numpy(image).unsqueeze(dim=0).float()
    label = torch.from_numpy(label).long()
    if self.transforms is not None:
			#apply the transformations to both image and its mask
      image = self.transforms(image)
      label = self.transforms(label)
		# return a tuple of the image and its mask
    return {'image': image, 'label': label}

In [None]:
class TwoDimImageDataset(Dataset):
  def __init__(self, indices, tr_image_path, tr_label_path, transforms=None):
		# store the image and label filepaths
    self.indices = indices
    self.tr_image_path = tr_image_path
    self.tr_label_path = tr_label_path
    self.transforms = transforms

  def __len__(self):
    return len(self.indices)

  def __getitem__(self, idx):
    data_index = self.indices[idx]
    image = torch.from_numpy(np.load(self.tr_image_path + str(data_index) + '.npy')).float()
    label = torch.from_numpy(np.load(self.tr_label_path + str(data_index) + '.npy')).long()[1,:,:].squeeze()
    return {'image': image, 'label': label}

In [None]:
class Block(nn.Module):
  def __init__(self, inChannels, outChannels):
    super().__init__()
    # store the convolution and RELU layers
    self.conv1 = nn.Conv2d(inChannels, outChannels, 3)
    self.batch1 = nn.BatchNorm2d(outChannels)
    self.relu = nn.ReLU(inplace=True)
    self.conv2 = nn.Conv2d(outChannels, outChannels, 3)
    self.batch1 = nn.BatchNorm2d(outChannels)
    self.relu = nn.ReLU(inplace=True)
  def forward(self, x):
    # apply CONV => RELU => CONV block to the inputs and return it
    x = self.conv1(x)
    x = self.batch1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.batch1(x)
    x = self.relu(x)
    return x

class Encoder(nn.Module):
  def __init__(self, channels=(1, 16, 32, 64)):
    super().__init__()
    # store the encoder blocks and maxpooling layer
    self.encBlocks = nn.ModuleList(
      [Block(channels[i], channels[i + 1])
        for i in range(len(channels) - 1)])
    self.pool = nn.MaxPool2d(2)
  def forward(self, x):
    # initialize an empty list to store the intermediate outputs
    blockOutputs = []
    # loop through the encoder blocks
    for block in self.encBlocks:
      # pass the inputs through the current encoder block, store
      # the outputs, and then apply maxpooling on the output
      x = block(x)
      blockOutputs.append(x)
      x = self.pool(x)
    # return the list containing the intermediate outputs
    return blockOutputs

class Decoder(nn.Module):
  def __init__(self, channels=(64, 32, 16, 4)):
    super().__init__()
    # initialize the number of channels, upsampler blocks, and
    # decoder blocks
    self.channels = channels
    self.upconvs = nn.ModuleList(
      [nn.ConvTranspose2d(channels[i], channels[i + 1], 2, 2)
        for i in range(len(channels) - 1)])
    self.dec_blocks = nn.ModuleList(
      [Block(channels[i], channels[i + 1])
        for i in range(len(channels) - 1)])
  def forward(self, x, encFeatures):
    # loop through the number of channels
    for i in range(len(self.channels) - 1):
      # pass the inputs through the upsampler blocks
      x = self.upconvs[i](x)
      # crop the current features from the encoder blocks,
      # concatenate them with the current upsampled features,
      # and pass the concatenated output through the current
      # decoder block
      encFeat = self.crop(encFeatures[i], x)
      x = torch.cat([x, encFeat], dim=1)
      x = self.dec_blocks[i](x)
    # return the final decoder output
    return x
  def crop(self, encFeatures, x):
    # grab the dimensions of the inputs, and crop the encoder
    # features to match the dimensions
    (_, _, H, W) = x.shape
    encFeatures = transforms.CenterCrop([H, W])(encFeatures)
    # return the cropped features
    return encFeatures

In [None]:
class UNet(nn.Module):
  def __init__(self, encChannels=(3, 16, 32, 64),
      decChannels=(64, 32, 16),
      nbClasses=3, retainDim=True,
      outSize=(image_width, image_height)):
    super().__init__()
    # initialize the encoder and decoder
    self.encoder = Encoder(encChannels)
    self.decoder = Decoder(decChannels)
    # initialize the regression head and store the class variables
    self.classifier = nn.Conv2d(decChannels[-1], nbClasses, 1)
    self.softmax = nn.Softmax(dim=1)
    self.retainDim = retainDim
    self.outSize = outSize

  def forward(self, x):
    # grab the features from the encoder
    encFeatures = self.encoder(x)
    # pass the encoder features through decoder making sure that
    # their dimensions are suited for concatenation
    decFeatures = self.decoder(encFeatures[::-1][0],
      encFeatures[::-1][1:])
    # pass the decoder features through the regression head to
    # obtain the segmentation mask
    map = self.classifier(decFeatures)
    # check to see if we are retaining the original output
    # dimensions and if so, then resize the output to match them
    if self.retainDim:
      map = nn.functional.interpolate(map, self.outSize)
    # return the segmentation map
    return self.softmax(map)

# 2. Training

In [None]:
batch_size = 32
# define transformations
transforms_def = transforms.Compose([transforms.ToPILImage(),
  transforms.Resize((image_width, image_height)),
  transforms.ToTensor()])
# create the train and test datasets
indices = np.random.choice(range(2556), size=(2000), replace=False)
train_ds = TwoDimImageDataset(indices[:1600], two_d_imgs_dir, two_d_labels_dir, transforms)
test_ds = TwoDimImageDataset(indices[1600:], two_d_imgs_dir, two_d_labels_dir, transforms)
print(f"[INFO] found {len(train_ds)} examples in the training set...")
print(f"[INFO] found {len(test_ds)} examples in the test set...")
# create the training and test data loaders
trainLoader = DataLoader(train_ds, shuffle=True,
  batch_size=batch_size, pin_memory=True,
  num_workers=os.cpu_count())
testLoader = DataLoader(test_ds, shuffle=False,
	batch_size=batch_size, pin_memory=True,
	num_workers=os.cpu_count())

[INFO] found 1600 examples in the training set...
[INFO] found 400 examples in the test set...


In [None]:
train_ds[0]['image'].shape

torch.Size([3, 512, 512])

In [None]:
learning_rate = 1e-3
# initialize our UNet model
# unet = UNet().to(device)
unet = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=3, init_features=16, pretrained=False).to(device)
# initialize loss function and optimizer
opt = optim.Adam(unet.parameters(), lr=learning_rate)
# calculate steps per epoch for training and test set
train_steps = len(train_ds) // batch_size
test_steps = len(test_ds) // batch_size
# initialize a dictionary to store training history
H = {"train_loss": [], "test_loss": []}

Downloading: "https://github.com/mateuszbuda/brain-segmentation-pytorch/archive/master.zip" to /root/.cache/torch/hub/master.zip


In [None]:
weight = torch.tensor([1., 1., 1.]).to(device)
loss_function = nn.CrossEntropyLoss() #(weight=weight)
# def my_loss_function(output, labels):
#   selected_output = torch.reshape(output, (output.shape[0], output.shape[1], output.shape[2]*output.shape[3]))
#   selected_output = torch.transpose(selected_output, 0, 1)
#   selected_labels = torch.reshape(labels, (labels.shape[0], labels.shape[1]*output.shape[2]))
#   selected_output = selected_output[:,torch.where(selected_labels > 0.1)]
#   selected_labels = labels[torch.where(labels > 0.1)]
#   selected_output = torch.transpose(selected_output, 0, 1)
#   return loss_function(selected_output, selected_labels)
def my_dice_loss(preds, labels):
  smooth = 1
  dice = 2 * (torch.mul(pred, labels)).sum(dim=0).sum(dim=0).sum(dim=0) / ((pred.pow(2) + labels.pow(2)).sum(dim=0).sum(dim=0).sum(dim=0) + smooth)
  return torch.clamp((1 - dice).mean(), 0, 1)

In [None]:
# loop over epochs
num_epochs = 5
# alpha = 1.0 # control learning rate
print("[INFO] training the network...")
startTime = time.time()
best_acc = 0.0
train_loss_history = []
val_loss_history = []
epoch = 0
for e in tqdm(range(num_epochs)):
  # set the model in training mode
  unet.train()
  # initialize the total training and validation loss
  total_train_loss = 0
  total_test_loss = 0
  train_acc = 0
  test_acc = 0
  # loop over the training set
  with torch.set_grad_enabled(True):
    for i, map in enumerate(trainLoader):
      # send the input to the device
      x, y = map['image'].to(device), map['label'].to(device).squeeze()
      # perform a forward pass and calculate the training loss
      prob = unet(x)
      _, preds = torch.max(prob, 1)
      loss = loss_function(prob, y)
      total_train_loss += loss
      # loss *= alpha
      # loss = my_dice_loss(preds, y)
      # first, zero out any previously accumulated gradients, then
      # perform backpropagation, and then update model parameters
      opt.zero_grad()
      loss.backward()
      opt.step()
      # add the loss to the total training loss so far
      train_acc += torch.sum(preds == y) / (y.shape[0] * y.shape[1] * y.shape[2])
  # switch off autograd
  with torch.set_grad_enabled(False):
    # set the model in evaluation mode
    unet.eval()
    # loop over the validation set
    for map in testLoader:
      # send the input to the device
      x, y = map['image'].to(device), map['label'].to(device).squeeze()
      # make the predictions and calculate the validation loss
      prob = unet(x)
      _, preds = torch.max(prob, 1)
      total_test_loss += loss_function(prob, y)
      # total_test_loss += my_dice_loss(preds, y)
      test_acc += torch.sum(preds == y) / (y.shape[0] * y.shape[1] * y.shape[2])
  # calculate the average training and validation loss
  avg_train_loss = total_train_loss / train_steps
  avg_test_loss = total_test_loss / test_steps
  train_acc /= train_steps
  test_acc /= test_steps

  if test_acc > best_acc:
    best_acc = test_acc
    best_model_wts = copy.deepcopy(unet.state_dict())
    epoch = e
  train_loss_history.append(avg_test_loss)
  val_loss_history.append(avg_test_loss)
  # update our training history
  # H["train_loss"].append(avg_train_loss.cpu().detach().numpy())
  # H["test_loss"].append(avg_test_loss.cpu().detach().numpy())
  # print the model training and validation information
  print("[INFO] EPOCH: {}/{}".format(e + 1, num_epochs))
  print("Train loss: {:.6f}, Test loss: {:.4f}".format(
    avg_train_loss, avg_test_loss))
  print("Train acc: {:.6f}, Test acc: {:.4f}".format(
    train_acc, test_acc))
  # if e % 10 is 0 and e is not 0:
    # alpha *= 0.8
# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
  endTime - startTime))

[INFO] training the network...


  0%|          | 0/5 [00:48<?, ?it/s]


KeyboardInterrupt: ignored

Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 224, in _feed
    nwait()
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


In [None]:
prob[:,2,:,:].max()

tensor(0.9743, device='cuda:0')

In [None]:
torch.sum(preds == y) / (y.shape[0] * y.shape[1] * y.shape[2])

tensor(0.9887, device='cuda:0')

In [None]:
save_dir = './models'
os.makedirs(save_dir, exist_ok=True)
torch.save(best_model_wts, os.path.join(save_dir, 'weights_best_val_acc.pt'))
torch.save(unet.state_dict(), os.path.join(save_dir, 'weights_last.pt'.format(epoch)))

In [None]:
!git clone https://{git_token}@github.com/{username}/{repository} temp
%cp -rf models temp
%cp predictions.mp4 temp
%cp labels.mp4 temp
%cp input.mp4 temp
%cd temp

!git config --global user.email "qzlinqian@126.com"
!git config --global user.name "Qian Lin"

!git add .
!git commit -m"2d model with 3 classes"
!git push origin main

Cloning into 'temp'...
remote: Enumerating objects: 62, done.[K
remote: Counting objects: 100% (62/62), done.[K
remote: Compressing objects: 100% (55/55), done.[K
remote: Total 62 (delta 17), reused 39 (delta 6), pack-reused 0[K
Unpacking objects: 100% (62/62), done.
/content/temp
[main cefa89f] 2d model with 3 classes
 5 files changed, 0 insertions(+), 0 deletions(-)
 rewrite input.mp4 (98%)
 rewrite labels.mp4 (94%)
 rewrite predictions.mp4 (99%)
Counting objects: 8, done.
Delta compression using up to 4 threads.
Compressing objects: 100% (8/8), done.
Writing objects: 100% (8/8), 13.87 MiB | 10.57 MiB/s, done.
Total 8 (delta 0), reused 0 (delta 0)
To https://github.com/qzlinqian/6_869_project_med_seg
   4d55d57..cefa89f  main -> main


In [None]:
%rm -rf temp

In [None]:
%cd ..
#%rm -rf temp

/content


# Visualize

In [None]:
vis_index = 0
vis_image_path = training_imgs_dir + '/liver_' + str(vis_index) + '.nii.gz'
vis_label_path = training_labels_dir + '/liver_' + str(vis_index) + '.nii.gz'
vis_images = nib.load(vis_image_path).get_fdata().squeeze()
vis_labels = nib.load(vis_label_path).get_fdata().squeeze()

In [None]:
vis_pred = np.zeros(vis_labels.shape)
for i in range(vis_labels.shape[2]-3):
  x = torch.from_numpy(vis_images[:,:,i:i+3].transpose([2,0,1])).float().unsqueeze(dim=0).to(device)
  prob = unet(x)
  _, pred = torch.max(prob, 1)
  vis_pred[:,:,i] = pred.cpu().detach().numpy()

In [None]:
acc = np.sum(vis_pred == vis_labels) / vis_labels.shape[0] / vis_labels.shape[1] / vis_labels.shape[2]
print(acc)

0.9612523905436198


In [None]:
import imageio
import matplotlib.animation as animate
%matplotlib inline
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from cv2 import imread, createCLAHE # read and equalize images
from glob import glob
import h5py
# for display the MRI images in animation
from IPython.display import HTML

def create_gif(input_image, title='.gif', filename='test.gif'):
    # see example from matplotlib documentation
    images = []
    fig = plt.figure()
    for i in range(input_image.shape[2]):
        im = plt.imshow(input_image[:,:,i], animated=True)
        images.append([im])
    ani = animate.ArtistAnimation(fig, images, interval=50, blit=True, repeat_delay=1000)
    plt.title(title, fontsize=20)
    plt.axis('off')
    plt.close()
    return ani

In [None]:
def display_result(input_image, label, pred_label, title='.gif', filename='test.gif'):
    # see example from matplotlib documentation
    images = []
    fig = plt.figure()
    ax1 = fig.add_subplot(1, 2, 1)
    ax2 = fig.add_subplot(1, 2, 2)
    for i in range(input_image.shape[2]):
        x11, y11 = np.where(label[:,:,i] == 1)
        x12, y12 = np.where(label[:,:,i] == 2)
        x21, y21 = np.where(pred_label[:,:,i] == 1)
        x22, y22 = np.where(pred_label[:,:,i] == 2)
        im1 = ax1.imshow(input_image[:,:,i], animated=True)
        ax1.scatter(x11, y11, color='y')
        ax1.scatter(x12, y12, color='b')
        im2 = ax2.imshow(input_image[:,:,i], animated=True)
        ax2.scatter(x21, y21, color='y')
        ax2.scatter(x22, y22, color='b')
        images.append([im1, im2])
    ani = animate.ArtistAnimation(fig, images, interval=50, blit=True, repeat_delay=1000)
    plt.title(title, fontsize=20)
    plt.axis('off')
    plt.close()
    return ani

ani = display_result(vis_images, vis_labels, vis_pred)
HTML(ani.to_html5_video())

In [None]:
ani = create_gif(vis_images, title='image', filename='image.gif')
ani.save('./input.mp4')
HTML(ani.to_html5_video())

In [None]:
ani = create_gif(vis_labels, title='label', filename='label.gif')
ani.save('./labels.mp4')
HTML(ani.to_html5_video())

In [None]:
ani = create_gif(vis_pred, title='prediction', filename='prediction.gif')
ani.save('./predictions.mp4')
HTML(ani.to_html5_video())

In [None]:
ani.save('./predictions.mp4')

In [None]:
vis_pred.max()

2.0

# Fake 3D