In [None]:
# Mount google drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# --- Install jarvis-md
% pip install \
    --index-url https://test.pypi.org/simple/ \
    --extra-index-url https://pypi.org/simple \
    jarvis-md==0.0.1a7

In [None]:
# Import libraries
import tensorflow as tf
from tensorflow.keras import Input, Model, layers, losses, optimizers
import os, numpy as np
from jarvis.train.client import Client
from jarvis.train import custom, models, datasets
from jarvis.utils.general import overload

### Download data

In [None]:
datasets.download(name='xr/mimic-ett')

In [None]:
gen_train, gen_valid, client = datasets.prepare(name='xr/mimic-ett', keyword='ett-crp')

### Create client and generators

In [None]:
import random
from scipy import ndimage

augEnabled = 1

@overload(Client)
def preprocess(self, arrays, **kwargs):
  """
  Method to preprocess arrays
  """
  # =================================================================
  # Dynamic Augmentation
  # =================================================================
  # 
  # arrays['xs']['dat'] ==> cropped dat
  # arrays['ys']['car'] ==> (z,y,x) carina or None (if not exists)
  # arrays['ys']['ett'] ==> (z,y,x) ETT or None (if not exists)
  # 
  # =================================================================

  if augEnabled == 1:
    # Generate a random number
    aug = random.randint(0, 12)

    """
    aug types:
     0: no augmentation
     1: zoom in (1.1)
     2: zoom in (1.05)
     3: zoom out (0.95)
     4: zoom out (0.9)
     5: shift NW (y:14, x:7)
     6: shift NE (y:14, x:7)
     7: shift SW (y:14, x:7)
     8: shift SE (y:14, x:7)
     9: shift N (y:14, x:0)
    10: shift E (y:0, x:7)
    11: shift S (y:14, x:0)
    12: shift W (y:0, x:7)
    """

    # Set dimensions
    yDim = 256
    xDim = 128

    # Scale up or down (zoom in or out)
    if aug > 0 and aug < 5:
      # Remove single-dimensional entries from the image array
      img = np.squeeze(arrays['xs']['dat']).reshape(yDim, xDim)

      if aug < 3:
        # Set scale
        if aug == 1:
          # Set scale to 1.1
          scale = 1.1
        else:
          # Set scale to 1.05
          scale = 1.05

        # Zoom in to create a new image
        img = ndimage.zoom(img, zoom=scale, order=1)

        # Append to images in flattened form to hold in one row
        arrays['xs']['dat'] = img[0:yDim, 0:xDim].flatten()
      else:
        # Set scale
        if aug == 3:
          # Set scale to 0.95
          scale = 0.95
        else:
          # Set scale to 0.9
          scale = 0.9

        # Zoom in to create a new image
        img = ndimage.zoom(img, zoom=scale, order=1)

        # Copy into an empty yDim x xDim image
        emptyImg = np.zeros((yDim, xDim), np.int16)
        emptyImg[0:img.shape[0], 0:img.shape[1]] = img

        # Set to flattened image
        arrays['xs']['dat'] = emptyImg.flatten()

      # Append to carina labels
      if arrays['ys']['car'] is not None:
        arrays['ys']['car'] = arrays['ys']['car'] * scale

      # Append to ett labels
      if arrays['ys']['ett'] is not None:
        arrays['ys']['ett'] = arrays['ys']['ett'] * scale
    elif aug >= 5:
      # Remove single-dimensional entries from the image array
      img = np.squeeze(arrays['xs']['dat']).reshape(yDim, xDim)

      # Initialize translation coordinates
      yTop = 0
      yBottom = yDim
      xLeft = 0
      xRight = xDim

      # Set the length of vertical and horizontal translation
      yLen = 14 #random.randint(10, 19)
      xLen = 7 #random.randint(10, 19)

      # Set the direction of translation: up(0) and down(1); left(0) and right(1)
      if aug == 5 or aug == 9:
        yDir = 0
        xDir = 0
      elif aug == 6 or aug == 10:
        yDir = 0
        xDir = 1
      elif aug == 7 or aug == 11:
        yDir = 1
        xDir = 0
      elif aug == 8 or aug == 12:
        yDir = 1
        xDir = 1

      # Set unused dimension length to zero for N, E, S, and W
      if aug == 9 or aug == 11:
        xLen = 0
      elif aug == 10 or aug == 12:
        yLen = 0

      if yDir == 0: # up
        yTop += yLen
        newTop = 0
        newBottom = yDim - yLen

        # Append to carina labels
        if arrays['ys']['car'] is not None:
          arrays['ys']['car'][0][0][1] -= float(yLen / yDim)

        # Append to ett labels
        if arrays['ys']['ett'] is not None:
          arrays['ys']['ett'][0][0][1] -= float(yLen / yDim)
      else: # down
        yBottom -= yLen
        newTop = yLen
        newBottom = yDim

        # Append to carina labels
        if arrays['ys']['car'] is not None:
          arrays['ys']['car'][0][0][1] += float(yLen / yDim)

        # Append to ett labels
        if arrays['ys']['ett'] is not None:
          arrays['ys']['ett'][0][0][1] += float(yLen / yDim)

      if xDir == 0: # left
        xLeft += xLen
        newLeft = 0
        newRight = xDim - xLen

        # Append to carina labels
        if arrays['ys']['car'] is not None:
          arrays['ys']['car'][0][0][2] -= float(xLen / xDim)

        # Append to ett labels
        if arrays['ys']['ett'] is not None:
          arrays['ys']['ett'][0][0][2] -= float(xLen / xDim)
      else: # right
        xRight -= xLen
        newLeft = xLen
        newRight = xDim

        # Append to carina labels
        if arrays['ys']['car'] is not None:
          arrays['ys']['car'][0][0][2] += float(xLen / xDim)

        # Append to ett labels
        if arrays['ys']['ett'] is not None:
          arrays['ys']['ett'][0][0][2] += float(xLen / xDim)

      # Copy into an empty yDim x xDim image
      emptyImg = np.zeros((yDim, xDim), np.int16)
      emptyImg[newTop:newBottom, newLeft:newRight] = img[yTop:yBottom, xLeft:xRight]

      # Set to flattened image
      arrays['xs']['dat'] = emptyImg.flatten()

  # =================================================================
  # Masks
  # =================================================================
  for key in ['car', 'ett']:
    if arrays['ys'][key] is None:
      # Initialize missing data
      arrays['ys'][key] = np.zeros((1, 1, 2, 1))
      arrays['xs']['msk-' + key][:] = 0
    else:
      # Prepare pts (convert 3D to 2D)
      arrays['ys'][key] = arrays['ys'][key][:, :, 1:]

      # Prepare msk (ignore points beyond field of view)
      arrays['xs']['msk-' + key][0] = \
          np.all(arrays['ys'][key] > 0, axis=(0, 2, 3), keepdims=True) & \
          np.all(arrays['ys'][key] < 1, axis=(0, 2, 3), keepdims=True)

  return arrays

In [None]:
# Define configs
configs = {'batch': {'size': 12, 'fold': -1}}

# Create client
os.environ['JARVIS_PROJECT_ID'] = 'xr/mimic-ett'
client = Client(pattern='client-ett-crp', configs=configs)

# Create generators
gen_train, gen_valid = client.create_generators()

print(client.batch)

In [None]:
# =================================================================
# Test visualization of data batches to ensure augmentation works 
# =================================================================
#
# Yield one example
xs, ys = next(gen_train)

# Print dict keys
print('xs keys: {}'.format(xs.keys()))
print('ys keys: {}'.format(ys.keys()))

In [None]:
# Print data shape
print('xs shape: {}'.format(xs['dat'].shape))
print('xs shape: {}'.format(xs['msk-car'].shape))
print('xs shape: {}'.format(xs['msk-ett'].shape))
print('ys shape: {}'.format(ys['car'].shape))
print('ys shape: {}'.format(ys['ett'].shape))

In [None]:
from jarvis.utils.display import imshow

# --- Show the first example
imshow(xs['dat'][0, 0])

# --- Show "montage" of all images
imshow(xs['dat'])

### Create model

In [None]:
def prepare_model(inputs):
    # --- Get blocks
    kwargs = models.create_block_components(names=('kwargs_z1',))[0]
    conv1, conv2 = models.create_blocks(('conv1', 'conv2'))

    # --- Define layers
    l1 = conv2(48, conv1(48, conv1(48, conv1(48, conv1(48, conv1(48, inputs['dat']))))))
    l2 = conv2(56, conv1(56, conv1(56, conv1(56, conv1(56, l1)))))
    l3 = conv2(64, conv1(64, conv1(64, conv1(64, conv1(64, l2)))))
    l4 = conv2(80, conv1(80, conv1(80, conv1(80, l3))))
    l5 = conv2(96, conv1(96, conv1(96, conv1(96, l4))))
    l6 = conv2(112, conv1(112, conv1(112, l5)))
    l7 = conv2(128, conv1(128, conv1(128, l6)))

    # --- Flatten
    c1 = layers.Reshape((1, 1, 1, 2 * 1 * 128))(l7)
    c2 = layers.Conv3D(filters=2, kernel_size=(1, 1, 1), activation='sigmoid')(c1)
    c3 = layers.Conv3D(filters=2, kernel_size=(1, 1, 1), activation='sigmoid')(c1)
    
    # --- Create logits
    logits = {}
    logits['car'] = layers.Reshape((-1, 1, 2, 1), name='car')(c2)
    logits['ett'] = layers.Reshape((-1, 1, 2, 1), name='ett')(c3)

    # --- Create model
    model = Model(inputs=inputs, outputs=logits) 

    # --- Compile the model
    model.compile(
        optimizer=optimizers.Adam(learning_rate=2e-4),
        loss={
            'car': custom.mse(inputs['msk-car']), 
            'ett': custom.mse(inputs['msk-ett'])})

    return model

### Train and Test

In [None]:
import math
import datetime

def test_model(fold):
  # Set start time
  startTime = datetime.datetime.now().time()

  # Initialize total count
  cntTotal = 0

  # --- Create client
  test_train, test_valid = client.create_generators(test=True)

  cntCar = 0
  cntEtt = 0
  distInPixelsCar = 0
  distInPixelsEtt = 0

  # Calculate the distance between lables and predictions
  for x, y in test_valid:
    logits = model.predict(x)

    # Get carina and ETT labels
    labelCarY = y['car'].flatten()[0]*256
    labelCarX = y['car'].flatten()[1]*128
    labelEttY = y['ett'].flatten()[0]*256
    labelEttX = y['ett'].flatten()[1]*128

    # Get carina and ETT predictions
    predCarY = logits[0].flatten()[0]*256
    predCarX = logits[0].flatten()[1]*128
    predEttY = logits[1].flatten()[0]*256
    predEttX = logits[1].flatten()[1]*128

    # Calculate distance between label and prediction for carina
    carExists = 0
    if labelCarY > 0 and labelCarX > 0:
      carExists = 1
      cntCar += 1
      distInPixels = math.sqrt((predCarY - labelCarY) ** 2 + (predCarX - labelCarX) ** 2)
      distInPixelsCar += distInPixels
      with open('carina_distances.csv', 'a') as carDistFile:
        carDistFile.write(str(fold) + ',' + str(distInPixels) + '\n')

    # Calculate distance between label and prediction for ETT
    ettExists = 0
    if labelEttY > 0 and labelEttX > 0:
      ettExists = 1
      cntEtt += 1
      distInPixels = math.sqrt((predEttY - labelEttY) ** 2 + (predEttX - labelEttX) ** 2)
      distInPixelsEtt += distInPixels
      with open('ett_distances.csv', 'a') as ettDistFile:
        ettDistFile.write(str(fold) + ',' + str(distInPixels) + '\n')

    if carExists == 1 and ettExists == 1:
      with open('car_ett_distances.csv', 'a') as ettDistFile:
        ettDistFile.write(str(fold) + ',' + str(labelCarY) + ',' + str(labelCarX) + ',' + str(labelEttY) + ',' + str(labelEttX) + ',' + str(predCarY) + ',' + str(predCarX) + ',' + str(predEttY) + ',' + str(predEttX) + '\n')

    cntTotal += 1

  # Set end time
  endTime = datetime.datetime.now().time()

  # Print start and end times
  print("Start Time: " + str(startTime))
  print("End Time: " + str(endTime))
  print("Item Count: " + str(cntTotal) + '\n')

  # Display test results
  print('\n\nCarina Count: ' + str(cntCar))
  print('Carina Avg Dist in Pixels: ' + str(distInPixelsCar / cntCar))
  print('Carina Avg Dist in Cm: ' + str(distInPixelsCar * 0.08 / cntCar))
  print('\nETT Count: ' + str(cntEtt))
  print('ETT Avg Dist in Pixels: ' + str(distInPixelsEtt / cntEtt))
  print('ETT Avg Dist in Cm: ' + str(distInPixelsEtt * 0.08 / cntEtt) + '\n')

In [None]:
#--------------------------------------------
# Train and test with 5-fold cross validation
#--------------------------------------------

with open('carina_distances.csv', 'a') as carDistFile:
  carDistFile.write('fold,dist\n')
with open('ett_distances.csv', 'a') as ettDistFile:
  ettDistFile.write('fold,dist\n')
with open('car_ett_distances.csv', 'a') as carEttDistFile:
  carEttDistFile.write('fold,car_label_y,car_label_x,ett_label_y,ett_label_x,car_pred_y,car_pred_x,ett_pred_y,ett_pred_x\n')

for fold in range(5):
  print('fold: ' + str(fold))

  # Set fold
  configs = {'batch': {'size': 12, 'fold': fold}}

  # --- Create client
  os.environ['JARVIS_PROJECT_ID'] = 'xr/mimic-ett'
  client = Client(pattern='client-ett-crp', configs=configs)

  # --- Create generators
  gen_train, gen_valid = client.create_generators()

  # Get inputs
  inputs = client.get_inputs(Input)

  # Prepare the model
  model = prepare_model(inputs)

  # Load data into memory for faster training
  client.load_data_in_memory()

  # Initialize learning rate and epoch
  lr = 0.0005
  epoch = 1

  # Augmentation enabled
  augEnabled = 1 # change random start from 1 to 0 in preprocess

  for j in range(2):
    # Compile the model
    model.compile(
        optimizer=optimizers.Adam(learning_rate=lr),
        loss={
            'car': custom.mse(inputs['msk-car']), 
            'ett': custom.mse(inputs['msk-ett'])})

    for i in range(3):
      print('learning-rate: ' + str(lr))
      print('epoch: ' + str(epoch))

      # Train the model
      model.fit(
          x=gen_train,
          steps_per_epoch=6500, # (6000 training items / 12 batch size) x (1 + 12 augmentations)
          epochs=1,
          validation_data=gen_valid,
          validation_steps=6500,
          validation_freq=1)

      # Increment epoch
      epoch += 1

      # Check epoch to break the loop
      if epoch == 5:
        break

    # Check epoch to break the loop
    if epoch == 5:
      break

    # Divide learning rate by 10
    lr /= 10

  # Test the model
  test_model(fold)

  # Delete the model at the end of each fold
  del model

!cp ./carina_distances.csv '/content/drive/My Drive/Colab Notebooks/carina_distances.csv'
!cp ./ett_distances.csv '/content/drive/My Drive/Colab Notebooks/ett_distances.csv'
!cp ./car_ett_distances.csv '/content/drive/My Drive/Colab Notebooks/car_ett_distances.csv'

In [None]:
#--------------------------------------------------
# Generate a model file from the whole training set
#--------------------------------------------------

# Set fold to -1
configs = {'batch': {'size': 12, 'fold': -1}}

# --- Create client
os.environ['JARVIS_PROJECT_ID'] = 'xr/mimic-ett'
client = Client(pattern='client-ett-crp', configs=configs)

# --- Create generators
gen_train, gen_valid = client.create_generators()

# Get inputs
inputs = client.get_inputs(Input)

# Prepare the model
model = prepare_model(inputs)

# Load data into memory for faster training
client.load_data_in_memory()

# Initialize learning rate and epoch
lr = 0.0005
epoch = 1

for j in range(2):
  # Compile the model
  model.compile(
      optimizer=optimizers.Adam(learning_rate=lr),
      loss={
          'car': custom.mse(inputs['msk-car']), 
          'ett': custom.mse(inputs['msk-ett'])})

  for i in range(3):
    print('learning-rate: ' + str(lr))
    print('epoch: ' + str(epoch))

    # Augmentation enabled
    augEnabled = 1 # change random start from 0 to 1 in preprocess

    # Train the model
    model.fit(
        x=gen_train,
        steps_per_epoch=7500, # (7500 training items / 12 batch size) x (12 augmentations)
        epochs=1,
        validation_data=gen_valid,
        validation_steps=7500,
        validation_freq=1)

    # Increment epoch
    epoch += 1

    # Check epoch to break the loop
    if epoch == 5:
      break

  # Check epoch to break the loop
  if epoch == 5:
    break

  # Divide learning rate by 10
  lr /= 10

In [None]:
# --- Save the model to a file
model.save('./cnn_3_carina_ett_regression.hdf5')

# Copy the model to google drive
!cp ./cnn_3_carina_ett_regression.hdf5 '/content/drive/My Drive/Colab Notebooks/cnn_3_carina_ett_regression.hdf5'

In [None]:
# Copy the model file from google drive
!cp '/content/drive/My Drive/Colab Notebooks/cnn_3_carina_ett_regression.hdf5' ./cnn_3_carina_ett_regression.hdf5

# Load the model from the file
from tensorflow.keras import models as tfModels
model = tfModels.load_model('./cnn_3_carina_ett_regression.hdf5', compile=False)

In [None]:
# Calculate distance from ETT to carina on all ETT positive images and compare with the values in ett.csv
import tensorflow as tf
import os
import glob
import numpy as np
import math

# Initialize global variables
count = 0
totalError = 0
ettBelowCarina = 0

# Initialize input date
xs= {}
xs['dat'] = np.zeros((1, 1, 256, 128, 1))
xs['msk-car'] = np.zeros((1, 1, 1, 2, 1))
xs['msk-ett'] = np.zeros((1, 1, 1, 2, 1))

# Read the contents of the cxr-record-list file
with open("./cxr-record-list.csv", 'r') as recFile:
  recContent = recFile.read()

# Read the contents of the true negative file
with open("./trueNeg.csv", 'r') as trueNegFile:
  trueNegContent = trueNegFile.read()

# Read the contents of the ett file
with open("./ett.csv", 'r') as ettFile:
  ettContent = ettFile.read()

# Read the contents of the frontal file
with open("./db-mimic-frontal-only.csv", 'r') as frontalFile:
  imgNames = frontalFile.readlines()

# Process each frontal image
for imgName in imgNames:
  # Strip whitespaces
  imgName = imgName.strip()

  # Search for the image name in cxr-record-list
  imgStart = recContent.find(imgName + ".dcm")

  # Get patient ID
  patientID = recContent[imgStart-19:imgStart-11]

  # Get study ID
  studyID = recContent[imgStart-9:imgStart-1]

  # Get folder
  folder = recContent[imgStart-23:imgStart-21]

  # Search for the study in ETT file
  patientStart = ettContent.find(patientID + "," + studyID)
  if patientStart != -1:
    # Search for the image in true negative file
    imgStart = trueNegContent.find(imgName)
    if imgStart != -1:
      patientStart = -1

  # Patient and study not found in ett
  if patientStart == -1:
    continue

  # Search for new line at the end of distance
  newLinePos = ettContent.find("\n", patientStart)

  # Get distance to carina from ett
  distToCarinaLabel = float(ettContent[patientStart+52:newLinePos])
  
  # Set path
  imgFile = "./crops/p" + folder + "/p" + patientID + "/s" + studyID + "/" + imgName + ".npy"

  # Load an image file
  imgS = np.load(imgFile)

  # Normalize the image
  imgN = (imgS - np.mean(imgS)) / np.std(imgS)

  # Predict coordinates of carina and ett on the 256x128 cropped image
  xs['dat'] = imgN.reshape(1, 1, 256, 128, 1)
  logits = model.predict(x=xs)

  # Get carina and ETT predictions
  predCarY = logits[0].flatten()[0]*256
  predCarX = logits[0].flatten()[1]*128
  predEttY = logits[1].flatten()[0]*256
  predEttX = logits[1].flatten()[1]*128

  # Calculate distance from ett to carina
  distToCarinaPred = math.sqrt(((predCarY - predEttY) ** 2) + ((predCarX - predEttX) ** 2)) * 0.08

  # Check whether ett is below carina
  if predEttY > predCarY:
    ettBelowCarina += 1
    
  # Add error
  totalError += abs(distToCarinaLabel - distToCarinaPred)

  # Append distance to a file to calculate median
  with open('ett_rpt_distances.csv', 'a') as ettRptDistFile:
    ettRptDistFile.write(str(abs(distToCarinaLabel - distToCarinaPred)) + '\n')

  # Increment count
  count += 1

# Print calculations
print("Total error: " + str(totalError))
print("Total count: " + str(count))
print("Average error: " + str(totalError / count))
print("ETT below carina: " + str(ettBelowCarina))