In [None]:
!pip install nilearn
!pip install monai

In [2]:
from google.colab import drive

drive.mount('/content/drive')
data_path = r'/content/drive/MyDrive/Deep Learning Project/training'

import os
from nibabel.testing import data_path
import nilearn.image
import numpy as np
from nilearn.image import resample_img
import nibabel as nib
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import pandas as pd
import monai
import torch
from tqdm import tqdm
import scipy
import random
import cv2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'The used device is {device}')

file_path = r'/content/drive/MyDrive/Deep Learning Project/training'
patient_files = [name for name in os.listdir(file_path) if os.path.isdir(os.path.join(file_path, name))]

Mounted at /content/drive
The used device is cpu


In [5]:
# Function for zero-padding the images to a certain shape
def to_shape(a, shape):
    y_, x_ = shape
    y, x = a.shape
    y_pad = (y_-y)
    x_pad = (x_-x)
    return np.pad(a,((y_pad//2, y_pad//2 + y_pad%2), 
                     (x_pad//2, x_pad//2 + x_pad%2)),
                  mode = 'constant')

# Function for normalizing the histograms of each image    
def image_histogram_equalization(image, number_bins=256):
    # from http://www.janeriksolem.net/histogram-equalization-with-python-and.html

    # get image histogram
    image_histogram, bins = np.histogram(image.flatten(), number_bins, density=True)
    cdf = image_histogram.cumsum() # cumulative distribution function
    cdf = 255 * cdf / cdf[-1] # normalize

    # use linear interpolation of cdf to find new pixel values
    image_equalized = np.interp(image.flatten(), bins[:-1], cdf)

    return image_equalized.reshape(image.shape), cdf

# Function for rotating the images
def rotate_CV(image, angel , interpolation):

    h,w = image.shape[:2]
    cX,cY = (w//2,h//2)
    M = cv2.getRotationMatrix2D((cX,cY),angel,1)
    rotated = cv2.warpAffine(image,M , (w,h),flags=interpolation)
    return rotated

# Function for scaling the images
def zoom_CV(image, scale , interpolation):

    h,w = image.shape[:2]
    cX,cY = (w//2,h//2)
    M = cv2.getRotationMatrix2D((cX,cY),0,scale)
    scaled = cv2.warpAffine(image,M , (w,h),flags=interpolation)
    return scaled

# Function for cropping the images to output_size
def crop_img(img_sys, gt_sys, img_dia, gt_dia, output_size):

  output_size = output_size + 1
  absdiff = abs(img_sys - img_dia)

  if absdiff.shape[0] < output_size:
    img_sys = to_shape(img_sys,[output_size,absdiff.shape[1]])
    gt_sys = to_shape(gt_sys,[output_size,absdiff.shape[1]])
    img_dia = to_shape(img_dia,[output_size,absdiff.shape[1]])
    gt_dia = to_shape(gt_dia,[output_size,absdiff.shape[1]])

  if absdiff.shape[1] < output_size:
    img_sys = to_shape(img_sys,[absdiff.shape[0],output_size])
    gt_sys = to_shape(gt_sys,[absdiff.shape[0],output_size])
    img_dia = to_shape(img_dia,[absdiff.shape[0],output_size])
    gt_dia = to_shape(gt_dia,[absdiff.shape[0],output_size])

  cord = []
  absdiff = abs(img_sys - img_dia)
  output_size = output_size - 1

  intsum = np.zeros([absdiff.shape[0]-output_size,absdiff.shape[1]-output_size])
  intsum[:] = np.nan
  for i in range(absdiff.shape[0]-output_size):
    for j in range(absdiff.shape[1]-output_size):
      intsum[i,j] = sum(sum(absdiff[i:i+output_size,j:j+output_size]))

  result = np.where(intsum == np.amax(intsum))
  cord = list(zip(result[0], result[1]))

  img_sys_cut = img_sys[cord[0][0]:output_size+cord[0][0],cord[0][1]:output_size+cord[0][1]]
  gt_sys_cut = gt_sys[cord[0][0]:output_size+cord[0][0],cord[0][1]:output_size+cord[0][1]]
  img_dia_cut = img_dia[cord[0][0]:output_size+cord[0][0],cord[0][1]:output_size+cord[0][1]]
  gt_dia_cut = gt_dia[cord[0][0]:output_size+cord[0][0],cord[0][1]:output_size+cord[0][1]]

  return img_sys_cut, gt_sys_cut, img_dia_cut, gt_dia_cut

In [4]:
# Remove some patients if the data does not exist or the image is too large for resampling
patient_files_true = []

for i in range(len(patient_files)):
  try:
    pt_nr = i
    img = nib.load(file_path+'/'+patient_files[pt_nr]+'/'+patient_files[pt_nr]+'_frame01.nii.gz')
    ds_img = resample_img(img, target_affine=np.diag([1.25,1.25,1]), interpolation='nearest')
    #print(ds_img.shape)
    #print(out.shape)
    patient_files_true.append(patient_files[pt_nr])

  except:
    print("Removing patient:", patient_files[pt_nr])


In [None]:
def augment_images(file_path, patient_files, pt_nr, n_images, V_var):
  # Function for data augmentation

  # Inputs:
  # file_path: path to all patient maps
  # patient_files: list of patient files
  # pt_nr: idx in patient files
  # n_images: number of generated images per slice
  # V_var: variance of the V-component in the intensity transformation

  # Outputs:
  # aug_img_all: augmented images of shape (x_pixel_size,y_pixel_size,n_images,z_slice,frame)
  #              where frame=0 is for systole and frame=1 is for diastole
  # aug_gt_all: augmentend ground truths of the same shape as aug_img_all

  path = os.listdir(file_path+'/'+patient_files_true[pt_nr])

  # Find the right files
  for i in range(len(path)):
    if path[i].endswith('frame01.nii.gz'):
      i_dia = i
    if path[i].endswith('frame01_gt.nii.gz'):
      i_dia_gt = i
    if 'frame' in path[i] and not path[i].endswith('_frame01.nii.gz') and not path[i].endswith('_gt.nii.gz'):
      i_sys = i
    if path[i].endswith('_gt.nii.gz') and not path[i].endswith('frame01_gt.nii.gz'):
      i_sys_gt = i
      
  # Load and resample the systolic and diastolic 3D images to 1,25x1,25x10
  img_sys = nib.load(file_path+'/'+patient_files[pt_nr]+'/'+path[i_sys])
  img_sys_ds = resample_img(img_sys, target_affine=np.diag([1.25,1.25,1]), interpolation='nearest')
  img_dia = nib.load(file_path+'/'+patient_files[pt_nr]+'/'+path[i_dia])
  img_dia_ds = resample_img(img_dia, target_affine=np.diag([1.25,1.25,1]), interpolation='nearest')
  img_sys_gt = nib.load(file_path+'/'+patient_files[pt_nr]+'/'+path[i_sys_gt])
  img_sys_gt_ds = resample_img(img_sys_gt, target_affine=np.diag([1.25,1.25,1]), interpolation='nearest')
  img_dia_gt = nib.load(file_path+'/'+patient_files[pt_nr]+'/'+path[i_dia_gt])
  img_dia_gt_ds = resample_img(img_dia_gt, target_affine=np.diag([1.25,1.25,1]), interpolation='nearest')

  theta_ = np.random.uniform(0, 180, size=n_images-1)
  zoom_ = np.random.uniform(1, 1.3, size=n_images-1)

  aug_img_all = np.zeros([152,152,n_images,2,img_sys_ds.get_fdata(caching='unchanged').shape[2]])
  aug_gt_all = np.zeros([152,152,n_images,2,img_sys_ds.get_fdata(caching='unchanged').shape[2]])
  aug_img_all[:] = np.nan
  for augment in range(1): 

    for s in range(img_sys_ds.get_fdata(caching='unchanged').shape[2]):

      slice_0 = img_sys_ds.get_fdata(caching='unchanged')[:,:,s]
      slice_1 = img_sys_gt_ds.get_fdata(caching='unchanged')[:,:,s]

      slice_2 = img_dia_ds.get_fdata(caching='unchanged')[:,:,s]
      slice_3 = img_dia_gt_ds.get_fdata(caching='unchanged')[:,:,s]

      pixLV0 = np.where(slice_1 == 3)
      pixMYO0 = np.where(slice_1 == 2)
      pixLV1 = np.where(slice_3 == 3)
      pixMYO1 = np.where(slice_3 == 2)

      intLV0 = np.zeros(len(pixLV0[0]))
      intLV1 = np.zeros(len(pixLV1[0]))
      for i in range(len(pixLV0[0])):
        intLV0[i] = slice_0[pixLV0[0][i],pixLV0[1][i]]
      for i in range(len(pixLV1[0])):
        intLV1[i] = slice_2[pixLV1[0][i],pixLV1[1][i]]  

      intMYO0 = np.zeros(len(pixMYO0[0]))
      intMYO1 = np.zeros(len(pixMYO1[0]))
      for i in range(len(pixMYO0[0])):
        intMYO0[i] = slice_0[pixMYO0[0][i],pixMYO0[1][i]]
      for i in range(len(pixMYO1[0])):
        intMYO1[i] = slice_2[pixMYO1[0][i],pixMYO1[1][i]]

      bins_ = np.arange(0,400,1)
      freqsLV0, _ = np.histogram(intLV0, bins=bins_)
      freqsMYO0, _ = np.histogram(intMYO0, bins=bins_)
      freqsLV0 = freqsLV0/len(intLV0)
      freqsMYO0 = freqsMYO0/len(intMYO0)
      freqsLV1, _ = np.histogram(intLV1, bins=bins_)
      freqsMYO1, _ = np.histogram(intMYO1, bins=bins_)
      freqsLV1 = freqsLV1/len(intLV1)
      freqsMYO1 = freqsMYO1/len(intMYO1)

      BC0 = sum(np.sqrt(freqsLV0*freqsMYO0)) # Bhattacharyya coefficient (between 0-1)
      BC1 = sum(np.sqrt(freqsLV1*freqsMYO1))
      DB0 = -np.log(BC0) # Bhattacharyya distance
      DB1 = -np.log(BC1)

      # Take an arbritary value for DB if pixLV or pixMYO is empty:
      if np.isnan(DB0) or DB0 > 3 or DB0 < -3:
        DB0 = 1.3
      if np.isnan(DB1) or DB1 > 3 or DB1 < -3:
        DB1 = 1.3

      aug_img = np.zeros([152,152,n_images,2])
      aug_gt = np.zeros([152,152,n_images,2])
      aug_img[:] = np.nan

      # Store original image as k=0
      slice_0, _ = image_histogram_equalization(slice_0, number_bins=256)
      slice_2, _ = image_histogram_equalization(slice_2, number_bins=256)
      slice_0, slice_1, slice_2, slice_3 = crop_img(slice_0, slice_1, slice_2, slice_3, 152)
      aug_img[:,:,0,0] = slice_0
      aug_gt[:,:,0,0] = slice_1
      aug_img[:,:,0,1] = slice_2
      aug_gt[:,:,0,1] = slice_3
      
      for k in range(1,n_images):

        slice_0 = img_sys_ds.get_fdata(caching='unchanged')[:,:,s]
        slice_1 = img_sys_gt_ds.get_fdata(caching='unchanged')[:,:,s]

        slice_2 = img_dia_ds.get_fdata(caching='unchanged')[:,:,s]
        slice_3 = img_dia_gt_ds.get_fdata(caching='unchanged')[:,:,s]

        # Histogram normalization
        slice_0, _ = image_histogram_equalization(slice_0, number_bins=256)
        slice_2, _ = image_histogram_equalization(slice_2, number_bins=256)
        
        # Transform intensity
        V = np.random.normal(loc=0.1, scale=V_var)

        for i in range(len(pixLV0[0])):
          W = np.random.uniform(low=-0.05, high=0.05)
          slice_0[pixLV0[0][i],pixLV0[1][i]] = (1-DB0*V + W)*slice_0[pixLV0[0][i],pixLV0[1][i]]
        for i in range(len(pixLV1[0])):
          W = np.random.uniform(low=-0.05, high=0.05)
          slice_2[pixLV1[0][i],pixLV1[1][i]] = (1-DB1*V + W)*slice_2[pixLV1[0][i],pixLV1[1][i]]

        for i in range(len(pixMYO0[0])):
          W = np.random.uniform(low=-0.05, high=0.05)
          slice_0[pixMYO0[0][i],pixMYO0[1][i]] = (1+DB0*V + W)*slice_0[pixMYO0[0][i],pixMYO0[1][i]]
        for i in range(len(pixMYO1[0])):
          W = np.random.uniform(low=-0.05, high=0.05)
          slice_2[pixMYO1[0][i],pixMYO1[1][i]] = (1+DB1*V + W)*slice_2[pixMYO1[0][i],pixMYO1[1][i]]

        # Rotate randomly (linear for img and nearest for gt)
        theta = theta_[k-1]
        slice_0 = rotate_CV(slice_0, theta, cv2.INTER_LINEAR)
        slice_1 = rotate_CV(slice_1, theta, cv2.INTER_NEAREST)
        slice_2 = rotate_CV(slice_2, theta, cv2.INTER_LINEAR)
        slice_3 = rotate_CV(slice_3, theta, cv2.INTER_NEAREST)
        
        # Scale randomly
        zoom = zoom_[k-1]
        slice_0 = zoom_CV(slice_0, zoom, cv2.INTER_LINEAR)
        slice_1 = zoom_CV(slice_1, zoom, cv2.INTER_NEAREST)
        slice_2 = zoom_CV(slice_2, zoom, cv2.INTER_LINEAR)
        slice_3 = zoom_CV(slice_3, zoom, cv2.INTER_NEAREST)

        # Crop to 152x152  
        slice_0, slice_1, slice_2, slice_3 = crop_img(slice_0, slice_1, slice_2, slice_3, 152)

        # Store augmented data
        aug_img[:,:,k,0] = slice_0
        aug_gt[:,:,k,0] = slice_1
        aug_img[:,:,k,1] = slice_2
        aug_gt[:,:,k,1] = slice_3
      ################################ END FOR
      aug_img_all[:,:,:,:,s] = aug_img
      aug_gt_all[:,:,:,:,s] = aug_gt
    ################################## END FOR

  return aug_img_all, aug_gt_all

In [None]:
output_size = 152
count_false = 0
false_crops = []
failed_files = []
count_total = 0
n_images = 3 # 3 images in total, so 1 original + 2 augmented
V_var = 0.02

for pt_nr in range(len(patient_files_true)):
  pt = patient_files_true[pt_nr][7:] # get patient number

  x,y = augment_images(file_path, patient_files_true, pt_nr, n_images, V_var)
  augmented = True

  if augmented:
    for n in range(x.shape[2]):
      for f in range(x.shape[3]):
        for z in range(x.shape[4]):
          try:
            x_cut = x[:,:,n,f,z]
            y_cut = y[:,:,n,f,z]
            np.save(f'/content/drive/MyDrive/Deep Learning Project/im_aug_REAL2/im_p{pt}_z{z}_f{f}_n{n}.npy',x_cut)
            np.save(f'/content/drive/MyDrive/Deep Learning Project/gt_aug_REAL2/gt_p{pt}_z{z}_f{f}_n{n}.npy',y_cut)

            # Keep track of how many falsely cropped images there are
            count_total = count_total + 1
            if sum(y_cut[0,:]) < 1 and sum(y_cut[151,:]) < 1 and sum(y_cut[:,0]) < 1 and sum(y_cut[:,151]) < 1:
              pass
            else:
              count_false = count_false + 1
              false_crops.append(f'{pt}_z{z}_f{f}_n{n}')

          except:
            print(f"file {pt}_z{z}_f{f}_n{n} failed")
            failed_files.append(f'{pt}_z{z}_f{f}_n{n}')
            continue

  print(f"{pt_nr+1} out of {len(patient_files_true)} done")

print(f"{count_false} out of {count_total} falsely cropped")
np.save('/content/drive/MyDrive/Deep Learning Project/false_crop_list3',false_crops)
np.save('/content/drive/MyDrive/Deep Learning Project/failed_files_list3',failed_files)