In [None]:
"""
@author: Parisima
This notebook requires predictions
in .mat format
and BraTS'21 test set
"""

In [None]:
!pip install lpips

In [None]:
!pip install tensorflow-addons

In [None]:
!pip install elasticdeform

In [None]:
# Access to google drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
import numpy as np
import nibabel as nib
import tensorflow as tf
import os
import pandas as pd
import scipy.io as sio

In [None]:
sys.path.append('/YourPath/CCL-Synthetis/')

# Assuming utils and Datagen are your local modules or packages
from utils.model_utils import modelObj
from Synthesis.synthesis_losses import lossObj
from Datagen.h5_pretrain_Synth_Data_Generator import DataLoaderObj

import Synthesis.synth_config as cfg

In [None]:
sys.path.append('/YourPath/multi-contrast-contrastive-learning/')

from utils.utils import myCrop3D
from utils.utils import contrastStretch

def normalize_img_zmean(img, mask):
    ''' Zero mean unit standard deviation normalization based on a mask'''
    mask_signal = img[mask>0]
    mean_ = mask_signal.mean()
    std_ = mask_signal.std()
    img = (img - mean_ )/ std_
    return img, mean_, std_

def normalize_img(img):
    img = (img - img.min())/(img.max()-img.min())
    return img

def load_subject(datadir, subName):
    data_suffix = ['_t1ce.nii.gz', '_t2.nii.gz', '_t1.nii.gz', '_flair.nii.gz']
    sub_img = []
    mask = None
    subject_dir = os.path.join(datadir, subName)  # Correctly form the path to the subject's directory

    for suffix in data_suffix:
        img_path = os.path.join(subject_dir, subName + suffix)  # Correct path to the image file
        img_data = nib.load(img_path).get_fdata()
        img_data = np.rot90(img_data, -1)
        img_data = myCrop3D(img_data, (192,192))

        if mask is None:
            mask = np.zeros(img_data.shape)
            mask[img_data > 0] = 1

        img_data = contrastStretch(img_data, mask, 0.01, 99.9)
        img_data, mean_, std_ = normalize_img_zmean(img_data, mask) # Change to normalize_img if your model trained with this
        sub_img.append(img_data)

    sub_img = np.stack(sub_img, axis=-1)
    sub_img = np.transpose(sub_img, (2, 0, 1, 3))  # Adjust dimensions as needed
    sub_img = sub_img[40:120]  # Assuming your volume z-axis slice range

    return sub_img

#-----------------------------------------------------------------

def get_data(img, contrast_idx, target_contrast_idx):
    """Returns tuple (input, target) correspond to sample #idx."""
    x_train = generate_X(img, contrast_idx)
    y_train = generate_Y(img, target_contrast_idx)
    return tf.identity(x_train), tf.identity(y_train)

def generate_X(img, contrast_idx):
    X = img[..., contrast_idx]
    return X

def generate_Y(img, target_contrast_idx):
    Y = img[..., target_contrast_idx]
    return Y

In [None]:
import lpips
import torch

In [None]:
# Choose the metric model
loss_fn = lpips.LPIPS(net='alex')  # Using AlexNet
# loss_fn = lpips.LPIPS(net='vgg')  # Using VGG

In [None]:
def load_predictions(predictions_file):
    data = sio.loadmat(predictions_file)
    # Assuming the data structure is known and 'predictions' is the key
    predictions = {
        'Baseline': data['predictions']['Baseline'][0, 0],
        'Partial_Decoder': data['predictions']['Partial_Decoder'][0, 0],
        'Full_Decoder': data['predictions']['Full_Decoder'][0, 0]
    }
    # Decode byte arrays if necessary
    for key, value in predictions.items():
        if isinstance(value, bytes):
            # Example of decoding bytes to numpy array; adapt as needed
            predictions[key] = np.frombuffer(value, dtype=np.float32).reshape((80, 192, 192, 3))
    return predictions

In [None]:
import tensorflow as tf
import torch

def prepare_image(image):
    # Ensure the input image is a TensorFlow tensor with dtype float32
    image = tf.convert_to_tensor(image, dtype=tf.float32)

    # Normalize the image tensor to [-1, 1]
    image_min = tf.reduce_min(image)
    image_max = tf.reduce_max(image)
    image = 2 * (image - image_min) / (image_max - image_min) - 1

    # Check and adapt the tensor dimensions
    # Assumption: The last dimension is channels if it's exactly 3; otherwise, we assume 1 channel.
    if image.shape[-1] != 3:
        # If the image does not have three channels, we tile to create three channels
        image = tf.tile(image, [1, 1, 3])  # This expects image to be at least 3D; reshape if not

    # Ensure the image tensor is in the correct shape (channels, height, width)
    image = tf.transpose(image, [2, 0, 1])

    # Convert TensorFlow tensor to a PyTorch tensor
    image = torch.from_numpy(image.numpy()).float()

    # Ensure the image tensor has a batch dimension
    if image.dim() == 3:
        image = image.unsqueeze(0)

    return image



In [None]:
def compute_lpips(datadir, predictions_dir, cfg):
    # loss_fn = lpips.LPIPS(net='alex')  # Initialize the LPIPS function
    subject_scores = []
    subjects = sorted([f for f in os.listdir(datadir) if f.startswith("BraTS2021_")])

    for subName in subjects:
        print(f"Processing {subName}...")
        img = load_subject(datadir, subName)
        _, y_true = get_data(img, cfg.contrast_idx, cfg.target_contrast_idx)  # Ground truth images
        y_true = y_true.numpy()

        # Load predictions
        predictions_file = f'{predictions_dir}/{subName}_predictions.mat'
        predictions = load_predictions(predictions_file)

        model_scores = {}
        for model_name, y_pred in predictions.items():
            scores = []
            for i in range(y_true.shape[0]):  # Assuming y_true is numpy array
                img0 = prepare_image(y_true[i])
                # print(type(y_pred))
                img1 = prepare_image(y_pred[i])
                with torch.no_grad():
                    score = loss_fn(img0, img1)
                scores.append(score.item())
            model_scores[model_name] = np.mean(scores)

        subject_scores.append({
            'Subject_ID': subName,
            'Baseline_Avg_LPIPS': model_scores['Baseline'],
            'Full_Decoder_Avg_LPIPS': model_scores['Full_Decoder'],
            'Partial_Decoder_Avg_LPIPS': model_scores['Partial_Decoder']
        })

    return pd.DataFrame(subject_scores)

In [None]:
datadir = '/YourPath/BraTS2021_Test/'
predictions_dir = '/YourPath/predictions/'

results_df = compute_lpips(datadir, predictions_dir, cfg)
results_df.to_csv('SaveDir.csv', index=False)
print('CSV file has been saved.')