In [1]:
import numpy as np
import os
import sys

In [2]:
from easydict import EasyDict as edict
from collections import Counter
from sklearn.metrics import confusion_matrix as sklearn_cm


In [3]:
sys.path.insert(0, '/cluster/tufts/hugheslab/zhuang12/general_utilities')
from shared_utilities import load_pickle, load_txt, save_pickle

In [4]:
#required input files:

#'{}_patient_order_list.pkl'.format(split)
#'{}_patient_level_count_dicts.pkl'.format(split)

#val ensemble predictions pkl for each method each fold
#test ensemble predictions pkl for each method each fold


In [5]:
#other utilities
def calculate_balanced_accuracy(true_labels, predictions, return_type = 'balanced_accuracy'):
    '''
    used particularly for this 3-classes classification task
    '''
    
    confusion_matrix = sklearn_cm(true_labels, predictions)
    
    class0_recall = confusion_matrix[0,0]/np.sum(confusion_matrix[0])
    class1_recall = confusion_matrix[1,1]/np.sum(confusion_matrix[1])
    class2_recall = confusion_matrix[2,2]/np.sum(confusion_matrix[2])
    
    balanced_accuracy = (1/3)*class0_recall + (1/3)*class1_recall + (1/3)*class2_recall
    
    if return_type == 'all':
        return balanced_accuracy * 100, class0_recall * 100, class1_recall * 100, class2_recall * 100
    elif return_type == 'balanced_accuracy':
        return balanced_accuracy * 100
    else:
        raise NameError('Unsupported return_type in hz_utils calculate_balanced_accuracy fn')


def find_most_frequent(input_array):
    
    occurance_count = Counter(input_array)
    
    return occurance_count.most_common(2)[0][0]


def generate_PatientOrder_DataIndicesRange_ImageCount(patients_split_stats_dir, split='test'):
    
    #load the patient_order_list and patient_level_count_dicts pre-generated from Echo_multitask/generate_data/
    #Processing_Echo_method1MaintainingAspectRatio_ResizingThenPad_ExcludeDoppler_grayscale_diagnosis.py
    #Processing_Echo_method1MaintainingAspectRatio_ResizingThenPad_ExcludeDoppler_grayscale_UsedForFullUnlabeledOnly_diagnosis.py
    
    patients_order_list = load_pickle(patients_split_stats_dir, '{}_patient_order_list.pkl'.format(split))
    patientlevel_count_dicts = load_pickle(patients_split_stats_dir, '{}_patient_level_count_dicts.pkl'.format(split))

    #sanity check for patientlevel_count_dicts:
#     assert patientlevel_count_dicts_sanity_check(patientlevel_count_dicts)

    
    num_patients = len(patients_order_list)
    patients_ImageCount_list = []
    patients_DataIndicesRange_list = []
    
    for patient_id in patients_order_list:
        this_patient_number_images = 0
        for view, view_labels in patientlevel_count_dicts[patient_id]['view_labels_count'].items():
            this_patient_number_images += view_labels
        
        patients_ImageCount_list.append(this_patient_number_images)
    
    
    patient_DataIndicesEndpoints_list = np.insert(np.cumsum(patients_ImageCount_list), 0, 0)
    
    for i in range(num_patients):
        patients_DataIndicesRange_list.append((patient_DataIndicesEndpoints_list[i], patient_DataIndicesEndpoints_list[i+1]))
        
    
    return patients_order_list, patients_DataIndicesRange_list, patients_ImageCount_list
    
    
    

In [6]:
#this script only compared the selected approach for MLHC paper: SoftMajority vote, Soft probabilistic majority vote
def perform_PatientLevel_integration(args_dict):
    
    fold_idx = args_dict.fold_idx
    split = args_dict.split
    View_predictions_path = args_dict.View_predictions_path
    Diagnosis_predictions_path = args_dict.Diagnosis_predictions_path
    Diagnosis_true_labels_path = args_dict.Diagnosis_true_labels_path
    
    View_predictions = load_pickle(View_predictions_path, 'view_predictions.pkl')
    Diagnosis_predictions = load_pickle(Diagnosis_predictions_path, 'diagnosis_predictions.pkl')
    Diagnosis_true_labels = load_pickle(Diagnosis_true_labels_path, '{}_diagnosis_labels.pkl'.format(split))

    #as sanity check:
    if fold_idx == 0:
        set_size = 5690
    elif fold_idx == 1:
        set_size = 5855
    elif fold_idx == 2:
        set_size = 5377
    elif fold_idx == 3:
        set_size = 5535
    
    
    total_images = Diagnosis_predictions.shape[0]
    assert total_images == set_size
        
    relevance_threshold = args_dict.relevance_threshold
     
    #get test_patients_order_list, test_patients_DataIndicesRange_list, test_patients_ImageCount_list
    fold_multitask_patients_split_stats_dir = args_dict.fold_multitask_patients_split_stats_dir
    
    fold_multitask_patients_order_list, fold_multitask_patients_DataIndicesRange_list, fold_multitask_patients_ImageCount_list = generate_PatientOrder_DataIndicesRange_ImageCount(fold_multitask_patients_split_stats_dir, split)

    
    #compared approaches
    #the ablation suggested by reviewer:
    RelevanceFiltered_HardMajorityVote_predicted_labels = []
    
    
    #Take all available images of a patient, average the diagnosis predictions 
    SoftMajorityVote_predicted_labels = []
    SoftMajorityVote_predictions = []
    
    #confidence based priority vote, not filtered
    NotFiltered_ConfidencedBased_SoftMajorityVote_PrioritizedView_predicted_labels = []
    NotFiltered_ConfidencedBased_SoftMajorityVote_PrioritizedView_predictions = []
    
    
    patient_true_diagnosis_labels = []
    
    #loop through each patient
    for idx, patient_id in enumerate(fold_multitask_patients_order_list):
        print('Currently aggregating predictions for {}'.format(patient_id).center(100, '-'))
        this_patient_data_indices = list(range(total_images))[fold_multitask_patients_DataIndicesRange_list[idx][0]:fold_multitask_patients_DataIndicesRange_list[idx][1]] 
        
        this_patient_diagnosis_true_labels = Diagnosis_true_labels[this_patient_data_indices]
        
        assert len(list(set(this_patient_diagnosis_true_labels))) == 1, '1 patient can only have 1 diagnosis label'
        this_patient_diagnosis_single_label = this_patient_diagnosis_true_labels[0]
        
        #record this patient's true diagnosis label
        patient_true_diagnosis_labels.append(this_patient_diagnosis_single_label)
       
        this_patient_diagnosis_predictions = Diagnosis_predictions[this_patient_data_indices]
        this_patient_view_predictions = View_predictions[this_patient_data_indices]

        #Suggested Ablation:
        this_patient_TargetView_mask = this_patient_view_predictions.argmax(1) != 2
        this_patient_ConfidenceThreshold_mask = np.sum(this_patient_view_predictions[:,:2], axis=1) > relevance_threshold
        this_patient_TargetViewRelevanceMask = np.logical_and(this_patient_TargetView_mask, this_patient_ConfidenceThreshold_mask)
        this_patient_TargetView_DiagnosisPredictions = this_patient_diagnosis_predictions[this_patient_TargetViewRelevanceMask]
        
        if len(this_patient_TargetView_DiagnosisPredictions) > 0:
            this_patient_RelevanceFiltered_HardMajorityVote_predicted_label = find_most_frequent(this_patient_TargetView_DiagnosisPredictions.argmax(1))
        else: #fall back to majority vote
            print('{} at relevance threshold {} fall back to majority vote'.format(patient_id, relevance_threshold))
            this_patient_RelevanceFiltered_HardMajorityVote_prediction = np.mean(this_patient_diagnosis_predictions, axis=0)
            this_patient_RelevanceFiltered_HardMajorityVote_predicted_label = np.argmax(this_patient_SoftMajorityVote_prediction)
        
        RelevanceFiltered_HardMajorityVote_predicted_labels.append(this_patient_RelevanceFiltered_HardMajorityVote_predicted_label)
        print('true_diagnosis: {}, RelevanceFiltered_HardMajorityVote predicted_diagnosis: {}\n'.format(this_patient_diagnosis_single_label, this_patient_RelevanceFiltered_HardMajorityVote_predicted_label))
        
        
        #SoftMajorityVote:
        print('SoftMajorityVote:')
        this_patient_SoftMajorityVote_prediction = np.mean(this_patient_diagnosis_predictions, axis = 0) #a 1x3 vector
        this_patient_SoftMajorityVote_predicted_label = np.argmax(this_patient_SoftMajorityVote_prediction)
        
        #record this patient's SoftMajorityVote predictions and predicted labels
        SoftMajorityVote_predictions.append(this_patient_SoftMajorityVote_prediction)
        SoftMajorityVote_predicted_labels.append(this_patient_SoftMajorityVote_predicted_label)
                        
        print('true_diagnosis:{}, SoftMajorityVote predicted_diagnosis:{}\n'.format(this_patient_diagnosis_single_label, this_patient_SoftMajorityVote_predicted_label))

        
        #NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView
        this_patient_ViewRelevance =  np.sum(this_patient_view_predictions[:,:2], axis=1)
        this_patient_DiagnosisPrediction_with_ViewRelevance = np.mean(this_patient_diagnosis_predictions * this_patient_ViewRelevance[:, np.newaxis], axis=0)
        this_patient_DiagnosisPredictedLabel_with_ViewRelevance = np.argmax(this_patient_DiagnosisPrediction_with_ViewRelevance)

        #record this patient's ConfidenceBased_SoftMajorityVote_PrioritizedView
        NotFiltered_ConfidencedBased_SoftMajorityVote_PrioritizedView_predictions.append(this_patient_DiagnosisPrediction_with_ViewRelevance)
        NotFiltered_ConfidencedBased_SoftMajorityVote_PrioritizedView_predicted_labels.append(this_patient_DiagnosisPredictedLabel_with_ViewRelevance)

        print('true_diagnosis:{}, NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView predicted_diagnosis:{}\n'.format(this_patient_diagnosis_single_label, this_patient_DiagnosisPredictedLabel_with_ViewRelevance))
        
        print('\n')
        
    RelevanceFiltered_HardMajorityVote_balanced_accuracy = calculate_balanced_accuracy(patient_true_diagnosis_labels, RelevanceFiltered_HardMajorityVote_predicted_labels)
    SoftMajorityVote_balanced_accuracy = calculate_balanced_accuracy(patient_true_diagnosis_labels, SoftMajorityVote_predicted_labels)
    NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView_balanced_accuracy = calculate_balanced_accuracy(patient_true_diagnosis_labels, NotFiltered_ConfidencedBased_SoftMajorityVote_PrioritizedView_predicted_labels)
   
                            
    returned_dict = edict()
    returned_dict.true_diagnosis_labels = patient_true_diagnosis_labels
    returned_dict.RelevanceFiltered_HardMajorityVote = {'balanced_accuracy':RelevanceFiltered_HardMajorityVote_balanced_accuracy, 'predicted_labels':np.array(RelevanceFiltered_HardMajorityVote_predicted_labels)}
    returned_dict.SoftMajorityVote = {'balanced_accuracy': SoftMajorityVote_balanced_accuracy, 'predicted_labels': np.array(SoftMajorityVote_predicted_labels), 'predictions':np.array(SoftMajorityVote_predictions)}
    returned_dict.NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView = {'balanced_accuracy':NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView_balanced_accuracy, 'predicted_labels':np.array(NotFiltered_ConfidencedBased_SoftMajorityVote_PrioritizedView_predicted_labels), 'predictions':np.array(NotFiltered_ConfidencedBased_SoftMajorityVote_PrioritizedView_predictions)}    
    
#     return  SoftMajorityVote_balanced_accuracy, ConfidenceBased_SoftMajorityVote_PrioritizedView_balanced_accuracy, LR_balanced_accuracy, SoftMajorityVote_predicted_labels, SoftMajorityVote_predictions, ConfidencedBased_SoftMajorityVote_PrioritizedView_predicted_labels, ConfidencedBased_SoftMajorityVote_PrioritizedView_predictions, LR_predicted_labels, patient_true_diagnosis_labels
    return returned_dict


## FS

In [7]:
#each fold's specification

In [8]:
fold0_args_dict = edict()
fold0_args_dict.fold_idx = 0
fold0_args_dict.split='test'
fold0_args_dict.relevance_threshold= 0.69
fold0_args_dict.View_predictions_path= '/cluster/tufts/hugheslab/zhuang12/MLHCCode_Release/predictions/ImageLevel_predictions/fold0/test/FS'
fold0_args_dict.Diagnosis_predictions_path= '/cluster/tufts/hugheslab/zhuang12/MLHCCode_Release/predictions/ImageLevel_predictions/fold0/test/FS'
fold0_args_dict.Diagnosis_true_labels_path= '/cluster/tufts/hugheslab/zhuang12/MLHCCode_Release/split_info/E4VD-156-52/fold0/test'
fold0_args_dict.fold_multitask_patients_split_stats_dir= '/cluster/tufts/hugheslab/zhuang12/MLHCCode_Release/split_info/E4VD-156-52/fold0/test'



### fold0

In [9]:
returned_dict = perform_PatientLevel_integration(fold0_args_dict) 

---------------------------Currently aggregating predictions for 1241777----------------------------
true_diagnosis: 1, RelevanceFiltered_HardMajorityVote predicted_diagnosis: 1

SoftMajorityVote:
true_diagnosis:1, SoftMajorityVote predicted_diagnosis:1

true_diagnosis:1, NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView predicted_diagnosis:1



---------------------------Currently aggregating predictions for 2929640----------------------------
true_diagnosis: 1, RelevanceFiltered_HardMajorityVote predicted_diagnosis: 1

SoftMajorityVote:
true_diagnosis:1, SoftMajorityVote predicted_diagnosis:1

true_diagnosis:1, NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView predicted_diagnosis:1



---------------------------Currently aggregating predictions for 2930384----------------------------
true_diagnosis: 1, RelevanceFiltered_HardMajorityVote predicted_diagnosis: 1

SoftMajorityVote:
true_diagnosis:1, SoftMajorityVote predicted_diagnosis:0

true_diagnosis:1, NotFilter

In [10]:
print('SoftMajorityVote_balanced_accuracy: {}'.format(returned_dict.SoftMajorityVote['balanced_accuracy']))
print('NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView_balanced_accuracy: {}'.format(returned_dict.NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView['balanced_accuracy']))
print('SuggestedAblation: {}'.format(returned_dict.RelevanceFiltered_HardMajorityVote['balanced_accuracy']))


SoftMajorityVote_balanced_accuracy: 87.91666666666667
NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView_balanced_accuracy: 90.83333333333333
SuggestedAblation: 85.41666666666666


## MixMatch

In [7]:
#each fold's specification

In [11]:
fold0_args_dict = edict()
fold0_args_dict.fold_idx = 0
fold0_args_dict.split= 'test'
fold0_args_dict.relevance_threshold= 0.63
fold0_args_dict.View_predictions_path= '/cluster/tufts/hugheslab/zhuang12/MLHCCode_Release/predictions/ImageLevel_predictions/fold0/test/MixMatch'
fold0_args_dict.Diagnosis_predictions_path= '/cluster/tufts/hugheslab/zhuang12/MLHCCode_Release/predictions/ImageLevel_predictions/fold0/test/MixMatch'
fold0_args_dict.Diagnosis_true_labels_path= '/cluster/tufts/hugheslab/zhuang12/MLHCCode_Release/split_info/E4VD-156-52/fold0/test'
fold0_args_dict.fold_multitask_patients_split_stats_dir= '/cluster/tufts/hugheslab/zhuang12/MLHCCode_Release/split_info/E4VD-156-52/fold0/test'


### fold0

In [12]:
returned_dict = perform_PatientLevel_integration(fold0_args_dict) 

---------------------------Currently aggregating predictions for 1241777----------------------------
true_diagnosis: 1, RelevanceFiltered_HardMajorityVote predicted_diagnosis: 1

SoftMajorityVote:
true_diagnosis:1, SoftMajorityVote predicted_diagnosis:1

true_diagnosis:1, NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView predicted_diagnosis:1



---------------------------Currently aggregating predictions for 2929640----------------------------
true_diagnosis: 1, RelevanceFiltered_HardMajorityVote predicted_diagnosis: 1

SoftMajorityVote:
true_diagnosis:1, SoftMajorityVote predicted_diagnosis:1

true_diagnosis:1, NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView predicted_diagnosis:1



---------------------------Currently aggregating predictions for 2930384----------------------------
true_diagnosis: 1, RelevanceFiltered_HardMajorityVote predicted_diagnosis: 1

SoftMajorityVote:
true_diagnosis:1, SoftMajorityVote predicted_diagnosis:1

true_diagnosis:1, NotFilter

In [13]:
print('SoftMajorityVote_balanced_accuracy: {}'.format(returned_dict.SoftMajorityVote['balanced_accuracy']))
print('NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView_balanced_accuracy: {}'.format(returned_dict.NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView['balanced_accuracy']))
print('SuggestedAblation: {}'.format(returned_dict.RelevanceFiltered_HardMajorityVote['balanced_accuracy']))


SoftMajorityVote_balanced_accuracy: 89.99999999999999
NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView_balanced_accuracy: 88.75
SuggestedAblation: 83.33333333333333


## Pretrained MixMatch

In [7]:
#each fold's specification

In [14]:
fold0_args_dict = edict()
fold0_args_dict.fold_idx = 0
fold0_args_dict.split= 'test'
fold0_args_dict.relevance_threshold= 0.75
fold0_args_dict.View_predictions_path= '/cluster/tufts/hugheslab/zhuang12/MLHCCode_Release/predictions/ImageLevel_predictions/fold0/test/PretrainedMixMatch'
fold0_args_dict.Diagnosis_predictions_path= '/cluster/tufts/hugheslab/zhuang12/MLHCCode_Release/predictions/ImageLevel_predictions/fold0/test/PretrainedMixMatch'
fold0_args_dict.Diagnosis_true_labels_path= '/cluster/tufts/hugheslab/zhuang12/MLHCCode_Release/split_info/E4VD-156-52/fold0/test'
fold0_args_dict.fold_multitask_patients_split_stats_dir= '/cluster/tufts/hugheslab/zhuang12/MLHCCode_Release/split_info/E4VD-156-52/fold0/test'


### fold0

In [15]:
returned_dict = perform_PatientLevel_integration(fold0_args_dict) 

---------------------------Currently aggregating predictions for 1241777----------------------------
true_diagnosis: 1, RelevanceFiltered_HardMajorityVote predicted_diagnosis: 1

SoftMajorityVote:
true_diagnosis:1, SoftMajorityVote predicted_diagnosis:1

true_diagnosis:1, NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView predicted_diagnosis:1



---------------------------Currently aggregating predictions for 2929640----------------------------
true_diagnosis: 1, RelevanceFiltered_HardMajorityVote predicted_diagnosis: 1

SoftMajorityVote:
true_diagnosis:1, SoftMajorityVote predicted_diagnosis:1

true_diagnosis:1, NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView predicted_diagnosis:1



---------------------------Currently aggregating predictions for 2930384----------------------------
true_diagnosis: 1, RelevanceFiltered_HardMajorityVote predicted_diagnosis: 1

SoftMajorityVote:
true_diagnosis:1, SoftMajorityVote predicted_diagnosis:0

true_diagnosis:1, NotFilter

In [16]:
print('SoftMajorityVote_balanced_accuracy: {}'.format(returned_dict.SoftMajorityVote['balanced_accuracy']))
print('NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView_balanced_accuracy: {}'.format(returned_dict.NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView['balanced_accuracy']))
print('SuggestedAblation: {}'.format(returned_dict.RelevanceFiltered_HardMajorityVote['balanced_accuracy']))


SoftMajorityVote_balanced_accuracy: 87.5
NotFiltered_ConfidenceBased_SoftMajorityVote_PrioritizedView_balanced_accuracy: 93.75
SuggestedAblation: 86.66666666666666
