In [None]:
### please specify your input path here
PROJECT_FOLDER = "YOUR_PROJECT_FOLDER" # parent folder of the input images
IMAGE_DATA_FOLDER = PROJECT_FOLDER + "images/" # folder of the input images
INPUT_TEST_CSV_FILE = "YOUR_TEST_FILE" # csv file list locations / paths to test cases (dicom)
OUTPUT_FOLDER = "YOUR_OUTPUT_FOLDER" # folder to save the output images

**Step 1: Import packages**

In [None]:
import os
import pickle
import time
import SimpleITK as sitk
import numpy as np
import torch
import path
import sys
import pandas as pd
import skimage
import skimage.morphology
import threading
import math
import gc

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# ------------------------------------ #
src_path = './'
if src_path not in sys.path:
    sys.path.insert(0, src_path)

print(os.listdir(src_path))
print(sys.path)
# ------------------------------------ #

from Utils.CommonTools import sitk_base
from Utils.CommonTools.NiiIO import read_from_DICOM_dir
from Utils.PreProcessing.resampling import sitk_dummy_3D_resample
from Utils.CommonTools.bbox import get_bbox, extend_bbox
from Utils.post_processing import keep_largest_cervical_cc
from Utils.Inference.nnunet_inference import NNUnetCTPredictor
from Utils.CommonTools.dir import try_recursive_mkdir
from Utils.CommonTools.NiiIO import read_from_DICOM_dir
from Utils.CommonTools.sitk_base import resample, copy_nii_info, get_nii_info

from Training.Task_301_PostProcessing_Overall.model import Model as ModelStage3
print("==> Import success")

**Step 2: Define predictors**

* PredictorStage1: Segment C1-C7
* PredictorStage2: Segment bone fracture region
* PredictorStage2: Predict final score using outputs from PredictorStage1 and PredictorStage2

In [2]:
class PredictorStage2(NNUnetCTPredictor):
    def __init__(self, *args, **kwargs):
        super(PredictorStage2, self).__init__(*args, **kwargs)

    def resampling(self, ct_nii):

        ori_spacing = ct_nii.GetSpacing()[::-1]  # to z,y,x
        ori_size = ct_nii.GetSize()[::-1]
        new_spacing = self.plan['plans_per_stage'][self.plan_stage]['current_spacing']

        # For faster inference, but may reduce accuracy
        new_size = [int(math.ceil(ori_size[0] * ori_spacing[0] / 0.8)), 224, 224]

        new_spacing[0] = 0.8
        new_spacing[1] = ori_size[1] * ori_spacing[1] / 224.
        new_spacing[2] = ori_size[2] * ori_spacing[2] / 224.

        do_resampling = np.any(np.abs(np.array(ori_spacing) - np.array(new_spacing)) > self.resampling_tolerance)
        if do_resampling:
            ct_nii = sitk_dummy_3D_resample(
                ct_nii,
                new_spacing=new_spacing[::-1],
                new_size=new_size[::-1],
                interp_xy=self.resampling_mode,
                interp_z=sitk.sitkNearestNeighbor,
                out_dtype=self.resampling_dtype,
                constant_value=self.resampling_constance_value
            )
        else:
            print(f'==> No necessary to do resampling ori {ori_spacing}, new: {new_spacing}')

        return ct_nii

class PredictorStage3:
    def __init__(self, list_model_pth, device, tta=False, tta_flip_axis=(4, )):
        self.list_model_pth = list_model_pth
        self.device = device
        self.tta = tta
        self.tta_flip_axis = tta_flip_axis

        self.list_model = None

        self.in_ch = 2
        self.out_ch = 1
        self.list_ch = [-1, 16, 32, 64, 128]

        self.init_model()

    def init_model(self):
        with torch.no_grad():
            self.list_model = []
            for i in range(len(self.list_model_pth)):
                model = ModelStage3(in_ch=self.in_ch, out_ch=self.out_ch, list_ch=self.list_ch, random_init=False)

                ckpt = torch.load(self.list_model_pth[i], map_location='cpu')

                model.load_state_dict(ckpt)
                model.eval()
                model = model.to(self.device)
                self.list_model.append(model)

                print(f'==> Init model from {self.list_model_pth[i]} to device {self.device}')
                
    def get_input_tensor(self, image):
        input_ori = image.copy()

        # Flip TTA
        if self.tta:
            p_flip_z = (0, 1) if 2 in self.tta_flip_axis else (0,)
            p_flip_y = (0, 1) if 3 in self.tta_flip_axis else (0,)
            p_flip_x = (0, 1) if 4 in self.tta_flip_axis else (0,)
        else:
            p_flip_z = (0,)
            p_flip_y = (0,)
            p_flip_x = (0,)
            
        patch_inputs = []

        for flip_z in p_flip_z:
            for flip_y in p_flip_y:
                for flip_x in p_flip_x:
                    patch_input = torch.from_numpy(input_ori).to(self.device).unsqueeze(0)

                    # Get flip axis
                    flip_axis = []
                    if flip_z == 1:
                        flip_axis.append(2)
                    if flip_y == 1:
                        flip_axis.append(3)
                    if flip_x == 1:
                        flip_axis.append(4)

                    # Flip aug
                    do_flip = (flip_z == 1) or (flip_y == 1) or (flip_x == 1)
                    if do_flip:
                        patch_input = torch.flip(patch_input, dims=flip_axis)
                    patch_inputs.append((flip_axis, patch_input))
        return patch_inputs

    def predict(self, image):
        with torch.no_grad():

            list_pred = []
            input_ori = image.copy()

            # Flip TTA
            if self.tta:
                p_flip_z = (0, 1) if 2 in self.tta_flip_axis else (0,)
                p_flip_y = (0, 1) if 3 in self.tta_flip_axis else (0,)
                p_flip_x = (0, 1) if 4 in self.tta_flip_axis else (0,)
            else:
                p_flip_z = (0,)
                p_flip_y = (0,)
                p_flip_x = (0,)

            for flip_z in p_flip_z:
                for flip_y in p_flip_y:
                    for flip_x in p_flip_x:
                        patch_input = torch.from_numpy(input_ori).to(self.device).unsqueeze(0)

                        # Get flip axis
                        flip_axis = []
                        if flip_z == 1:
                            flip_axis.append(2)
                        if flip_y == 1:
                            flip_axis.append(3)
                        if flip_x == 1:
                            flip_axis.append(4)

                        # Flip aug
                        do_flip = (flip_z == 1) or (flip_y == 1) or (flip_x == 1)
                        if do_flip:
                            patch_input = torch.flip(patch_input, dims=flip_axis)

                        for model in self.list_model:
                            pred = model(patch_input)
                            pred = pred[0]
                            pred = torch.sigmoid(pred[0, 0])

                            list_pred.append(pred.cpu().numpy())
            return np.mean(list_pred)

class DICOMReader(threading.Thread):
    def __init__(self, func=read_from_DICOM_dir, args=()):
        super(DICOMReader, self).__init__()

        self.func = func
        self.args = args

        self.result = None

    def run(self):
        self.result = self.func(*self.args)

    def get_result(self):
        threading.Thread.join(self)
        return self.result


class FractureDetector:
    def __init__(self, predictor_stage1, predictor_stage2, predictor_stage3,  extend_roi=(5.0, 5.0, 5.0)):
        self.predictor_stage1 = predictor_stage1
        self.predictor_stage2 = predictor_stage2
        self.predictor_stage3 = predictor_stage3
        self.extend_roi = extend_roi

        self.params = {
            'alpha': [0.055, 0.044, 0.052, 0.05, 0.07, 0.077, 0.09, 0.024],
            'beta': [0.475, 0.34, 0.37, 0.38, 0.31, 0.35, 0.38, 0.36],
            'min_score': [0.116, 0.01, 0.015, 0.015, 0.01, 0.02, 0.032, 0.048],
            'max_score': [0.99, 0.999, 0.993, 0.99, 1.0, 0.943, 0.997, 0.999]
        }

        self.results = {}

    def get_c1_c7_bbox(self, pred, image_spacing):
        c1_c7_bbox = get_bbox(pred > 0)
        if c1_c7_bbox is None:
            return None
        c1_c7_bbox = extend_bbox(
            c1_c7_bbox,
            max_shape=pred.shape,
            list_extend_length=self.extend_roi,
            spacing=image_spacing,
            approximate_method=np.ceil
        )
        return c1_c7_bbox

    def predict_stage1(self, ct_nii):
        # Resampling
        time_start = time.time()
        ori_nii_info = get_nii_info(ct_nii)
        ct_nii = self.predictor_stage1.resampling(ct_nii)
        print(f"                  Resampling use: {time.time() - time_start}")

        # Pre_processing
        time_start = time.time()
        image = sitk.GetArrayFromImage(ct_nii)[np.newaxis]
        image = self.predictor_stage1.pre_processing(image)
        print(f"                  Pre_processing use: {time.time() - time_start}")

        # Model forward
        time_start = time.time()
        pred = self.predictor_stage1.sliding_window_inference(image)
        print(f"                  Model forward use: {time.time() - time_start}")

        # Post processing
        time_start = time.time()
        pred = np.argmax(pred, axis=0)
        pred = keep_largest_cervical_cc(pred, ct_nii.GetSpacing()[::-1])
        pred[pred > 7] = 0
        print(f"                  Post processing use: {time.time() - time_start}")

        # Resampling back
        time_start = time.time()
        pred_nii = sitk.GetImageFromArray(np.uint8(pred))
        pred_nii = copy_nii_info(ct_nii, pred_nii)
        pred_nii = resample(
            pred_nii,
            new_spacing=ori_nii_info['spacing'],
            new_origin=ori_nii_info['origin'],
            new_size=ori_nii_info['size'],
            new_direction=ori_nii_info['direction'],
            center_origin=None,
            interp=sitk.sitkNearestNeighbor,
            dtype=sitk.sitkUInt8,
            constant_value=0
        )

        pred = sitk.GetArrayFromImage(pred_nii)
        print(f"                  Resampling back use: {time.time() - time_start}")
        return pred

    def predict_stage2(self, ct_nii):

        # Resampling
        time_start = time.time()
        ori_nii_info = get_nii_info(ct_nii)
        ct_nii = self.predictor_stage2.resampling(ct_nii)
        print(f"                  Resampling use: {time.time() - time_start}")

        # Pre_processing
        time_start = time.time()
        image = sitk.GetArrayFromImage(ct_nii)[np.newaxis]
        image = self.predictor_stage2.pre_processing(image)
        print(f"                  Pre_processing use: {time.time() - time_start}")
        print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!{image.shape}")
        # Model forward
        time_start = time.time()
        pred = self.predictor_stage2.sliding_window_inference(image)
        print(f"                  Model forward use: {time.time() - time_start}")

        # Post processing
        time_start = time.time()
        pred = pred[1]  # 0 for background, 1 for foreground
        print(f"                  Post processing use: {time.time() - time_start}")

        # Resampling back
        time_start = time.time()
        pred_nii = sitk.GetImageFromArray(pred)
        pred_nii = copy_nii_info(ct_nii, pred_nii)
        pred_nii = resample(
            pred_nii,
            new_spacing=ori_nii_info['spacing'],
            new_origin=ori_nii_info['origin'],
            new_size=ori_nii_info['size'],
            new_direction=ori_nii_info['direction'],
            center_origin=None,
            interp=sitk.sitkLinear,
            dtype=sitk.sitkFloat32,
            constant_value=0.
        )
        pred = sitk.GetArrayFromImage(pred_nii)
        print(f"                  Resampling back use: {time.time() - time_start}")

        return pred
    def get_stage3_input(self, pred_c1_c7, pred_fracture):
        # Resampling
        pred_c1_c7_nii = sitk.GetImageFromArray(pred_c1_c7)
        pred_fracture_nii = sitk.GetImageFromArray(pred_fracture)

        new_size = (96, 96, 96)
        new_spacing = list(np.array(pred_fracture_nii.GetSize()[::-1]) / np.array(new_size))

        pred_fracture_nii = sitk_base.resample(
            pred_fracture_nii,
            new_spacing[::-1],
            new_origin=None,
            new_size=new_size[::-1],
            new_direction=None,
            center_origin=None,
            interp=sitk.sitkLinear,
            dtype=sitk.sitkFloat32,
            constant_value=0
        )
        pred_c1_c7_nii = sitk_base.resample(
            pred_c1_c7_nii,
            new_spacing[::-1],
            new_origin=None,
            new_size=new_size[::-1],
            new_direction=None,
            center_origin=None,
            interp=sitk.sitkNearestNeighbor,
            dtype=sitk.sitkUInt8,
            constant_value=0
        )

        pred_c1_c7 = sitk.GetArrayFromImage(pred_c1_c7_nii)
        pred_fracture = sitk.GetArrayFromImage(pred_fracture_nii)

        input_ = np.concatenate((pred_fracture[np.newaxis], pred_c1_c7[np.newaxis]), axis=0)
        # print(input_.shape)
        # score = self.predictor_stage3.predict(input_)

        return input_
    
    def predict_stage3(self, pred_c1_c7, pred_fracture):
        # Resampling
        pred_c1_c7_nii = sitk.GetImageFromArray(pred_c1_c7)
        pred_fracture_nii = sitk.GetImageFromArray(pred_fracture)

        new_size = (96, 96, 96)
        new_spacing = list(np.array(pred_fracture_nii.GetSize()[::-1]) / np.array(new_size))

        pred_fracture_nii = sitk_base.resample(
            pred_fracture_nii,
            new_spacing[::-1],
            new_origin=None,
            new_size=new_size[::-1],
            new_direction=None,
            center_origin=None,
            interp=sitk.sitkLinear,
            dtype=sitk.sitkFloat32,
            constant_value=0
        )
        pred_c1_c7_nii = sitk_base.resample(
            pred_c1_c7_nii,
            new_spacing[::-1],
            new_origin=None,
            new_size=new_size[::-1],
            new_direction=None,
            center_origin=None,
            interp=sitk.sitkNearestNeighbor,
            dtype=sitk.sitkUInt8,
            constant_value=0
        )

        pred_c1_c7 = sitk.GetArrayFromImage(pred_c1_c7_nii)
        pred_fracture = sitk.GetArrayFromImage(pred_fracture_nii)

        input_ = np.concatenate((pred_fracture[np.newaxis], pred_c1_c7[np.newaxis]), axis=0)
        # print(input_.shape)
        # score = self.predictor_stage3.predict(input_)

        return input_
    
    def get_score(self, pred_c1_c7, pred_fracture):
        output = np.zeros(8, np.float32)  # Overall, C1-C7

        if (pred_c1_c7 is not None) and (pred_fracture is not None):
            # Overall, C1-C7
            for C_i in range(8):
                if C_i == 0:
                    roi_fracture = pred_fracture[np.logical_and(pred_fracture >= self.params['alpha'][C_i], pred_c1_c7 > 0)]
                else:
                    roi_fracture = pred_fracture[np.logical_and(pred_fracture >= self.params['alpha'][C_i], pred_c1_c7 == C_i)]

                if len(roi_fracture) == 0:
                    output[C_i] = self.params['min_score'][C_i]
                else:
                    output[C_i] = max(self.params['min_score'][C_i],
                                      min(self.params['max_score'][C_i],
                                          np.percentile(roi_fracture, 100 * self.params['beta'][C_i])
                                          )
                                      )
        else:
            for C_i in range(8):
                output[C_i] = self.params['min_score'][C_i]
        output[0] = max(self.params['min_score'][0],  np.max(output[1:]))
        return output

    @staticmethod
    def read_DICOM_multi_thread(list_DICOM_dirs):
        list_thread = []
        list_outputs = []

        for DICOM_dir in list_DICOM_dirs:
            cur_thread = DICOMReader(func=read_from_DICOM_dir, args=(DICOM_dir, ))
            cur_thread.start()
            list_thread.append(cur_thread)

        for cur_thread in list_thread:
            list_outputs.append(cur_thread.get_result())
        list_thread.clear()

        return list_outputs

    def predict(self,
                test_df: pd.DataFrame,
                num_thread=4,
                output_dir=None
                ):
        
        with torch.no_grad():
            #try_recursive_mkdir(output_dir)
            overall_time_start = time.time()
            cur_case_ids = test_df['StudyInstanceUID'].tolist()

            print(f"==> Predicting {split_i}: {cur_case_ids}")

            # --------------------------------- Inference one split ------------------------------ #
            # Step 1, Read all images
            time_start = time.time()
            cur_ct_niis = self.read_DICOM_multi_thread(cur_test_files['image_folder'].tolist())
            print(f"    Finish Reading use : {time.time() - time_start} seconds")

            time_start = time.time()
            for case_i in range(len(cur_ct_niis)):
                case_id = cur_case_ids[case_i]
                ct_nii = cur_ct_niis[case_i]

                ori_nii_info = get_nii_info(ct_nii)  # Record original nii info before processing, eg. spacing, size

                # Step 2, predictor_stage1, segment C1-C7
                pred_1 = self.predict_stage1(ct_nii)

                #  Step 3, get c1-c7 bounding bbox
                c1_c7_bbox = self.get_c1_c7_bbox(pred_1, ori_nii_info['spacing'][::-1])

                # Step 4, predictor_stage2,  segment fracture
                if c1_c7_bbox is not None:
                    bz, ez, by, ey, bx, ex = c1_c7_bbox
                    roi_ct_nii = ct_nii[bx:ex + 1, by:ey + 1, bz:ez + 1]
                    roi_pred_1 = pred_1[bz:ez + 1, by:ey + 1, bx:ex + 1]

                    # Stage 2 inference
                    roi_pred_2 = self.predict_stage2(roi_ct_nii)
                else:
                    roi_pred_1 = None
                    roi_pred_2 = None

                # Saving 5, get score
                if roi_pred_1 is None:
                    roi_pred_1 = np.zeros((2, 2, 2), np.uint8)
                    roi_pred_2 = np.zeros((2, 2, 2), np.float32)

                #sitk.WriteImage(sitk.GetImageFromArray(np.uint8(roi_pred_1)), f"{output_dir}/{case_id}_C1_C7.nii.gz")
                #sitk.WriteImage(sitk.GetImageFromArray(np.float32(roi_pred_2)), f"{output_dir}/{case_id}_fracture.nii.gz")

                score = self.get_score(roi_pred_1, roi_pred_2)
                # score[0] = self.predict_stage3(roi_pred_1, roi_pred_2)

                # self.results[case_id] = score
                print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!     {score}")
            print(f"    Finish this split use : {time.time() - time_start} seconds")
            print(f"    Overall use : {time.time() - overall_time_start} seconds")

            gc.collect()
        return roi_pred_1, roi_pred_2

In [3]:
import path

**Step3: Init all models, get final predictions of all DICOM dirs in test dir**

In [None]:
# ----------------------- main ---------------------------- #
time_start = time.time()


# Predictor
with torch.no_grad():
    # --------------------------- Init models --------------------------------- #
    list_model_C1_C7_segmentation = [
        f'{path.path_root}/nnUnet/Models/nnUNet/3d_fullres/Task_101_VertebralLocation_GeneratePseudoLabel/nnUNetTrainerV2__nnUNetPlansv2.1/all/model_final_checkpoint.model',
        f'{path.path_root}/nnUnet/Models/nnUNet/3d_fullres/Task_102_VertebralLocation_GeneratePseudoLabel/nnUNetTrainerV2__nnUNetPlansv2.1/all/model_final_checkpoint.model',
        f'{path.path_root}/nnUnet/Models/nnUNet/3d_fullres/Task_103_VertebralLocation_GeneratePseudoLabel/nnUNetTrainerV2__nnUNetPlansv2.1/all/model_final_checkpoint.model'
    ]
    plan_C1_C7_segmentation = f'{path.path_root}/nnUnet/Models/nnUNet/3d_fullres/Task_101_VertebralLocation_GeneratePseudoLabel/nnUNetTrainerV2__nnUNetPlansv2.1/plans.pkl'
    
    list_model_fracture_detection = [
        f'{path.path_root}/nnUnet/Models/nnUNet/3d_fullres/Task_203_FractureDetection_Real5Fold/nnUNetTrainerV2__nnUNetPlansv2.1/fold_0/model_final_checkpoint.model',
        f'{path.path_root}/nnUnet/Models/nnUNet/3d_fullres/Task_203_FractureDetection_Real5Fold/nnUNetTrainerV2__nnUNetPlansv2.1/fold_1/model_final_checkpoint.model',
        f'{path.path_root}/nnUnet/Models/nnUNet/3d_fullres/Task_203_FractureDetection_Real5Fold/nnUNetTrainerV2__nnUNetPlansv2.1/fold_2/model_final_checkpoint.model',
        f'{path.path_root}/nnUnet/Models/nnUNet/3d_fullres/Task_203_FractureDetection_Real5Fold/nnUNetTrainerV2__nnUNetPlansv2.1/fold_3/model_final_checkpoint.model',
        f'{path.path_root}/nnUnet/Models/nnUNet/3d_fullres/Task_203_FractureDetection_Real5Fold/nnUNetTrainerV2__nnUNetPlansv2.1/fold_4/model_final_checkpoint.model',
    ]
    list_model_post_processing = [
        # f'{path.path_root}/Models/no_val_58/best_train_model.pkl',
        # f'{path.path_root}/Models/no_val_59/best_train_model.pkl',
        f'{path.path_root}/Models/no_val_58/best_val_model.pkl',
        f'{path.path_root}/Models/no_val_59/best_val_model.pkl',
    ]
    plan_fracture_detection = f'{path.path_root}/nnUnet/Models/nnUNet/3d_fullres/Task_203_FractureDetection_Real5Fold/nnUNetTrainerV2__nnUNetPlansv2.1/plans.pkl'
    
    predictor_1 = NNUnetCTPredictor(
        list_model_pth=list_model_C1_C7_segmentation,
        plan_file=plan_C1_C7_segmentation,
        plan_stage=-1,
        device=torch.device('cuda:0'),
        use_gaussian_for_sliding_window=True,

        patch_size=None,
        stride=None,
        tta=False,
        tta_flip_axis=(4,),

        resampling_tolerance=0.01,
        resampling_mode=sitk.sitkNearestNeighbor,
        resampling_dtype=sitk.sitkInt16,
        resampling_constance_value=-1024,

        remove_air_CT=True
    )
    predictor_2 = PredictorStage2(
        list_model_pth=list_model_fracture_detection,
        plan_file=plan_fracture_detection,
        plan_stage=-1,
        device=torch.device('cuda:0'),
        use_gaussian_for_sliding_window=True,

        patch_size=(96, 224, 224),
        stride=(96, 224, 224),
        tta=True,
        tta_flip_axis=(4,),

        resampling_tolerance=0.01,
        resampling_mode=sitk.sitkNearestNeighbor,
        resampling_dtype=sitk.sitkInt16,
        resampling_constance_value=-1024,

        remove_air_CT=False,

        save_dtype=np.float32
    )

    predictor_3 = PredictorStage3(
            list_model_pth=list_model_post_processing,
            device=torch.device('cuda:0'),
            tta=True,
            tta_flip_axis=(4,)
        )

    c2f_predictor = FractureDetector(
        predictor_stage1=predictor_1,
        predictor_stage2=predictor_2,
        predictor_stage3=predictor_3
    )

In [None]:
sys.path.append('../custom_grad_cam')

import matplotlib.pyplot as plt

from grad_cam_3d import GradCAM3D

from pytorch_grad_cam.utils.model_targets import BinaryClassifierOutputTarget
import cv2


targets = [BinaryClassifierOutputTarget(1)] # BinaryClassifierOutputTarget

In [6]:
def cam_to_voxel_space(cam: np.ndarray, voxel:np.ndarray) -> np.ndarray:
    length, h, w = voxel.shape
    
    grayscale_cam_resized = np.zeros((length, h, w))
    cam_length = cam.shape[0]
    frame_interval = cam_length / length
    max_index = cam_length - 1
    
    offset_index = int((length - max_index * length / cam_length) / 2)
    
    for i in range(offset_index):
        grayscale_cam_resized[i] = cv2.resize(cam[0], (w, h)) * (0.5 + 0.5 * i / offset_index)
        grayscale_cam_resized[length - i - 1] = cv2.resize(cam[0], (w, h)) * (0.5 + 0.5 * i / offset_index)
        
    for j in range(offset_index, length - offset_index):
        # Find the indices of the original frames that we'll interpolate between
        ind = j - offset_index
        idx1 = int(ind * frame_interval)
        idx2 = min(idx1 + 1, max_index)

        # Find the interpolation weight
        alpha = (ind * frame_interval) - idx1

        # Interpolate between the original frames
        cam1 = cv2.resize(cam[idx1], (w, h))
        cam2 = cv2.resize(cam[idx2], (w, h))
        grayscale_cam_resized[j] = cv2.addWeighted(cam1, 1 - alpha, cam2, alpha, 0)

    return grayscale_cam_resized

In [10]:
output_folder = OUTPUT_FOLDER
if not os.path.exists(output_folder):
    os.makedirs(output_folder)
    
test_df = pd.read_csv(INPUT_TEST_CSV_FILE)
test_df.shape

(97, 3)

In [None]:
num_thread = 1
num_split = math.ceil(test_df.shape[0] / num_thread)
for split_i in range(num_split):
    cur_test_files = test_df.iloc[num_thread * split_i:num_thread * (split_i + 1)]
    # ------------------------- Get all DICOM dirs ----------------------------- #
    roi_1, roi_2 = c2f_predictor.predict(
        test_df=cur_test_files
    )

    stage3_input = c2f_predictor.get_stage3_input(roi_1, roi_2)
    
    stage3_input_tensor = predictor_3.get_input_tensor(stage3_input)

    gray_cam = None
    for flip, input_tensor in stage3_input_tensor:
        for index, model in enumerate(predictor_3.list_model):
            with GradCAM3D(model=model, target_layers=[model.encoder.encoder_4], use_cuda=torch.cuda.is_available()) as cam:
                local_cam = torch.tensor(cam(input_tensor=input_tensor, targets=targets))
                if len(flip) > 0:
                    local_cam = torch.flip(local_cam, dims=flip)
                    
                if gray_cam is None:
                    gray_cam = local_cam.numpy().squeeze(0).squeeze(0)
                else:
                    gray_cam += local_cam.numpy().squeeze(0).squeeze(0)
    
    gray_cam = gray_cam / 4
    
    images = c2f_predictor.read_DICOM_multi_thread(cur_test_files['image_folder'].tolist())
    nd_image = sitk.GetArrayFromImage(images[0])
    resized_cam = cam_to_voxel_space(gray_cam, nd_image)
    resized_cam = resized_cam[::-1, :]
    
    # Get the SeriesUID to name the file
    series_uid = cur_test_files['StudyInstanceUID'].values[0]  # Assuming SeriesUID is a column in cur_test_files DataFrame

    # Define the file path for saving
    save_path = os.path.join(output_folder, f"{series_uid}.npy")

    # Save the resized_cam as a .npy file
    np.save(save_path, resized_cam)
    del gray_cam, resized_cam, nd_image, images
    gc.collect()