In [1]:
%config Completer.use_jedi = False

import echonet
from echonet.datasets import Echo

import torch.nn.functional as F
from torchvision.models.video import r2plus1d_18
from torch.utils.data import Dataset, DataLoader, Subset
from multiprocessing import cpu_count

from src.utils.torch_utils import TransformDataset, torch_collate
from src.utils.echo_utils import get2dPucks
from src.utils.camus_validate import cleanupSegmentation
from src.transform_utils import generate_2dmotion_field
from src.visualization_utils import categorical_dice
from src.loss_functions import huber_loss, convert_to_1hot, convert_to_1hot_tensor
from src.model.R2plus1D_18_MotionNet import R2plus1D_18_MotionNet
from src.echonet_dataset import EchoNetDynamicDataset, EDESpairs
# from src.visualization_utils import categorical_dice

import numpy as np
from scipy.signal import find_peaks
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

import random
import pickle
import time

tic, toc = (time.time, time.time)

In [2]:
batch_size = 4
num_workers = max(4, cpu_count()//2)


def worker_init_fn_valid(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)
    

def worker_init_fn(worker_id):
    # See here: https://pytorch.org/docs/stable/notes/randomness.html#dataloader
    # and the original post of the problem: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817373837
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    

def permuter(list1, list2):
    for i1 in list1:
        for i2 in list2:
            yield (i1, i2)
            

param_trainLoader = {'collate_fn': torch_collate,
                     'batch_size': batch_size,
                     'num_workers': max(4, cpu_count()//2),
                     'worker_init_fn': worker_init_fn}

param_testLoader = {'collate_fn': torch_collate,
                    'batch_size': batch_size,
                    'shuffle': False,
                    'num_workers': max(4, cpu_count()//2),
                    'worker_init_fn': worker_init_fn}

paramLoader = {'train': param_trainLoader,
               'valid': param_testLoader,
               'test':  param_testLoader}

In [3]:
with open("fold_indexes/stanford_valid_sampled_indices", "rb") as infile:
    valid_mask = pickle.load(infile)
infile.close()

full_dataset = EchoNetDynamicDataset(split='val', clip_length="full", subset_indices=valid_mask, period=1)
test_dataset = EchoNetDynamicDataset(split='test', clip_length="full", raise_for_es_ed=False, period=1)

100%|██████████| 16/16 [00:01<00:00, 14.63it/s]
100%|██████████| 16/16 [00:01<00:00, 13.68it/s]


In [4]:
def divide_to_consecutive_clips(video, clip_length=32, interpolate_last=False):
    source_video = video.copy()
    video_length = video.shape[1]
    left = video_length % clip_length
    if left != 0 and interpolate_last:
        source_video = torch.Tensor(source_video).unsqueeze(0)
        source_video = F.interpolate(source_video, size=(int(np.round(video_length / clip_length) * clip_length), 112, 112),
                                     mode="trilinear", align_corners=False)
        source_video = source_video.squeeze(0).squeeze(0)
        source_video = source_video.numpy()
    
    videos = np.empty(shape=(1, 3, clip_length, 112, 112))

    for start in range(0, int(clip_length * np.round(video_length / clip_length)), clip_length):
        one_clip = source_video[:, start: start + clip_length]
        one_clip = np.expand_dims(one_clip, 0)
        videos = np.concatenate([videos, one_clip])
    return videos[1:]


def get_all_possible_start_points(ed_index, es_index, video_length, clip_length):
    assert es_index - ed_index > 0, "not a ED to ES clip pair"
    possible_shift = clip_length - (es_index - ed_index)
    allowed_right = video_length - es_index
    if allowed_right < possible_shift:
        return np.arange(ed_index - possible_shift + 1, video_length - clip_length + 1)
    if possible_shift < 0:
        return np.array([ed_index])
    elif ed_index < possible_shift:
        return np.arange(ed_index)
    else:
        return np.arange(ed_index - possible_shift + 1, ed_index + 1)

In [6]:
model_save_path = "save_models/R2plus1DMotionSegNet_model_tmp.pth"

model = torch.nn.DataParallel(R2plus1D_18_MotionNet())
model.to("cuda")
torch.cuda.empty_cache()
model.load_state_dict(torch.load(model_save_path)["model"])
print(f'R2+1D MotionNet has {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters.')

model.eval();

R2+1D MotionNet has 31575731 parameters.


### Compute EF for all test patients

In [8]:
EF_list = []
true_EF_list = []
mean_EF_list = []

for i in range(len(test_dataset)):
    test_pat_index = i
    try:
        video, (filename, EF, es_clip_index, ed_clip_index, es_index, ed_index, es_frame, ed_frame, es_label, ed_label) = test_dataset[test_pat_index]
    except:
        print("Get exception when trying to read the video from patient:{:04d}".format(i))
        continue
    interpolate_last = True

    consecutive_clips = divide_to_consecutive_clips(video, interpolate_last=interpolate_last)

    segmentation_outputs = np.empty(shape=(1, 2, 32, 112, 112))
    motion_outputs = np.empty(shape=(1, 4, 32, 112, 112))

    for i in range(consecutive_clips.shape[0]):
        one_clip = np.expand_dims(consecutive_clips[i], 0)
        segmentation_output, motion_output = model(torch.Tensor(one_clip))
        segmentation_output = F.softmax(segmentation_output, 1)
        segmentation_outputs = np.concatenate([segmentation_outputs, segmentation_output.cpu().detach().numpy()])
        motion_outputs = np.concatenate([motion_outputs, motion_output.cpu().detach().numpy()])
    segmentation_outputs = segmentation_outputs[1:]
    motion_outputs = motion_outputs[1:]
   
    # --- Modification Begin ---
    segmentation_outputs = segmentation_outputs.transpose([1, 0, 2, 3, 4])
    segmentation_outputs = segmentation_outputs.reshape(2, -1, 112, 112)
    if interpolate_last and (video.shape[1] % 32 != 0):
        interpolated_segmentations = torch.Tensor(segmentation_outputs).unsqueeze(0)
        interpolated_segmentations = F.interpolate(interpolated_segmentations, size=(video.shape[1], 112, 112), 
                                                   mode="trilinear", align_corners=False)
        interpolated_segmentations = interpolated_segmentations.squeeze(0).numpy()
        segmentation_outputs = np.argmax(interpolated_segmentations, 0)

        segmentations = segmentation_outputs
    else:
        segmentations = np.argmax(segmentation_outputs, axis=0)
    # --- Modification End ---
    

    size = np.sum(segmentations, axis=(1, 2)).ravel()
    _05cut, _85cut, _95cut = np.percentile(size, [5, 85, 95]) 

    trim_min = _05cut
    trim_max = _95cut
    trim_range = trim_max - trim_min
    systole = find_peaks(-size, distance=20, prominence=(0.50 * trim_range))[0]
    diastole = find_peaks(size, distance=20, prominence=(0.50 * trim_range))[0]

    # keep only real diastoles..
    diastole = [x for x in diastole if size[x] >= _85cut]
    # Add first frame
    if np.mean(size[:3]) >= _85cut:
        diastole = [0] + diastole
    diastole = np.array(diastole)

    clip_pairs = EDESpairs(diastole, systole)

    one_array_of_segmentations = segmentations

    predicted_efs = []

    for i in range(len(clip_pairs)):
        output_ED = one_array_of_segmentations[clip_pairs[i][0]]
        output_ES = one_array_of_segmentations[clip_pairs[i][1]]

        length_ed, radius_ed = get2dPucks(((output_ED) == 1).astype('int'), (1.0, 1.0))
        length_es, radius_es = get2dPucks(((output_ES) == 1).astype('int'), (1.0, 1.0))

        edv = np.sum(((np.pi * radius_ed * radius_ed) * length_ed / len(radius_ed)))
        esv = np.sum(((np.pi * radius_es * radius_es) * length_es / len(radius_es)))

        ef_predicted = (edv - esv) / edv * 100

        predicted_efs.append(ef_predicted)
    
    if np.isnan(np.mean(predicted_efs)):
        if len(predicted_efs) == 0:
            print("Cannot identify clips at patient:{:04d}".format(test_pat_index))
            continue
        else:
            print("Nan EF at patient:{:04d}".format(test_pat_index))
    
    EF_list.append(predicted_efs)
    true_EF_list.append(EF)
    mean_EF_list.append(np.nanmean(predicted_efs))

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


Cannot identify clips at patient:0095
Cannot identify clips at patient:0156
Cannot identify clips at patient:0294
Cannot identify clips at patient:0390
Cannot identify clips at patient:0573
Cannot identify clips at patient:0612
Cannot identify clips at patient:0761
Cannot identify clips at patient:0832
Cannot identify clips at patient:1061


### Mean Absolute Error

In [11]:
errors = np.array(np.array(true_EF_list) - np.array(mean_EF_list))
abs_errors = abs(errors)

print("Mean absolute error (standard deviation):  {:.4f} ({:.4f}) %".format(np.mean(abs_errors), np.std(abs_errors)))
print("Median absolute error:  {:.4f} %".format(np.median(abs_errors)))
print("Bias +- 1.96 x std:  {:.4f} +- {:.4f}".format(np.mean(errors), 1.96 * np.std(errors)))
print("Percentile of mae 50%: {:6.4f}  75%: {:6.4f}  95%: {:6.4f}".format(np.percentile(abs_errors, 50), np.percentile(abs_errors, 75),
                                                                    np.percentile(abs_errors, 95)))

Mean absolute error (standard deviation):  5.3737 (4.9693) %
Median absolute error:  4.2037 %
Bias +- 1.96 x std:  -2.2112 +- 13.6753
Percentile of mae 50%: 4.2037  75%: 7.4805  95%: 14.1865


### Cross Correlation with True EF

In [15]:
drop_ind_arr = ~np.isnan(mean_EF_list)
corrcoef = np.corrcoef(np.array(true_EF_list)[drop_ind_arr], np.array(mean_EF_list)[drop_ind_arr])
print("Cross correlation with True EF: {:.3f}".format(corrcoef[0, 1]))

Cross correlation with True EF: 0.833


### Check with Ground True Labels
Estimate the ejection fraction using the provided manual $LV_{endo}$ annotation. The approximation protocol is still the Simpson's monoplane method. 

In [16]:
derived_reported_EF_list = []
true_EF_list = []

for i in range(len(test_dataset)):
    test_pat_index = i
    try:
        video, (filename, EF, es_clip_index, ed_clip_index, es_index, ed_index, es_frame, ed_frame, es_label, ed_label) = test_dataset[test_pat_index]
    except:
        print("Get exception when trying to read the video from patient:{:04d}".format(i))
        continue
    
    output_ED = ed_label
    output_ES = es_label

    length_ed, radius_ed = get2dPucks((output_ED == 1).astype('int'), (1.0, 1.0))
    length_es, radius_es = get2dPucks((output_ES == 1).astype('int'), (1.0, 1.0))

    edv = np.sum(((np.pi * radius_ed * radius_ed) * length_ed / len(radius_ed)))
    esv = np.sum(((np.pi * radius_es * radius_es) * length_es / len(radius_es)))

    ef_predicted = (edv - esv) / edv * 100
    
    if np.isnan(ef_predicted):
        print("Nan EF at patient:{:04d}".format(test_pat_index))
    
    derived_reported_EF_list.append(ef_predicted)
    true_EF_list.append(EF)

### Mean Absolute Error between the Reported EF and EF computed from Reported ED/ES labels

In [19]:
errors = np.array(np.array(true_EF_list) - np.array(derived_reported_EF_list))
abs_errors = abs(errors)

print("Mean absolute error (standard deviation):  {:.4f} ({:.4f}) %".format(np.mean(abs_errors), np.std(abs_errors)))
print("Median absolute error:  {:.4f} %".format(np.median(abs_errors)))
print("Bias +- 1.96 x std:  {:.4f} +- {:.4f}".format(np.mean(errors), 1.96 * np.std(errors)))
print("Percentile of mae 50%: {:6.4f}  75%: {:6.4f}  95%: {:6.4f}".format(np.percentile(abs_errors, 50), np.percentile(abs_errors, 75),
                                                                    np.percentile(abs_errors, 95)))

Mean absolute error (standard deviation):  1.5450 (2.1868) %
Median absolute error:  1.0394 %
Bias +- 1.96 x std:  -0.5295 +- 5.1443
Percentile of mae 50%: 1.0394  75%: 2.0723  95%: 3.8775


### Cross Correlation between the Reported EF and EF computed from Reported ED/ES labels

In [23]:
corrcoef = np.corrcoef(np.array(true_EF_list), np.array(derived_reported_EF_list))
print("Cross correlation with True EF: {:.3f}".format(corrcoef[0, 1]))

Cross correlation with True EF: 0.978
