In [None]:
# model loader - BraTS

import os
import sys
import torch
sys.path.append(".")
sys.path.append("./brats-mri")
import monai
from PIL import Image
from tqdm import tqdm
from monai.utils import first
from generative.inferers import LatentDiffusionInferer
from generative.networks.schedulers import DDIMScheduler

from torch.utils.data import DataLoader

from pretrained import load_autoencoder, load_unet
import utils

BUNDLE = './brats-mri/brats_mri_class_cond/'
sys.path.append(BUNDLE)
from scripts.inferer import LatentDiffusionInfererWithClassConditioning

def get_monai_autoencoder(bundle_target, training_args, weights_override_path):
    # load autoencoder
    autoencoder = load_autoencoder(bundle_target,
                                   override_model_cfg_json=training_args.config,
                                   override_weights_load_path=weights_override_path)    
    return autoencoder

def get_monai_unet(bundle_target, training_args, weights_override_path):
    unet = load_unet(bundle_target,
                     context_conditioning=training_args.conditioning == 'context',
                     override_model_cfg_json=training_args.config,
                     override_weights_load_path=weights_override_path,
                     use_conditioning=True)
    return unet
    
def get_monai_model_dict(bundle_target, training_args, autoencoder_weights_path, unet_weights_path):
    monai_dict = {}
    training_args = torch.load(os.path.join(output_dir, training_name, 'training_args'))
    monai_dict['autoencoder'] = get_monai_autoencoder(bundle_target, training_args, autoencoder_weights_path)
    monai_dict['unet'] = get_monai_unet(bundle_target, training_args, unet_weights_path)
    
    # set scheduler
    config = utils.model_config(bundle_target, training_args.config)
    monai_dict['scheduler'] = config.get_parsed_content('noise_scheduler')
    # set inferer
    if training_args.conditioning in ['context', 'none']:
        monai_dict['inferer'] = LatentDiffusionInferer(scheduler=scheduler, scale_factor=scale_factor)
    else:
        monai_dict['inferer'] = LatentDiffusionInfererWithClassConditioning(scheduler=scheduler, scale_factor=scale_factor)
    return monai_dict

In [None]:
# confusion matrix plotter

from sklearn.metrics import confusion_matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

def labels_to_human_labels(labels, human_labels_list):
    human_labels = [human_labels_list[int(x.detach().cpu().numpy())] for x in labels]  
    return human_labels

def plot_confusion_matrix(true_labels, labels_pred, human_labels_list):
    true_labels_readable = labels_to_human_labels(true_labels, human_labels_list)
    labels_pred_readable = labels_to_human_labels(labels_pred, human_labels_list)
    cm = confusion_matrix(true_labels_readable, labels_pred_readable, labels=human_labels_list)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                                  display_labels=human_labels_list)
    fig, ax = plt.subplots(figsize=(12,12))
    disp.plot(ax=ax)
    plt.show()

In [None]:
import os
from enum import Enum
import pickle

from ldm_classifier_brats import MonaiLdmClassifier

def strip_epoch_num_from_ckpt(ckpt_full_path):
  ckpt_name = ckpt_full_path.split('/')[-1]  
  epoch_num = int(ckpt_name.split(".")[0].split("_")[-1])
  return epoch_num

def get_monai_training_ckpt_files(output_dir, training_name):
  ckpt_dir = os.path.join(output_dir, training_name)
  autoencoder_ckpt_files = [os.path.join(ckpt_dir, ckpt) for ckpt in os.listdir(ckpt_dir) if 'autoencoder' in ckpt]
  unet_ckpt_files = [os.path.join(ckpt_dir, ckpt) for ckpt in os.listdir(ckpt_dir) if 'diffusion' in ckpt]
  return autoencoder_ckpt_files, unet_ckpt_files

def evaluate_accuracy_over_epochs(output_dir, 
                                  training_name, 
                                  dataset, 
                                  framework_type: FrameworkType,
                                  t_sampling_stride = 50,
                                  n_trials = 1
                                 ):
    # create classification results dir under the training dir
    clf_dir = os.path.join(output_dir, training_name, 'classification')
    if not os.path.exists(clf_dir):
        os.makedirs(clf_dir)
    n_pred_files = len(os.listdir(clf_dir))

    clf_res_per_epoch = {'dataset': dataset}
    # prepare files
    autencoder_ckpt_files, unet_ckpt_files = get_monai_training_ckpt_files(output_dir, training_name)
    # loop over ckpts
    for autoenc_ckpt, unet_ckpt in zip(autencoder_ckpt_files, unet_ckpt_files):
        epoch_num = strip_epoch_num_from_ckpt(autoenc_ckpt, framework_type)
        # load model
        model_dict = get_monai_model_dict(BUNDLE, training_args, autoenc_ckpt, unet_ckpt)
        # instantiate ldm classifier
        ldm_clf = MonaiLdmClassifier(**model_dict)
        # run classification
        l2_labels_pred, l1_labels_pred, true_labels = ldm_clf.classify_dataset(dataset=ds,
                                                                               batch_size=1,
                                                                               n_trials=n_trials,
                                                                               t_sampling_stride=t_sampling_stride)
        # save results
        clf_res_per_epoch[epoch_num] = {
          'true_labels': true_labels,
          'l1_pred_labels': l1_labels_pred,
          'l2_pred_labels': l2_labels_pred,
        }
        with open(os.path.join(clf_dir, f'predictions_{n_pred_files}'), 'wb') as f:
            pickle.dump(clf_res_per_epoch, f)
        
        # delete model
        del model
        del ldm_clf
    
    return clf_res_per_epoch

In [None]:
import sys
# TODO - define the correct path to data
sys.path.append('./brats-mri/brats_mri_class_cond/scripts')
from ct_rsna import CTSubset
from torchvision import transforms
import torch
from matplotlib import pyplot as plt
import numpy as np

In [None]:
# prepare data
train_dir = './data/ct-rsna/train'
val_dir = './data/ct-rsna/validation'

subset_len = 1
ds = CTSubset(data_dir=val_dir, labels_file='validation_set_dropped_nans.csv', size=256, flip_prob=0., subset_len=subset_len)

# LDM classifier params
t_sampling_stride = 50
n_trials = 1

# training dir
output_dir = './data/outputs'
training_name = 'brats001'

# evaluate
clf_res_per_epoch = evaluate_accuracy_over_epochs(output_dir, 
                                                  training_name, 
                                                  ds, 
                                                  t_sampling_stride,
                                                  n_trials)