In [1]:
import csv
import re
import tqdm
import scipy
import os, os.path
import pydicom as dicom
import numpy as np
from sklearn.metrics import roc_curve, auc
import time
import sys
sys.path.append("..")
import torch
import torchvision
from utils import save_checkpoint, load_checkpoint, save_metrics, load_metrics
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader, Dataset
import torchvideo.transforms as VT
from cv2 import resize
import cv2

from sklearn.metrics import roc_auc_score
from sklearn import metrics
from data_loader_hcm import load_video_data

In [3]:
mean = np.array([15.890775, 48.323906, 48.834034])
std = np.array([33.840668, 62.168327, 62.729694])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
period =1
length = 16
spatial_dim = (112, 112)
#target_video_shape = (228, 288)
frame_perm=False
jitter_ratio = 0.5
train_transform = transform=transforms.Compose([
    VT.PILVideoToTensor(ordering='TCHW'),
    ])


train_val_test_flag = 'test'
valid_dataset = load_video_data(
    train_val_test_flag = train_val_test_flag,   
    mean = mean, std = std,
    noise = None,
    length=length,
    min_len = length,
    max_len = length,
    period = period,
    select_channel = None,
    frame_perm = False,
    spatial_dim = spatial_dim,
    give_me_all_clips = True,
    device=device)

print("Test num {}".format(len(test_dataset)))

# dataloader
test_batch_size = 1

test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = 1, shuffle = False)

In [3]:
# Load Model
model = torchvision.models.video.r2plus1d_18(pretrained=True)
device = torch.device("cuda")
model = torch.nn.DataParallel(model)

model.module.fc = nn.Linear(in_features=model.module.fc.in_features, out_features=2, bias=True)

model_dir = '/model_dir/'
load_checkpoint(os.path.join(model_dir,'best_epoch.pt'), model)

device = torch.device('cuda')
model = model.to(device)

In [1]:
model.eval()
top_cut = 20
bottom_cut = 90
right_cut = 30
left_cut = 90
target_size = (112,112)
with torch.no_grad():

    # validation loop view level
    print("Begin Test")
    valid_sample_count = 0.0
    valid_running_loss = 0.0
    valid_running_acc = 0.0
    valid_labels = []
    valid_outputs = []
    valid_logits = []
    all_logits = []
    # video_path_list = []
    ground_truth_list = []
    for videos, frame_nums, labels  in tqdm.tqdm(test_dataloader):
        videos = videos.type(torch.FloatTensor).to(device)
        frame_nums = frame_nums.type(torch.LongTensor).to(device)
        labels = labels.type(torch.LongTensor).to(device)

        loop_num = int(np.floor(frame_nums.data.cpu().numpy()/16))
        sub_output = []
        sub_logits = []
        for i in range(loop_num):
            video_sub = videos[:,:,16*i:16*(i+1),:,:]
            cropped_img = []
            for i in range(video_sub.shape[0]):
                new_data = video_sub[i,:,:,top_cut:bottom_cut,right_cut:left_cut]
                new_data = new_data.permute(1, 0, 2, 3)
                resized_swapped_tensor = F.interpolate(new_data, size=target_size, mode='bicubic', align_corners=False)
                resized_tensor = resized_swapped_tensor.permute(1, 0, 2, 3)
                resized_tensor = torch.unsqueeze(resized_tensor,0)
            cropped_img.append(resized_tensor)
            cropped_img = torch.cat(cropped_img,dim=0)
        
            output = model(cropped_img)
            logits = F.softmax(output,dim=1)
            sub_logits.append(logits)
            sub_output.append(output)



        labels = labels.reshape(-1,1)

        max_out = torch.mean(torch.cat(sub_output,dim=0),dim=0)
        max_logits = torch.mean(torch.cat(sub_logits,dim=0),dim=0)
        weighted_out = max_logits[1]


        valid_labels.append(labels)
        valid_outputs.append(weighted_out)
        valid_logits.append(F.softmax(max_out)[1])


        pred_labels = weighted_out >= 0.5
        pred_labels = pred_labels.int()			

        ground_truth_list.append(labels.cpu())	
        all_logits.append(weighted_out.detach().cpu().numpy())

        valid_sample_count += labels.shape[0]
        valid_running_acc += (pred_labels == labels).sum().item()


valid_labels = torch.cat(valid_labels, dim=0)

valid_outputs = torch.stack(valid_outputs, dim=0)
valid_logits = torch.stack(valid_logits,dim=0)

valid_outputs_np = valid_outputs.cpu().data.numpy()
valid_labels_np = valid_labels.cpu().data.numpy()


label = np.concatenate(ground_truth_list).ravel()
fpr, tpr, thresh = metrics.roc_curve(valid_labels_np,valid_outputs_np)
test_auc = metrics.auc(fpr,tpr) 

print(test_auc)

In [13]:
model_test_results_dir = '/model_test_results_dir/'
np.save(os.path.join(model_test_results_dir, 'hcm_output_test.npy'), valid_outputs.cpu().data.numpy())
np.save(os.path.join(model_test_results_dir, 'hcm_logits_test.npy'), valid_logits.cpu().data.numpy())
np.save(os.path.join(model_test_results_dir, 'hcm_labels_test.npy'), valid_labels.cpu().data.numpy())

In [15]:
hcm_outputs = np.load(os.path.join(model_test_results_dir, 'hcm_output_test.npy'))
hcm_logits = np.load(os.path.join(model_test_results_dir, 'hcm_logits_test.npy'))
hcm_labels = np.load(os.path.join(model_test_results_dir, 'hcm_labels_test.npy'))

valid_outputs_np = hcm_outputs
valid_labels_np = hcm_labels
video_data_list = []

cases_path = '/data/cases_test_list.csv'
with open(cases_path, 'r') as csvfile:
    csvreader = csv.reader(csvfile)
    for row in csvreader:
        video_data_list.append(row[0])

        
controls_path = '/data/controls_test_list.csv'
with open(controls_path, 'r') as csvfile:
    csvreader = csv.reader(csvfile)
    for row in csvreader:
        video_data_list.append(row[0])

In [2]:
valid_labels_np = hcm_labels
valid_outputs_np = hcm_outputs
study_level_prediction = {}
num_views = 0
case_count = 0
cases_list = []
for idx, e_data in enumerate(video_data_list): # validation list
    study_name = e_data.split('/')[-1].split('_')[0]
    if valid_labels_np[idx] == 0:
        study_name = "/data/hcm/" + study_name
    else:
        study_name = "/data/hcm/cases/" + study_name
        cases_list.append(study_name)
    if study_name not in study_level_prediction:
        study_level_prediction[study_name] = ([], [], valid_labels_np[idx])
        study_level_prediction[study_name][0].append(valid_outputs_np[idx])
    else:
        study_level_prediction[study_name][0].append(valid_outputs_np[idx])
        assert study_level_prediction[study_name][2] == valid_labels_np[idx]

study_level_prediction_list = []
study_level_prediction_label = []

for e_study in study_level_prediction:
    preds = study_level_prediction[e_study][0]
    view_prob = study_level_prediction[e_study][1]
    study_pred = np.zeros(1)
    view_count = 0
    for i in range(len(preds)):
        study_pred += preds[i]
    study_pred = study_pred/len(preds)
    study_level_prediction_list.append(study_pred)
    study_level_prediction_label.append(study_level_prediction[e_study][2])
        
study_level_prediction_list = np.vstack(study_level_prediction_list)
study_level_prediction_label = np.array(study_level_prediction_label)		

fpr, tpr, thresh = metrics.roc_curve(study_level_prediction_label,study_level_prediction_list)
test_study_auc = metrics.auc(fpr,tpr) 
print(test_study_auc)