<a href="https://colab.research.google.com/github/ser5kovskiy/Multi_MRI_diffusion/blob/main/samling_notebook_preliminary.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive2')

In [None]:
import torch
from functools import partial
import os
import argparse
import yaml

import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
path_to_data = '/content/drive2/MyDrive/deep_generative_model_project/'

In [None]:
import sys
sys.path.insert(0, path_to_data)

In [None]:
from distutils.dir_util import copy_tree

folder_to_copy = ['data', 'guided_diffusion',]

for folder in folder_to_copy:
    dest_path = os.path.join(path_to_data, folder)
    target_path = os.path.join('.', folder)
    copy_tree(dest_path, target_path)

In [None]:
from guided_diffusion.condition_methods import get_conditioning_method
from guided_diffusion.measurements import get_noise, get_operator
from guided_diffusion.unet import create_model
from guided_diffusion.gaussian_diffusion import create_sampler
from data.dataloader import get_dataset, get_dataloader

In [None]:
from skimage import io
import numpy as np
from torch.backends import cudnn
import torch
import random
import numpy as np

seed = 42
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(seed)  # Set SEED for the Random module
np.random.seed(seed)  # Set SEED for np.random module
torch.manual_seed(seed)   # Set the random number seed
torch.cuda.manual_seed(seed)  # Set the current GPU random seed
torch.cuda.manual_seed_all(seed)  # Set random seed for all GPUs

In [None]:
def load_yaml(file_path: str) -> dict:
    with open(file_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config

In [None]:
batch_size = 15
device_str = f"cuda" if torch.cuda.is_available() else 'cpu'

dataset_name = 'brain_nii'
folder_congigs = f'experiments_{dataset_name}'

#dataset_name = input()
#folder_congigs = input()

result_folder_parent_part = os.path.join(path_to_data, f'{dataset_name}/res')

config_folder = os.path.join(path_to_data, f'configs/{dataset_name}/experiments/{folder_congigs}')

In [None]:
model_config = load_yaml(f'{config_folder}/model_config_MRI_T1.yaml')
diffusion_config = load_yaml(f'{config_folder}/diffusion_config.yaml')
task_config = load_yaml(f'{config_folder}/MRI_config_hyperparameters_one_object_test_05.yaml')
dataset_mode = task_config['data']['folder'].split('/')[-1]

In [None]:
device = torch.device(device_str)
model = create_model(**model_config)
model = model.to(device)
model = model.eval()

In [None]:
measure_config = task_config['measurement']
operator = get_operator(device=device, **measure_config['operator'])
noiser = get_noise(**measure_config['noise'])

In [None]:
modalities = [0, 1]
a = int(input())
if(a != 0):
    modalities = modalities[a - 1]

In [None]:
assert modalities + 1 == int(model_config['model_path'][-4])

In [None]:
data_config = task_config['data']
data_config['modalities'] = modalities

In [None]:
data_config

In [None]:
dataset = get_dataset(**data_config)
loader = get_dataloader(dataset, batch_size=batch_size, num_workers=0, train=False,)

In [None]:
modalities_num = dataset[0][0].shape[0]
modalities_num

In [None]:
from PIL import Image
from skimage.metrics import structural_similarity, mean_squared_error, peak_signal_noise_ratio
import imageio as io

def saving_object(parent_dir, obj_type, obj_num, obj):
    obj = np.hstack((obj[...,0], obj[...,1]))
    obj = Image.fromarray(obj)

    path_to_save = os.path.join(parent_dir, obj_type, f'{obj_num}.jpeg')
    os.makedirs(os.path.dirname(path_to_save), exist_ok=True)
    io.imwrite(path_to_save, obj, format='jpeg')

def object_metrics(prediction, GT):
    obj_SSIM = []
    obj_PSNR = []
    #GT = (GT * 255).astype(np.uint8)
    #prediction = (prediction * 255).astype(np.uint8)
    for idx in range(GT.shape[2]):
        cur_GT = GT[:,:,idx]
        cur_pred = prediction[:,:,idx]
        obj_SSIM += [structural_similarity(cur_GT, cur_pred)]
        obj_PSNR += [peak_signal_noise_ratio(cur_GT, cur_pred)]
    obj_SSIM = np.array(obj_SSIM)
    obj_PSNR = np.array(obj_PSNR)
    return obj_SSIM, obj_PSNR

def metrics_calculate(prediction, GT, raw_data, cur_objs_num, parent_dir, middle_result_txt_file):
    prediction = np.moveaxis(np.array(prediction), 1, -1)
    GT = np.moveaxis(np.array(GT[:,:,0,]), 1, -1)
    raw_data = np.moveaxis(np.array(raw_data[:,:,0,]), 1, -1)

    prediction = (prediction * 255).astype(np.uint8)
    GT = (GT * 255).astype(np.uint8)
    raw_data = (raw_data * 255).astype(np.uint8)

    idx = 0
    set_batch_SSIM = np.array([0.,] * prediction.shape[1])
    set_batch_PSNR = np.array([0.,] * prediction.shape[1])

    for idx, cur_obj in enumerate(zip(prediction, GT, raw_data)):
        cur_pred, cur_GT, cur_raw = cur_obj

        saving_object(parent_dir=parent_dir, obj_type = 'recon',
                      obj_num=cur_objs_num + idx, obj = cur_pred)

        saving_object(parent_dir=parent_dir, obj_type = 'label',
                      obj_num=cur_objs_num + idx, obj = cur_GT)

        saving_object(parent_dir=parent_dir, obj_type = 'input',
                      obj_num=cur_objs_num + idx, obj = cur_raw)

        cur_obj_metrics = object_metrics(prediction=cur_pred, GT=cur_GT)
        print(f'object:{cur_objs_num + idx} ', cur_obj_metrics)
        result_line = f'object:{cur_objs_num + idx} - SSIM: {cur_obj_metrics[0]} PSNR:{cur_obj_metrics[1]} SSIM mean: {np.mean(cur_obj_metrics[0])} PSNR common: {np.mean(cur_obj_metrics[1])}\n'
        with open(middle_result_txt_file, 'a+') as file:
            file.write(result_line)

        set_batch_SSIM += cur_obj_metrics[0]
        set_batch_PSNR += cur_obj_metrics[1]
    return set_batch_SSIM, set_batch_PSNR

def fftc(image):
    kspace = torch.fft.fftshift(torch.fft.fft2(image), dim=(-2,-1))
    return kspace

def ifftc(kspace):
    image = torch.fft.ifft2(torch.fft.ifftshift(kspace, dim=(-2,-1)))
    return image

def call_metrics(sample,
                 batch,
                 measurement,
                 mask,
                 cur_objs_num,
                 out_path,
                 middle_result_txt_file,
                 post_processing = False,
                 SSIM_total=np.array([0., 0.]),
                 PSNR_total=np.array([0., 0.])):

    num_samples = batch.shape[0]
    sample = min_max_norm(sample)
    batch = min_max_norm(batch)
    import copy
    sample_before = copy.deepcopy(sample)

    if(post_processing):
        print("post_processing")
        inverse_mask = 1. - mask
        sample = sample[:,:,None,:,:]
        f_masked_data = fftc(measurement) * mask + fftc(sample) * inverse_mask
        masked_data = ifftc(f_masked_data).to(sample.dtype)
        sample = masked_data[:,:,0,:,:]
        sample = min_max_norm(sample)
    measurement = min_max_norm(measurement).cpu().detach()
    set_obj_SSIM, set_obj_PSNR = metrics_calculate(prediction=sample,
                                                GT=batch,
                                                cur_objs_num = cur_objs_num,
                                                parent_dir=out_path,
                                                raw_data=measurement,
                                                middle_result_txt_file=middle_result_txt_file)
    SSIM_total += set_obj_SSIM
    PSNR_total += set_obj_PSNR
    print('Current batch metrics: ', set_obj_SSIM / num_samples, set_obj_PSNR / num_samples, np.mean(set_obj_SSIM) / num_samples, np.mean(set_obj_PSNR) / num_samples)

    total_considered_objs = cur_objs_num + num_samples
    cur_SSIM_total = SSIM_total / total_considered_objs
    cur_PSNR_total = PSNR_total / total_considered_objs
    print('Current total: ', cur_SSIM_total, cur_PSNR_total, np.mean(cur_SSIM_total), np.mean(cur_PSNR_total))

    result_line = f'Total number of considered objects {total_considered_objs} - SSIM: {cur_SSIM_total} PSNR:{cur_PSNR_total} SSIM mean: {np.mean(cur_SSIM_total)} PSNR common: {np.mean(cur_PSNR_total)}\n'

    with open(middle_result_txt_file, 'a+') as file:
        file.write(result_line)

    return SSIM_total, PSNR_total

In [None]:
import itertools

cond_config = task_config['conditioning']
hyper_params_to_search = list(task_config['conditioning']['params'].keys())
hyper_params_list = [cond_config['params'][x] for x in hyper_params_to_search]
hyper_param_combinations = list(itertools.product(*hyper_params_list))
len(hyper_param_combinations)

In [None]:
hyper_param_combinations

In [None]:
def set_new_config(original_config, hp_names, hp_values):
    for name, value in zip(hp_names, hp_values):
        original_config['params'][name] = value
    return original_config

In [None]:
def min_max_norm(data):
    min_ = torch.amin(data, dim=(-2,-1), keepdim=True)
    max_ = torch.amax(data, dim=(-2,-1), keepdim=True)
    return (data - min_) / (max_ - min_)

In [None]:
total_samples_to_consider = 500.

#sample, batch, measurement, mask = None, None, None, None

for cur_hp in hyper_param_combinations:
    print(cur_hp)
    cond_config = set_new_config(original_config=cond_config,
                            hp_names=hyper_params_to_search,
                            hp_values=cur_hp)


    method_folder = f'''T{modalities}_{dataset_mode}_{dataset_name}_{hyper_params_to_search}'''.replace('\n', '')

    unique_add_results_folder = f'''{hyper_params_to_search}_{cur_hp}_'''.replace('\n', '')

    result_folder = f'{result_folder_parent_part}/{method_folder}/{unique_add_results_folder}'
    os.makedirs(result_folder, exist_ok=True)

    out_path = os.path.join(result_folder, measure_config['operator']['name'])
    os.makedirs(out_path, exist_ok=True)
    for img_dir in ['input', 'recon', 'progress', 'label']:
        os.makedirs(os.path.join(out_path, img_dir), exist_ok=True)

    cond_method = get_conditioning_method(cond_config['method'], operator, noiser, **cond_config['params'])
    measurement_cond_fn = cond_method.conditioning

    sampler = create_sampler(**diffusion_config)
    sample_fn = partial(sampler.p_sample_loop, model=model, measurement_cond_fn=measurement_cond_fn)

    SSIM_total = np.array([0.] * modalities_num)
    PSNR_total = np.array([0.] * modalities_num)

    #SSIM_total_post = np.array([])
    #PSNR_total_post = np.array([])

    middle_result_txt_file = os.path.join(result_folder, 'middle_results.txt')
    middle_result_txt_file_postprocessing = os.path.join(result_folder, 'middle_results_post_processing.txt')
    for i, all_batch in tqdm(enumerate(loader)):
        if(i < total_samples_to_consider):
            print('HI')
            fname = str(i).zfill(5) + '.png'

            all_batch = [x.to(device) for x in all_batch]
            #batch - [9, 2, 1, 128, 128] - [batch_size; modalities; channels; spatial#1, spatial#2]
            #measurа ement - [9, 2, 1, 128, 128] - [batch_size; modalities; channels; spatial#1, spatial#2]
            #mask - [9, 1, 1, 128, 128] - [batch_size; modalities; channels; spatial#1, spatial#2]

            batch, measurement, mask = all_batch

            #x_start - [9, 2, 128, 128] - [batch_size; channels; spatial#1, spatial#2]
            x_start = torch.randn(batch[:,:,0,].shape, device=device).requires_grad_()

            meas_support = dict()
            meas_support['mask'] = mask

            sample = sample_fn(x_start=x_start, measurement=measurement,
                               meas_support=meas_support,
                               record=True, save_root=out_path,)

            sample = sample.cpu().detach()
            batch = batch.cpu().detach()
            measurement = measurement.cpu().detach()
            mask = mask.cpu().detach()

            SSIM_total, PSNR_total = call_metrics(sample=sample,
                         batch=batch,
                         measurement=measurement,
                         mask=mask,
                         cur_objs_num = i * batch_size,
                         out_path=out_path,
                         middle_result_txt_file=middle_result_txt_file,
                         post_processing = False,
                         SSIM_total=SSIM_total,
                         PSNR_total=PSNR_total)
            """
            SSIM_total_post, PSNR_total_post = call_metrics(sample=sample,
                         batch=batch,
                         measurement=measurement,
                         mask=mask,
                         cur_objs_num = i * batch_size,
                         out_path=out_path,
                         middle_result_txt_file=middle_result_txt_file_postprocessing,
                         post_processing = True,
                         SSIM_total=SSIM_total_post,
                         PSNR_total=PSNR_total_post)
            """
        else:
            break

    """
    SSIM_total /= (total_samples_to_consider * batch_size)
    PSNR_total /= (total_samples_to_consider * batch_size)
    print('SSIM_total', SSIM_total)
    print('PSNR_total', PSNR_total)
    result_line = f'Total~ {total_samples_to_consider * batch_size} objects POST - SSIM: {SSIM_total_post} PSNR:{PSNR_total_post} SSIM mean: {np.mean(SSIM_total)} PSNR common: {np.mean(PSNR_total)}\n'

    with open(os.path.join(result_folder, 'final_results.txt'), 'a+') as file:
        file.write(result_line)


    SSIM_total_post /= (total_samples_to_consider * batch_size)
    PSNR_total_post /= (total_samples_to_consider * batch_size)
    print('SSIM_total', SSIM_total_post)
    print('PSNR_total', PSNR_total_post)
    result_line = f'Total~ {total_samples_to_consider * batch_size} objects POST - SSIM: {SSIM_total_post} PSNR:{PSNR_total_post} SSIM mean: {np.mean(cur_SSIM_total)} PSNR common: {np.mean(cur_PSNR_total)}\n'

    with open(os.path.join(result_folder, 'final_results.txt'), 'a+') as file:
        file.write(result_line)
    """