In [None]:
import nibabel as nib
import matplotlib.pyplot as plt
import configparser
import numpy as np
import cv2
import torch

In [None]:
# DCM patient
patient_cfg = configparser.ConfigParser()
patient_cfg.read_string('[Info]\n' + open('./ACDC/training/patient001/info.cfg').read())

ed = patient_cfg.get('Info', 'ED')
es = patient_cfg.get('Info', 'ES')
group = patient_cfg.get('Info', 'Group')
nbFrame = patient_cfg.get('Info', 'NbFrame')
height = patient_cfg.get('Info', 'Height')
weight = patient_cfg.get('Info', 'Weight')

nii_file_path1 = "./ACDC/training/patient001/patient001_frame01.nii.gz"
nii_file_path2 = "./ACDC/training/patient001/patient001_frame01_gt.nii.gz"

nii_img1 = nib.load(nii_file_path1)
nii_data1 = nii_img1.get_fdata()

nii_img2 = nib.load(nii_file_path2)
nii_data2 = nii_img2.get_fdata()

# Display slices of both NIfTI images side by side using matplotlib
num_slices = nii_data1.shape[2]


plt.figure(figsize=(12, 5 * num_slices))

for slice_idx in range(num_slices):
    plt.subplot(num_slices, 2, 2*slice_idx + 1)
    plt.imshow(nii_data1[:, :, slice_idx], cmap="gray")
    plt.title(f"DCM - Slice {slice_idx}")
    plt.axis("off")

    plt.subplot(num_slices, 2, 2*slice_idx + 2)
    plt.imshow(nii_data2[:, :, slice_idx], cmap="gray")
    plt.title(f"GT - Slice {slice_idx}")
    plt.axis("off")

title = f"Patient 1, ED: {ed}, ES: {es}, Group: {group}, NbFrame: {nbFrame}, Height: {height}, Weight: {weight}"
plt.suptitle(title)
plt.tight_layout()
plt.show()


In [None]:
def convert_mask_single(y):
    """
    Given one masks with many classes create one mask per class
    y: shape (w,h)
    """
    mask = np.zeros((4, y.shape[0], y.shape[1]))
    mask[0, :, :] = np.where(y == 0, 1, 0)
    mask[1, :, :] = np.where(y == 1, 1, 0)
    mask[2, :, :] = np.where(y == 2, 1, 0)
    mask[3, :, :] = np.where(y == 3, 1, 0)

    return mask

def get_images(img_path, input_size=(224,224,1)):
    all_imgs = []
    img = nib.load(img_path).get_fdata()
    for idx in range(img.shape[2]):
        i = cv2.resize(img[:,:,idx], (input_size[0], input_size[1]), interpolation=cv2.INTER_NEAREST)
        # cv2.imwrite(f"raw_{idx}.png", img[:,:,idx].astype("float32"))
        all_imgs.append(i)
    all_imgs = np.expand_dims(all_imgs, axis=3)
    return [all_imgs, torch.empty((all_imgs.shape[0], all_imgs.shape[1], all_imgs.shape[2], 1), dtype=torch.float32)]


def visualize(image_raw,mask):
    """
    iamge_raw:gray image with shape [width,height,1]
    mask: segment mask image with shape [num_class,width,height]
    this function return an image using multi color to visualize masks in raw image
    """
    # Convert grayscale image to RGB
    image = cv2.cvtColor(image_raw, cv2.COLOR_GRAY2RGB)
    
#     image = image_raw
#     mask = mask.numpy()

    # Get the number of classes (i.e. channels) in the mask
    num_class = mask.shape[0]


    # Define colors for each class (using a simple color map)
    colors = []
    for i in range(1, num_class):  # skip first class (background)
        hue = int(i/float(num_class-1) * 179)
        color = np.zeros((1, 1, 3), dtype=np.uint8)
        color[0, 0, 0] = hue
        color[0, 0, 1:] = 255
        color = cv2.cvtColor(color, cv2.COLOR_HSV2RGB)
        colors.append(color)

    # Overlay each non-background class mask with a different color on the original image
    for i in range(1, num_class):
        class_mask = mask[i, :, :]
        class_mask = np.repeat(class_mask[:, :, np.newaxis], 3, axis=2)
        class_mask = class_mask.astype(image.dtype)
        class_mask = class_mask * colors[i-1]
        image = cv2.addWeighted(image, 1.0, class_mask, 0.5, 0.0)

    return image

def evaluate_model(model, dataloader):
    device = torch.device("cuda")
    model.eval()
    model = model.to(device)
    patient_id = 101
    slice_id = 1
    i = 0
    scores = pd.DataFrame(columns=['patient_id', 'slice_id', 'dice_avg', 'dice_lv', 'dice_rv', 'dice_myo'])
    
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        with torch.no_grad():
            outputs = model(inputs)

        y_pred = torch.argmax(outputs[2], axis=1)
        mask = convert_mask_single(y_pred[0, :, :].cpu().numpy())
        
        # Visualize the input image, ground truth mask, and predicted mask
        input_image = inputs[0].cpu().numpy().transpose(1, 2, 0)
        # convert into a single channel to visualize
        ground_truth_mask = torch.argmax(targets[0], dim=0)
        predicted_mask = y_pred.cpu().numpy().transpose(1, 2, 0)
        mask_with_image = visualize(input_image, mask)
        mask_with_image = (mask_with_image - mask_with_image.min()) / (mask_with_image.max()- mask_with_image.min()) *255
#         cv2.imwrite('here.jpg',mask_with_image)    
    

        plt.figure(figsize=(12, 4))
        plt.subplot(1, 4, 1)
        plt.title("Input Image")
        plt.imshow(input_image, cmap='gray')

        plt.subplot(1, 4, 2)
        plt.title("Ground Truth Mask")        
        plt.imshow(ground_truth_mask.cpu(), cmap='gray')

        plt.subplot(1, 4, 3)
        plt.title("Predicted Mask")
        plt.imshow(predicted_mask, cmap='gray')
        
        plt.subplot(1, 4, 4)
        plt.title("Predicted Mask2")
        plt.imshow(mask_with_image.astype(np.uint8))
        plt.show()
#         i+=1
#         if  i == 2:
#         break
    

## image and its groud truth

In [None]:
# img_path = "./ACDC/training/patient002/patient002_frame01.nii.gz"

# input_size = (224, 224, 1)
# all_imgs = []
# img = nib.load(img_path).get_fdata()
# for idx in range(img.shape[2]):
    
#     i = cv2.resize(img[:,:,idx], (input_size[0], input_size[1]), interpolation=cv2.INTER_NEAREST)
#     all_imgs.append(i)
# #     plt.title("Input Image")
# #     plt.imshow(i, cmap='gray')
# #     plt.show()
#     # cv2.imwrite(f"raw_{idx}.png", img[:,:,idx].astype("float32"))


# img_path = "pred_feature_patient002_frame01.nii.gz"
img_path = "./ACDC/training/patient002/patient002_frame01_sg.nii.gz"
input_size = (224, 224, 1)
all_imgs_gt = []
img = nib.load(img_path).get_fdata()
for idx in range(img.shape[2]):
    i = cv2.resize(img[:,:,idx], (input_size[0], input_size[1]), interpolation=cv2.INTER_NEAREST)
    all_imgs_gt.append(i)
    plt.title("Predicted Mask")
    plt.imshow(i, cmap='gray')
    plt.show()
    
    # cv2.imwrite(f"raw_{idx}.png", img[:,:,idx].astype("float32"))


In [None]:
# img_path = "pred_feature_patient002_frame01.nii.gz"
img_path = "./ACDC/training/patient002/patient002_frame01_gt.nii.gz"
input_size = (224, 224, 1)
all_imgs_gt = []
img = nib.load(img_path).get_fdata()
for idx in range(img.shape[2]):
    i = cv2.resize(img[:,:,idx], (input_size[0], input_size[1]), interpolation=cv2.INTER_NEAREST)
    all_imgs_gt.append(i)
    plt.title("Predicted Mask")
    plt.imshow(i, cmap='gray')
    plt.show()


## Model

## With one .nii image

In [None]:
model = torch.load('models/fct.model')

In [None]:
from utils.data_utils import get_acdc,convert_masks
import pandas as pd
from torch.utils.data import DataLoader,TensorDataset
import os

img_path = "./ACDC/testing/patient101/patient101_frame01.nii.gz"   

acdc_data= get_images(img_path, input_size=(224, 224, 1))
acdc_data[1] = convert_masks(acdc_data[1])
acdc_data[0] = np.transpose(acdc_data[0], (0, 3, 1, 2)) # for the channels
acdc_data[1] = np.transpose(acdc_data[1], (0, 3, 1, 2)) # for the channels
acdc_data[0] = torch.Tensor(acdc_data[0]) # convert to tensors
acdc_data[1] = torch.Tensor(acdc_data[1]) # convert to tensors
acdc_data = TensorDataset(acdc_data[0], acdc_data[1])
test_loader = DataLoader(acdc_data, batch_size=1, num_workers=2)
print(len(test_loader))

In [None]:
evaluate_model(model, test_loader)

# with all the testing data

In [None]:
from utils.data_utils import get_acdc,convert_masks
import pandas as pd
from torch.utils.data import DataLoader,TensorDataset

acdc_data,_,_ = get_acdc("ACDC/testing", input_size=(224, 224, 1))
print(acdc_data[1].shape, acdc_data[0].shape)
acdc_data[1] = convert_masks(acdc_data[1])
acdc_data[0] = np.transpose(acdc_data[0], (0, 3, 1, 2)) # for the channels
acdc_data[1] = np.transpose(acdc_data[1], (0, 3, 1, 2)) # for the channels
acdc_data[0] = torch.Tensor(acdc_data[0]) # convert to tensors
acdc_data[1] = torch.Tensor(acdc_data[1]) # convert to tensors
acdc_data = TensorDataset(acdc_data[0], acdc_data[1])
test_loader = DataLoader(acdc_data, batch_size=1, num_workers=2)
print(len(test_loader))

In [None]:
scores = evaluate_model(model, test_loader)
scores

## save predicted images for each patient

In [None]:
img_path = "./ACDC/testing/patient103/patient103_frame01.nii.gz"

input_size = (224, 224, 1)
all_imgs_gt = []
img = nib.load(img_path).get_fdata()

In [None]:
img.shape

In [None]:
img.dtype

In [None]:
def evaluate_model(model, dataloader):
    device = torch.device("cuda")
    model.eval()
    model = model.to(device)
    results = []
    
    """
    outputs[2],  torch.Size([1, 4, 224, 224])
    y_pred,  torch.Size([1, 224, 224])
    mask,  (4, 224, 224)
    predicted_mask  (224, 224, 1)
    """
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        with torch.no_grad():
            outputs = model(inputs)

        y_pred = torch.argmax(outputs[2], axis=1)
        
        results.append(y_pred[0,:,:].cpu().numpy())
    return results
        
def get_images(img, input_size=(224,224,1)):
    """
    given one .nii file and return all the frames in one list
    """
    all_imgs = []
    # img = nib.load(img).get_fdata()
    # img = img.get_fdata()
    for idx in range(img.shape[2]):
        i = cv2.resize(img[:,:,idx], (input_size[0], input_size[1]), interpolation=cv2.INTER_NEAREST)
        # cv2.imwrite(f"raw_{idx}.png", img[:,:,idx].astype("float32"))
        all_imgs.append(i)
    all_imgs = np.expand_dims(all_imgs, axis=3)
    return [all_imgs, torch.empty((all_imgs.shape[0], all_imgs.shape[1], all_imgs.shape[2], 1), dtype=torch.float32)]
 
def predict(img, model):
    
    acdc_data= get_images(img, input_size=(224, 224, 1))
    acdc_data[1] = convert_masks(acdc_data[1])
    acdc_data[0] = np.transpose(acdc_data[0], (0, 3, 1, 2)) # for the channels
    acdc_data[1] = np.transpose(acdc_data[1], (0, 3, 1, 2)) # for the channels
    acdc_data[0] = torch.Tensor(acdc_data[0]) # convert to tensors
    acdc_data[1] = torch.Tensor(acdc_data[1]) # convert to tensors
    acdc_data = TensorDataset(acdc_data[0], acdc_data[1])
    test_loader = DataLoader(acdc_data, batch_size=1, num_workers=2)
    results = evaluate_model(model, test_loader)
    return results

In [None]:
model = torch.load('models/fct.model')

In [None]:
res = predict(img, model)

In [None]:
len(res)

In [None]:
res[0].shape

In [None]:
res = np.array(res).transpose(1, 2, 0).astype("float64")

In [None]:
res.shape

In [None]:
affine = np.eye(4)
nifti_file = nib.Nifti1Image(res, affine)

nib.save(nifti_file, "predicted.nii.gz") # Here you put the path + the extionsion 'nii' or 'nii.gz'



In [None]:
img_path = "predicted.nii.gz"

input_size = (224, 224, 1)
all_imgs_gt = []
img_mask = nib.load(img_path).get_fdata()

In [None]:
img.shape

In [None]:
(img_mask == res).sum()

In [None]:
# import re
# model = torch.load('models/fct.model')

# for root, directories, files in os.walk("ACDC/testing"):
#         for file in files:
#             if ".gz" and "frame" in file:
#                 if "_gt" not in file:
#                     img_path = root + "/" + file
#                     out_path = root + "/pred_" + file
#                     print(img_path)
#                     img = nib.load(img_path).get_fdata()
#                     data_loader = prepare_data(img)
#                     result = model_output(model, dataloader)
#                     affine = np.eye(4)
#                     nifti_file = nib.Nifti1Image(result, affine)
#                     nib.save(nifti_file, out_path) # Here you put the path + the extionsion 'nii' or 'nii.gz'
   