<a href="https://colab.research.google.com/github/qzlinqian/6_869_project_med_seg/blob/main/segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 0. Load data and basic setup

In [1]:
use_gdrive = True  # want to use data in my google drive

In [2]:
import os
from tqdm import tqdm

if use_gdrive:
  from google.colab import drive
  drive.mount('/content/drive')

  data_dir = "/content/drive/MyDrive/data"
else:
  data_dir = "./data"

datasets_dir = data_dir + '/Task03_Liver'

os.makedirs(datasets_dir, exist_ok=True)

training_imgs_dir = datasets_dir + '/imagesTr'
training_labels_dir = datasets_dir + '/labelsTr'
test_imgs_dir = datasets_dir + '/imagesTs'
two_d_imgs_dir = datasets_dir + '/2d/images/'
two_d_labels_dir = datasets_dir + '/2d/labels/'


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import nibabel as nib  # to read .nii.gz files
import numpy as np
import matplotlib.pyplot as plt

### To clone from repo

In [56]:
username = 'qzlinqian'
repository = '6_869_project_med_seg'
git_token =  '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'

In [None]:
!git clone https://{git_token}@github.com/{username}/{repository} temp
%cp -r temp/* .
%rm -rf temp
%rm segmentation.ipynb

Cloning into 'temp'...
remote: Enumerating objects: 16, done.[K
remote: Counting objects: 100% (16/16), done.[K
remote: Compressing objects: 100% (15/15), done.[K
remote: Total 16 (delta 4), reused 3 (delta 0), pack-reused 0[K
Unpacking objects: 100% (16/16), done.


In [None]:
from dense_unet import DenseUNet

pretrained_encoder_uri = 'https://download.pytorch.org/models/densenet121-a639ec97.pth'
num_output_classes = 3
model = DenseUNet(num_output_classes, downsample=True, pretrained_encoder_uri=pretrained_encoder_uri)


Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth


# 1. Initialize a new model

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time
import copy
import PIL 
  
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print("Using the GPU!")
else:
    print("WARNING: Could not find GPU! Using CPU only")
    print("You may want to try to use the GPU in Google Colab by clicking in:")
    print("Runtime > Change Runtime type > Hardware accelerator > GPU.")

Using the GPU!


In [5]:
# import the necessary packages
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms
from imutils import paths

### Define parameters

In [6]:
image_width = 512
image_height = 512
threshold = 0.5

### Dataset

In [7]:
from torch.utils.data import Dataset
import cv2

In [None]:
# for reference
class SegmentationDataset(Dataset):
  def __init__(self, tr_image_paths, tr_label_paths, ts_image_paths, transforms=None):
		# store the image and label filepaths
    self.tr_image_paths = tr_image_paths
    self.tr_label_paths = tr_label_paths
    self.ts_image_paths = ts_image_paths
    self.transforms = transforms

  def __len__(self):
		# return the number of total samples contained in the dataset
    return 120
  
  def __getitem__(self, idx):
		# grab the image path from the current index
    tr_image_path = self.tr_image_paths + '/liver_' + str(idx) + '.nii.gz'
    tr_label_path = self.tr_label_paths + '/liver_' + str(idx) + '.nii.gz'
		# load the image from disk, swap its channels from BGR to RGB,
		# and read the associated mask from disk in grayscale mode
    image = nib.load(tr_image_path).get_fdata().squeeze()
    label = nib.load(tr_label_path).get_fdata().squeeze()
		# check to see if we are applying any transformations
    if self.transforms is not None:
			# apply the transformations to both image and its mask
      image = self.transforms(image)
      label = self.transforms(label)
		# return a tuple of the image and its mask
    return {'image': image, 'label': label}

In [None]:
# This was just for test
class Test2DDataset(Dataset):
  def __init__(self, tr_image_paths, tr_label_paths, ts_image_paths, transforms=None):
		# store the image and label filepaths
    self.tr_image_paths = tr_image_paths
    self.tr_label_paths = tr_label_paths
    self.ts_image_paths = ts_image_paths
    self.transforms = transforms
    
    tr_image_path = self.tr_image_paths + '/liver_2.nii.gz'
    tr_label_path = self.tr_label_paths + '/liver_2.nii.gz'
    # load the image from disk, swap its channels from BGR to RGB,
    # and read the associated mask from disk in grayscale mode
    self.images = nib.load(tr_image_path).get_fdata().squeeze()
    self.labels = nib.load(tr_label_path).get_fdata().squeeze()
    for i in range(2):
      tr_image_path = self.tr_image_paths + '/liver_' + str(i) + '.nii.gz'
      tr_label_path = self.tr_label_paths + '/liver_' + str(i) + '.nii.gz'
      # load the image from disk, swap its channels from BGR to RGB,
      # and read the associated mask from disk in grayscale mode
      self.images = np.concatenate([self.images, nib.load(tr_image_path).get_fdata().squeeze()], axis=2)
      self.labels = np.concatenate([self.labels, nib.load(tr_label_path).get_fdata().squeeze()], axis=2)

  def __len__(self):
		# return the number of total samples contained in the dataset
    return self.images.shape[2]
  
  def __getitem__(self, idx):
		# grab the image path from the current index
		# check to see if we are applying any transformations
    image = self.images[:,:,idx]
    label = self.labels[:,:,idx]
    # print(image.shape)
    image = torch.from_numpy(image).unsqueeze(dim=0).float()
    label = torch.from_numpy(label).long()
    if self.transforms is not None:
			#apply the transformations to both image and its mask
      image = self.transforms(image)
      label = self.transforms(label)
		# return a tuple of the image and its mask
    return {'image': image, 'label': label}

In [76]:
class TwoDimImageDataset(Dataset):
  def __init__(self, indices, tr_image_path, tr_label_path, transforms=None):
		# store the image and label filepaths
    self.indices = indices
    self.tr_image_path = tr_image_path
    self.tr_label_path = tr_label_path
    self.transforms = transforms

  def __len__(self):
    return len(self.indices)

  def __getitem__(self, idx):
    data_index = self.indices[idx]
    image = torch.from_numpy(np.load(self.tr_image_path + str(data_index) + '.npy')).unsqueeze(dim=0).float()
    label = torch.from_numpy(np.load(self.tr_label_path + str(data_index) + '.npy')).long()
    return {'image': image, 'label': label}

In [71]:
class Block(nn.Module):
  def __init__(self, inChannels, outChannels):
    super().__init__()
    # store the convolution and RELU layers
    self.conv1 = nn.Conv2d(inChannels, outChannels, 3)
    self.batch1 = nn.BatchNorm2d(outChannels)
    self.relu = nn.ReLU(inplace=True)
    self.conv2 = nn.Conv2d(outChannels, outChannels, 3)
    self.batch1 = nn.BatchNorm2d(outChannels)
    self.relu = nn.ReLU(inplace=True)
  def forward(self, x):
    # apply CONV => RELU => CONV block to the inputs and return it
    x = self.conv1(x)
    x = self.batch1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.batch1(x)
    x = self.relu(x)
    return x

class Encoder(nn.Module):
  def __init__(self, channels=(1, 4, 16, 32, 64)):
    super().__init__()
    # store the encoder blocks and maxpooling layer
    self.encBlocks = nn.ModuleList(
      [Block(channels[i], channels[i + 1])
        for i in range(len(channels) - 1)])
    self.pool = nn.MaxPool2d(2)
  def forward(self, x):
    # initialize an empty list to store the intermediate outputs
    blockOutputs = []
    # loop through the encoder blocks
    for block in self.encBlocks:
      # pass the inputs through the current encoder block, store
      # the outputs, and then apply maxpooling on the output
      x = block(x)
      blockOutputs.append(x)
      x = self.pool(x)
    # return the list containing the intermediate outputs
    return blockOutputs

class Decoder(nn.Module):
  def __init__(self, channels=(64, 32, 16, 4)):
    super().__init__()
    # initialize the number of channels, upsampler blocks, and
    # decoder blocks
    self.channels = channels
    self.upconvs = nn.ModuleList(
      [nn.ConvTranspose2d(channels[i], channels[i + 1], 2, 2)
        for i in range(len(channels) - 1)])
    self.dec_blocks = nn.ModuleList(
      [Block(channels[i], channels[i + 1])
        for i in range(len(channels) - 1)])
  def forward(self, x, encFeatures):
    # loop through the number of channels
    for i in range(len(self.channels) - 1):
      # pass the inputs through the upsampler blocks
      x = self.upconvs[i](x)
      # crop the current features from the encoder blocks,
      # concatenate them with the current upsampled features,
      # and pass the concatenated output through the current
      # decoder block
      encFeat = self.crop(encFeatures[i], x)
      x = torch.cat([x, encFeat], dim=1)
      x = self.dec_blocks[i](x)
    # return the final decoder output
    return x
  def crop(self, encFeatures, x):
    # grab the dimensions of the inputs, and crop the encoder
    # features to match the dimensions
    (_, _, H, W) = x.shape
    encFeatures = transforms.CenterCrop([H, W])(encFeatures)
    # return the cropped features
    return encFeatures

In [10]:
class UNet(nn.Module):
  def __init__(self, encChannels=(1, 4, 8, 16, 32),
      decChannels=(32, 16, 8, 4),
      nbClasses=3, retainDim=True,
      outSize=(image_width, image_height)):
    super().__init__()
    # initialize the encoder and decoder
    self.encoder = Encoder(encChannels)
    self.decoder = Decoder(decChannels)
    # initialize the regression head and store the class variables
    self.classifier = nn.Conv2d(decChannels[-1], nbClasses, 1)
    self.softmax = nn.Softmax(dim=1)
    self.retainDim = retainDim
    self.outSize = outSize

  def forward(self, x):
    # grab the features from the encoder
    encFeatures = self.encoder(x)
    # pass the encoder features through decoder making sure that
    # their dimensions are suited for concatenation
    decFeatures = self.decoder(encFeatures[::-1][0],
      encFeatures[::-1][1:])
    # pass the decoder features through the regression head to
    # obtain the segmentation mask
    map = self.classifier(decFeatures)
    # check to see if we are retaining the original output
    # dimensions and if so, then resize the output to match them
    if self.retainDim:
      map = nn.functional.interpolate(map, self.outSize)
    # return the segmentation map
    return self.softmax(map)

# 2. Training

In [77]:
batch_size = 16
# define transformations
transforms_def = transforms.Compose([transforms.ToPILImage(),
  transforms.Resize((image_width, image_height)),
  transforms.ToTensor()])
# create the train and test datasets
indices = np.random.choice(range(7190), size=(1000), replace=False)
train_ds = TwoDimImageDataset(indices[:800], two_d_imgs_dir, two_d_labels_dir, transforms)
test_ds = TwoDimImageDataset(indices[800:], two_d_imgs_dir, two_d_labels_dir, transforms)
print(f"[INFO] found {len(train_ds)} examples in the training set...")
print(f"[INFO] found {len(test_ds)} examples in the test set...")
# create the training and test data loaders
trainLoader = DataLoader(train_ds, shuffle=True,
  batch_size=batch_size, pin_memory=True,
  num_workers=os.cpu_count())
testLoader = DataLoader(test_ds, shuffle=False,
	batch_size=batch_size, pin_memory=True,
	num_workers=os.cpu_count())

[INFO] found 800 examples in the training set...
[INFO] found 200 examples in the test set...


In [73]:
learning_rate = 1e-4
# initialize our UNet model
unet = UNet().to(device)
# initialize loss function and optimizer
opt = optim.Adam(unet.parameters(), lr=learning_rate)
# calculate steps per epoch for training and test set
train_steps = len(train_ds) // batch_size
test_steps = len(test_ds) // batch_size
# initialize a dictionary to store training history
H = {"train_loss": [], "test_loss": []}

In [74]:
weight = torch.tensor([0.1, 0.9]).to(device)
loss_function = nn.CrossEntropyLoss(weight=weight, reduction='sum')
# def my_loss_function(output, labels):
#   selected_output = torch.reshape(output, (output.shape[0], output.shape[1], output.shape[2]*output.shape[3]))
#   selected_output = torch.transpose(selected_output, 0, 1)
#   selected_labels = torch.reshape(labels, (labels.shape[0], labels.shape[1]*output.shape[2]))
#   selected_output = selected_output[:,torch.where(selected_labels > 0.1)]
#   selected_labels = labels[torch.where(labels > 0.1)]
#   selected_output = torch.transpose(selected_output, 0, 1)
#   return loss_function(selected_output, selected_labels)
def my_dice_loss(preds, labels):
  smooth = 1
  dice = 2 * (torch.mul(pred, labels)).sum(dim=0).sum(dim=0).sum(dim=0) / ((pred.pow(2) + labels.pow(2)).sum(dim=0).sum(dim=0).sum(dim=0) + smooth)
  return torch.clamp((1 - dice).mean(), 0, 1)

class DiceLoss(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        
        # pred.shape is [batch_size,1,1,512,512], target.shape is [batch_size,1,512,512]
        pred = pred.squeeze(dim=1)

        smooth = 1

        # dice
        dice = 2 * torch.sum(torch.mul(pred, target)) / (torch.sum(pred.pow(2)) +
                                            torch.sum(target.pow(2)) + smooth)

        # 返回的是dice距离,torch.clamp(input, min, max, out=None),讲张量截断在[0,1]区间里 
        return torch.clamp((1 - dice).mean(), 0, 1)

# loss_function = DiceLoss()

In [78]:
# loop over epochs
num_epochs = 15
print("[INFO] training the network...")
startTime = time.time()
best_acc = 0.0
train_loss_history = []
val_loss_history = []
epoch = 0
for e in tqdm(range(num_epochs)):
  # set the model in training mode
  unet.train()
  # initialize the total training and validation loss
  total_train_loss = 0
  total_test_loss = 0
  # loop over the training set
  with torch.set_grad_enabled(True):
    train_acc = 0
    test_acc = 0
    for i, map in enumerate(trainLoader):
      # send the input to the device
      x, y = map['image'].to(device), map['label'].to(device).squeeze()
      # perform a forward pass and calculate the training loss
      prob = unet(x)
      _, preds = torch.max(prob, 1)
      loss = loss_function(prob, y)
      # loss = my_dice_loss(preds, y)
      # first, zero out any previously accumulated gradients, then
      # perform backpropagation, and then update model parameters
      opt.zero_grad()
      loss.backward()
      opt.step()
      # add the loss to the total training loss so far
      total_train_loss += loss
      train_acc += torch.sum(preds == y) / (y.shape[0] * y.shape[1] * y.shape[2])
  # switch off autograd
  with torch.set_grad_enabled(False):
    # set the model in evaluation mode
    unet.eval()
    # loop over the validation set
    for map in testLoader:
      # send the input to the device
      x, y = map['image'].to(device), map['label'].to(device).squeeze()
      # make the predictions and calculate the validation loss
      prob = unet(x)
      _, preds = torch.max(prob, 1)
      total_test_loss += loss_function(prob, y)
      # total_test_loss += my_dice_loss(preds, y)
      test_acc += torch.sum(preds == y) / (y.shape[0] * y.shape[1] * y.shape[2])
  # calculate the average training and validation loss
  avg_train_loss = total_train_loss / train_steps
  avg_test_loss = total_test_loss / test_steps
  train_acc /= train_steps
  test_acc /= test_steps

  if test_acc > best_acc:
    best_acc = test_acc
    best_model_wts = copy.deepcopy(unet.state_dict())
    epoch = e
  train_loss_history.append(avg_test_loss)
  val_loss_history.append(avg_test_loss)
  # update our training history
  # H["train_loss"].append(avg_train_loss.cpu().detach().numpy())
  # H["test_loss"].append(avg_test_loss.cpu().detach().numpy())
  # print the model training and validation information
  print("[INFO] EPOCH: {}/{}".format(e + 1, num_epochs))
  print("Train loss: {:.6f}, Test loss: {:.4f}".format(
    avg_train_loss, avg_test_loss))
  print("Train acc: {:.6f}, Test acc: {:.4f}".format(
    train_acc, test_acc))
# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
  endTime - startTime))

[INFO] training the network...


  3%|▎         | 1/30 [01:17<37:36, 77.80s/it]

[INFO] EPOCH: 1/30
Train loss: 499531.312500, Test loss: 522729.2188
Train acc: 0.907194, Test acc: 0.9864


  7%|▋         | 2/30 [01:29<18:06, 38.80s/it]

[INFO] EPOCH: 2/30
Train loss: 471396.500000, Test loss: 524583.3750
Train acc: 0.914518, Test acc: 0.9864


 10%|█         | 3/30 [01:40<11:43, 26.04s/it]

[INFO] EPOCH: 3/30
Train loss: 445088.875000, Test loss: 527134.5625
Train acc: 0.927342, Test acc: 0.9864


 13%|█▎        | 4/30 [01:50<08:37, 19.92s/it]

[INFO] EPOCH: 4/30
Train loss: 425901.375000, Test loss: 527619.1250
Train acc: 0.946223, Test acc: 0.9864


 17%|█▋        | 5/30 [02:01<06:55, 16.64s/it]

[INFO] EPOCH: 5/30
Train loss: 417697.437500, Test loss: 528701.6875
Train acc: 0.958961, Test acc: 0.9864


 20%|██        | 6/30 [02:12<05:50, 14.59s/it]

[INFO] EPOCH: 6/30
Train loss: 411281.468750, Test loss: 528861.3750
Train acc: 0.963854, Test acc: 0.9864


 23%|██▎       | 7/30 [02:23<05:08, 13.40s/it]

[INFO] EPOCH: 7/30
Train loss: 406116.500000, Test loss: 529020.2500
Train acc: 0.965217, Test acc: 0.9864


 27%|██▋       | 8/30 [02:33<04:35, 12.52s/it]

[INFO] EPOCH: 8/30
Train loss: 400427.031250, Test loss: 529548.2500
Train acc: 0.967277, Test acc: 0.9864


 30%|███       | 9/30 [02:44<04:10, 11.92s/it]

[INFO] EPOCH: 9/30
Train loss: 395577.156250, Test loss: 530050.8750
Train acc: 0.967621, Test acc: 0.9864


 33%|███▎      | 10/30 [02:54<03:50, 11.53s/it]

[INFO] EPOCH: 10/30
Train loss: 390365.062500, Test loss: 530288.3750
Train acc: 0.968156, Test acc: 0.9864


 37%|███▋      | 11/30 [03:05<03:35, 11.36s/it]

[INFO] EPOCH: 11/30
Train loss: 385080.500000, Test loss: 529538.7500
Train acc: 0.967928, Test acc: 0.9864


 40%|████      | 12/30 [03:16<03:21, 11.20s/it]

[INFO] EPOCH: 12/30
Train loss: 379985.718750, Test loss: 530877.7500
Train acc: 0.968326, Test acc: 0.9864


 43%|████▎     | 13/30 [03:27<03:07, 11.03s/it]

[INFO] EPOCH: 13/30
Train loss: 375225.562500, Test loss: 533580.9375
Train acc: 0.968333, Test acc: 0.9864


 47%|████▋     | 14/30 [03:38<02:55, 10.97s/it]

[INFO] EPOCH: 14/30
Train loss: 369903.718750, Test loss: 536406.4375
Train acc: 0.968538, Test acc: 0.9864


 50%|█████     | 15/30 [03:48<02:43, 10.91s/it]

[INFO] EPOCH: 15/30
Train loss: 365716.312500, Test loss: 535739.0000
Train acc: 0.968237, Test acc: 0.9864


 53%|█████▎    | 16/30 [03:59<02:32, 10.86s/it]

[INFO] EPOCH: 16/30
Train loss: 361630.625000, Test loss: 536199.7500
Train acc: 0.968291, Test acc: 0.9864


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
    if w.is_alive():
  File "/

[INFO] EPOCH: 17/30
Train loss: 358143.781250, Test loss: 534937.3125
Train acc: 0.967899, Test acc: 0.9864


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho

[INFO] EPOCH: 18/30
Train loss: 354668.187500, Test loss: 536938.0625
Train acc: 0.967326, Test acc: 0.9864


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3

[INFO] EPOCH: 19/30
Train loss: 350924.500000, Test loss: 536352.2500
Train acc: 0.967498, Test acc: 0.9864


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho

[INFO] EPOCH: 20/30
Train loss: 348141.437500, Test loss: 536619.2500
Train acc: 0.966580, Test acc: 0.9864


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
Traceback (most recent call last):
    self._shutdown_workers()
  File "/usr/local/lib/python3

[INFO] EPOCH: 21/30
Train loss: 345078.156250, Test loss: 535921.4375
Train acc: 0.965474, Test acc: 0.9864


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
    self._shutdown_workers()
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child

[INFO] EPOCH: 22/30
Train loss: 341360.218750, Test loss: 537951.2500
Train acc: 0.966546, Test acc: 0.9864


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
    self._shutdown_workers()
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    self._shutdown_workers()
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    assert self._parent_pid == os.getpid(), 'can only test a child process'
    if w.is_alive():
AssertionError: can only test a child process
  File "/usr/lib/pytho

[INFO] EPOCH: 23/30
Train loss: 338288.593750, Test loss: 525769.6250
Train acc: 0.965738, Test acc: 0.9864


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4d33774680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    self._shutdown_workers()
    if w.is_alive():
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child

[INFO] EPOCH: 24/30
Train loss: 335817.406250, Test loss: 527087.4375
Train acc: 0.965184, Test acc: 0.9864


 83%|████████▎ | 25/30 [05:39<00:54, 11.00s/it]

[INFO] EPOCH: 25/30
Train loss: 332932.406250, Test loss: 526483.3125
Train acc: 0.965020, Test acc: 0.9864


 87%|████████▋ | 26/30 [05:50<00:43, 10.93s/it]

[INFO] EPOCH: 26/30
Train loss: 330105.625000, Test loss: 527394.8750
Train acc: 0.964632, Test acc: 0.9864


 90%|█████████ | 27/30 [06:01<00:32, 10.90s/it]

[INFO] EPOCH: 27/30
Train loss: 327364.406250, Test loss: 522276.0000
Train acc: 0.964417, Test acc: 0.9864


 93%|█████████▎| 28/30 [06:11<00:21, 10.87s/it]

[INFO] EPOCH: 28/30
Train loss: 325165.687500, Test loss: 519017.3438
Train acc: 0.964001, Test acc: 0.9864


 97%|█████████▋| 29/30 [06:22<00:10, 10.92s/it]

[INFO] EPOCH: 29/30
Train loss: 322398.031250, Test loss: 523728.5625
Train acc: 0.964468, Test acc: 0.9864


100%|██████████| 30/30 [06:33<00:00, 13.13s/it]

[INFO] EPOCH: 30/30
Train loss: 319513.718750, Test loss: 518585.5000
Train acc: 0.964395, Test acc: 0.9864
[INFO] total time taken to train the model: 393.86s





In [None]:
torch.sum(preds == y) / (y.shape[0] * y.shape[1] * y.shape[2])

tensor(0.8088, device='cuda:0')

In [51]:
save_dir = './models'
os.makedirs(save_dir, exist_ok=True)
torch.save(best_model_wts, os.path.join(save_dir, 'weights_best_val_acc.pt'))
torch.save(unet.state_dict(), os.path.join(save_dir, 'weights_last.pt'.format(epoch)))

In [57]:
!git clone https://{git_token}@github.com/{username}/{repository} temp
%cp -rf models temp/models2
%cd temp

!git config --global user.email "qzlinqian@126.com"
!git config --global user.name "Qian Lin"

!git add .
!git commit -m"add 2d model v2"
!git push origin main

Cloning into 'temp'...
remote: Enumerating objects: 35, done.[K
remote: Counting objects: 100% (35/35), done.[K
remote: Compressing objects: 100% (31/31), done.[K
remote: Total 35 (delta 11), reused 15 (delta 3), pack-reused 0[K
Unpacking objects: 100% (35/35), done.
/content/temp
[main f5e8bae] add 2d model v2
 2 files changed, 0 insertions(+), 0 deletions(-)
 create mode 100644 models2/weights_best_val_acc.pt
 create mode 100644 models2/weights_last.pt
Counting objects: 5, done.
Delta compression using up to 4 threads.
Compressing objects: 100% (5/5), done.
Writing objects: 100% (5/5), 226.19 KiB | 14.14 MiB/s, done.
Total 5 (delta 1), reused 0 (delta 0)
remote: Resolving deltas: 100% (1/1), completed with 1 local object.[K
To https://github.com/qzlinqian/6_869_project_med_seg
   3851a4f..f5e8bae  main -> main


In [58]:
%cd ..
%rm -rf temp

/content


# Visualize

In [79]:
vis_index = 0
vis_image_path = training_imgs_dir + '/liver_' + str(vis_index) + '.nii.gz'
vis_label_path = training_labels_dir + '/liver_' + str(vis_index) + '.nii.gz'
vis_images = nib.load(vis_image_path).get_fdata().squeeze()
vis_labels = nib.load(vis_label_path).get_fdata().squeeze()
vis_pred = np.zeros(vis_labels.shape)

In [80]:
for i in range(vis_labels.shape[2]):
  x = torch.from_numpy(vis_images[:,:,i].squeeze()).float().unsqueeze(dim=0).unsqueeze(dim=0).to(device)
  prob = unet(x)
  _, pred = torch.max(prob, 1)
  vis_pred[:,:,i] = pred.cpu().detach().numpy()

In [81]:
acc = np.sum(vis_pred == vis_labels) / vis_labels.shape[0] / vis_labels.shape[1] / vis_labels.shape[2]
print(acc)

0.9718919372558594


In [52]:
import imageio
import matplotlib.animation as animate
%matplotlib inline
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from cv2 import imread, createCLAHE # read and equalize images
from glob import glob
import h5py
# for display the MRI images in animation
from IPython.display import HTML

def create_gif(input_image, title='.gif', filename='test.gif'):
    # see example from matplotlib documentation
    images = []
    fig = plt.figure()
    for i in range(input_image.shape[2]):
        im = plt.imshow(input_image[:,:,i], animated=True)
        images.append([im])
    ani = animate.ArtistAnimation(fig, images, interval=50, blit=True, repeat_delay=1000)
    plt.title(title, fontsize=20)
    plt.axis('off')
    plt.close()
    return ani

In [None]:
def display_result(input_image, label, pred_label, title='.gif', filename='test.gif'):
    # see example from matplotlib documentation
    images = []
    fig = plt.figure()
    ax1 = fig.add_subplot(1, 2, 1)
    ax2 = fig.add_subplot(1, 2, 2)
    for i in range(input_image.shape[2]):
        x11, y11 = np.where(label[:,:,i] == 1)
        x12, y12 = np.where(label[:,:,i] == 2)
        x21, y21 = np.where(pred_label[:,:,i] == 1)
        x22, y22 = np.where(pred_label[:,:,i] == 2)
        im1 = ax1.imshow(input_image[:,:,i], animated=True)
        ax1.scatter(x11, y11, color='y')
        ax1.scatter(x12, y12, color='b')
        im2 = ax2.imshow(input_image[:,:,i], animated=True)
        ax2.scatter(x21, y21, color='y')
        ax2.scatter(x22, y22, color='b')
        images.append([im1, im2])
    ani = animate.ArtistAnimation(fig, images, interval=50, blit=True, repeat_delay=1000)
    plt.title(title, fontsize=20)
    plt.axis('off')
    plt.close()
    return ani

ani = display_result(vis_images, vis_labels, vis_pred)
HTML(ani.to_html5_video())

In [None]:
ani.save('./mymovie.mp4')

In [None]:
ani = create_gif(vis_images, title='image', filename='image.gif')
ani.save('./input.mp4')
HTML(ani.to_html5_video())

In [None]:
ani = create_gif(vis_labels, title='label', filename='label.gif')
ani.save('./labels.mp4')
HTML(ani.to_html5_video())

In [None]:
ani = create_gif(vis_pred, title='prediction', filename='prediction.gif')
ani.save('./predictions.mp4')
HTML(ani.to_html5_video())

In [None]:
vis_pred[:,:,100].max()

0.0