# How to perform inference

In [None]:
import sys
sys.path.insert(1, '../')
import matplotlib.pyplot as plt
import numpy as np 
import torch
from validate.metric_utils import make_prediction_on_volume
from validate.metric_utils import add_gaussian_noise
import torch
import h5py
import matplotlib.pyplot as plt
import numpy as np
import os
from os import walk
import json
import argparse

# define the models here
from models.phirec.src import phirec_skip as phirec_skip
from models.phirec.src.phirec_skip import PHISeg as PHiRec

from models.punet.src import probabilistic_unet as probabilistic_unet
from models.punet.src.probabilistic_unet import ProbabilisticUnet

from models.unet.src import unet as unet
from models.unet.src.unet import UNet as UNet

from models.unet_dropout.src import unet as unet_dropout
from models.unet_dropout.src.unet import UNet as UNet_dropout

from models.unet_het.src import unet as unet_het
from models.unet_het.src.unet import UNet as UNet_het

from models.unet_het_dropout.src import unet as unet_het_dr
from models.unet_het_dropout.src.unet import UNet as UNet_het_dr


In [None]:
def perform_inference(us_factors, base_data_origin, split_file, model_names, settings, base_path_models, base_save_path):

    # define model parameters
    input_channels = 2 
    num_classes = 2 
    num_filters = [32, 64, 128, 192, 192, 192, 192]
    for us_factor in us_factors:
        # create input data list
        files_origin = os.path.join(base_data_origin, us_factor)
        print(files_origin)        

        # alternative for only test data
        with open(split_file, 'r') as f:
                data = json.load(f)
        f = [data['images'][i]['file_name'] for i in range(len(data['images']))]
        print('files:', f)
        print('len files:', len(f))

        for model_name in model_names:
            print('doing predictions for', model_name)
            for setting in settings:
                print('doing predictions for', setting)
                model_path = os.path.join(base_path_models, model_name, setting)
                model_file_path = []
                for (dirpath, dirnames, filename_model) in walk(model_path):
                    model_file_path.append(os.path.join(model_path, filename_model[0]))
                    break
                assert len(model_file_path) == 1
                model_file_path = model_file_path[0]

                # make distinction between the models where we are
                if model_name == 'dropout':
                    model = UNet_dropout(input_channels, num_classes)
                    n_samples = 20
                elif model_name == 'het':
                    model = UNet_het(input_channels, num_classes)
                    n_samples = 20
                elif model_name == 'phirec':
                    model = PHiRec(input_channels, num_classes, num_filters, image_size=(2,512,512))
                    n_samples = 20
                elif model_name == 'punet':
                    model = ProbabilisticUnet(input_channels, num_classes, num_filters, image_size=(2,512,512))
                    n_samples = 20
                elif model_name == 'unet':
                    model = UNet(input_channels, num_classes)
                    n_samples = 1
                elif model_name == 'het_dr':
                    model = UNet_het_dr(input_channels, num_classes)
                    n_samples = 20
                else:
                    print(model_name)
                    raise ValueError('Not the right model loaded')
                checkpoint = torch.load(model_file_path)
                model.load_state_dict(checkpoint['model_state_dict'])

                # loop over all data files and make predictions
                for input_file in f:
                    input_file_dir = os.path.join(files_origin, input_file)
                    file = h5py.File(input_file_dir, 'r')
                    vol_us = file['img_us'][()]
                    print('vol us shape', vol_us.shape)
                    file.close()
                    reconstruction = make_prediction_on_volume(vol_us, model, n_samples)
                    print('vol pred shape', reconstruction.shape)
                    

                    # create the save path
                    save_path = os.path.join(base_save_path, us_factor[8:], model_name, setting, input_file)

                    # save prediction in h5 file
                    print(save_path)
                    save_file = h5py.File(save_path, 'w')
                    save_file.create_dataset('recon', data=reconstruction)
                    save_file.close()

In [None]:
base_path_models = '/mnt/qb/baumgartner/pfischer23/final_eval_skmtea/models'
model_names = ['dropout', 'het', 'het_dr', 'phirec', 'punet', 'unet']
base_save_path = '/mnt/qb/baumgartner/pfischer23/final_eval_skmtea/predictions/reconstructions'
settings = ['train_4x', 'train_all']
us_factors = ['skm-tea-4x', 'skm-tea-8x', 'skm-tea-16x']
base_data_origin = '/mnt/qb/work/baumgartner/pfischer23/fastmriUQ'
split_file = '/mnt/qb/baumgartner/rawdata/SKM-TEA/skm-tea/v1-release/annotations/v1.0.0/test.json'

In [None]:
perform_inference(us_factors, base_data_origin, split_file, model_names, settings, base_path_models, base_save_path)

# Inference for the ensembles

In [None]:
def generate_samples_ensembles(path_to_models, model, vol_us):
    """Generate samples from the ensembles

    Args:
        path_to_models (str): the path to where the models are saved
        model (torch.module): the model to test for
        vol_us (np.array): the undersampled volume with schape (x,y,z,2)

    Returns:
        _type_: _description_
    """
    # load all paths to the ensembles
    model_paths = []
    for path, subdirs, files in os.walk(path_to_models):
        for name in files:
            model_paths.append(os.path.join(path, name))
    assert len(model_paths) == 20
    
    # loop over all models and perform predictions
    final_shape = (20,) + vol_us.shape
    samples = np.zeros(final_shape)
    print(final_shape)
    for i, model_path in enumerate(model_paths):
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        pred = make_prediction_on_volume(vol_us, model, n_samples=1)[0]
        print(pred.shape)
        samples[i] = pred
    print(samples.shape)
    return samples

def perform_inference_ensembles(us_factors, base_data_origin, split_file, model_names, settings, base_path_models, base_save_path):
    for us_factor in us_factors:
        # create input data list
        files_origin = os.path.join(base_data_origin, us_factor)
        print(files_origin)
        

        # alternative for only test data
        with open(split_file, 'r') as f:
                data = json.load(f)
        f = [data['images'][i]['file_name'] for i in range(len(data['images']))][-1:]
        print('files:', f)
        print('len files:', len(f))

        for model_name in model_names:
            print('doing predictions for', model_name)
            for setting in settings:
                print('doing predictions for', setting)
                model_path = os.path.join(base_path_models, model_name, setting)

                # define the model
                model = UNet(2, 2)

                # checkpoint = torch.load(model_file_path)
                # model.load_state_dict(checkpoint['model_state_dict'])

                # loop over all data files and make predictions
                for input_file in f:
                    input_file_dir = os.path.join(files_origin, input_file)
                    file = h5py.File(input_file_dir, 'r')
                    vol_us = file['img_us'][()]
                    print('vol us shape', vol_us.shape)
                    file.close()
                    reconstruction = generate_samples_ensembles(model_path, model, vol_us) # make_prediction_on_volume(vol_us, model, n_samples)
                    print('vol pred shape', reconstruction.shape)
                    

                    # create the save path
                    save_path = os.path.join(base_save_path, us_factor[8:], model_name, setting, input_file)

                    # save prediction in h5 file
                    print(save_path)
                    save_file = h5py.File(save_path, 'w')
                    save_file.create_dataset('recon', data=reconstruction)
                    save_file.close()

In [None]:
perform_inference_ensembles(us_factors, base_data_origin, split_file, model_names, settings, base_path_models, base_save_path)

# Evaluate SSIM/PSNR for the reconstructions

In [None]:
from validate.metric_utils import eval_ssim_psnr_big
model_names = ['dropout', 'het', 'het_dr', 'phirec', 'punet', 'unet', 'ensemble']

In [None]:
eval_ssim_psnr_big(us_factors, base_data_origin, model_names, settings, base_save_path)

# Evaluate the Reconstruction NCC

In [None]:
from validate.metric_utils import eval_ncc_big

In [None]:
eval_ncc_big(us_factors, base_data_origin, model_names, settings, base_save_path)

# Evaluate the variance

In [None]:
from validate.metric_utils import eval_var_big

In [None]:
eval_var_big(us_factors, model_names, settings, base_save_path)

# Perform segmentation inference for the reconstructions

In [None]:
from validate.metric_utils import make_prediction_on_volume_segmentation
def perform_inference_segm(us_factors, base_data_origin, split_file, model_names, settings, model_path, base_save_path):
    for us_factor in us_factors:
        # create input data list
        input_channels = 2 
        num_classes = 7 
        
        model = UNet(input_channels, num_classes)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model_state_dict'])


        # alternative for only test data
        with open(split_file, 'r') as f:
                data = json.load(f)
        f = [data['images'][i]['file_name'] for i in range(len(data['images']))]
        print('files:', f)
        print('len files:', len(f))

        for model_name in model_names:
            print('doing predictions for', model_name)
            for setting in settings:
                # loop over all data files and make predictions
                for input_file in f:
                    input_file_dir = os.path.join(base_data_origin, us_factor, model_name, setting, input_file)

                    file = h5py.File(input_file_dir, 'r')
                    data = file['recon'][()]
                    print('vol data shape', data.shape)
                    file.close()
                    # reconstruction = make_prediction_on_volume(vol_us, model, n_samples)
                    segmentation = make_prediction_on_volume_segmentation(data, model)
                    print('vol segm shape', segmentation.shape)
                    

                    # create the save path
                    save_path = os.path.join(base_save_path, us_factor, model_name, setting, input_file)

                    # save prediction in h5 file
                    print(save_path)
                    save_file = h5py.File(save_path, 'w')
                    save_file.create_dataset('segm', data=segmentation)
                    save_file.close()

In [None]:
model_path = 'path/to/your/segmentation/model.pth'
perform_inference_segm(us_factors, base_data_origin, split_file, model_names, settings, model_path, base_save_path)

# Evaluate NCC for the segmentations

In [None]:
from validate.metric_utils import eval_ncc_big_segmentations

In [None]:
eval_ncc_big_segmentations(us_factors, base_data_origin, model_names, settings, base_save_path)