## Data Preparation

You should prepare the following before running this step. Please refer to the `example_data/data` folder for guidance:

1. **image data**
   - you want to prepare the SAX data as a 4D array [x,y,time_frame,slice_num] saved as a nii file. in our study we sample 15 time frames as default. please refer ```example_data/data/ID_0002``` as SAX reference  
   - you want to prepare the LAX data as a 3D array [x,y,time_frame]. please refer ```example_data/data/ID_0085``` as SAX reference  

2. **A patient list** that enumerates all your cases
   - To understand the standard format, please refer to the file:  
     `example_data/Patient_list/patient_list.xlsx`
   - make sure column ***total_slice_num*** is correct for each case

4. **Text prompts** that specifies the view type
   - our model takes text prompt "SAX" or "LAX" to specify the view type 
   - we use "CLIP" model to embed text prompts (code: ```dataset/CMR/clip_extractor.ipynb```)
   - we have prepared the embedded feature in `example_data/data/text_prompt_clip`, please download to your local

5. **Box prompts** that indicates the location of myocardium
   - in the prediction you need to define box prompts manually by yourself if you want to use this feature
   - we prepare examplar bounding box ```example_data/data/ID_0002/bounding_box.npy``` which saves the bounding box as a 4D array [f,s,2,4] where f is the number of cases, s is the slice num in each case, 2 refers to ED and ES, 4 refers to at each frame the definition of [xmin, ymin, xmax, ymax] of the bounding box. 
   - If you don't define the box, the model will just pass None as box prompt. 


---

### Docker environment
Please use `docker`, it will build a pytorch-based container


In [1]:
import os
import sys
sys.path.append('/workspace/Documents')  ### remove this if not needed!
import numpy as np
import pandas as pd
import random
import torch
from tqdm import tqdm
import torch.backends.cudnn as cudnn
 
from cineCMR_SAM.utils.model_util import *
from cineCMR_SAM.segment_anything.model import build_model 
from cineCMR_SAM.utils.save_utils import *
from cineCMR_SAM.utils.config_util import Config

import cineCMR_SAM.inference_engine as inference_engine

import cineCMR_SAM.dataset.build_CMR_datasets as build_CMR_datasets
import cineCMR_SAM.functions_collection as ff
import cineCMR_SAM.get_args_parser as get_args_parser

main_path = '/mnt/camca_NAS/SAM_for_CMR/'  # replace with your own path

100%|████████████████████████████████████████| 338M/338M [00:02<00:00, 128MiB/s]


### define parameters for this experiment

In [11]:
# set experiment-specific parameters
trial_name = 'cineCMR_sam_trial' 

output_dir = os.path.join(main_path, 'example_data/models', trial_name)
ff.make_folder([os.path.join(main_path, 'example_data/models'), output_dir])

text_prompt = True # whether we need to input text prompt to specify the view types (LAX or SAX). True or False. default = True
box_prompt = False # whether we have the bounding box for myocardium defined by the user. False means no box, 'one' means one box at ED and 'two' means two boxes at ED and ES

# define trained model
pretrained_model = os.path.join(main_path, 'example_data/models',trial_name,'models/model-sax.pth')  # replace with your own path

In [12]:
# default
# preload the text prompt feature 
sax_text_prompt_feature = np.load(os.path.join(main_path,'example_data/data/text_prompt_clip/sax.npy'))
lax_text_prompt_feature = np.load(os.path.join(main_path,'example_data/data/text_prompt_clip/lax.npy'))

# define the original SAM model
original_sam = os.path.join( main_path, 'example_data/pretrained_sam/sam_vit_h_4b8939.pth') 

args = get_args_parser.get_args_parser(text_prompt = text_prompt, 
                                       box_prompt = box_prompt, 
                                       pretrained_model = pretrained_model, 
                                       original_sam = original_sam, )
args = args.parse_args([])

# some other settings
cfg = Config(args.config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True

### define data


In [14]:
sax_or_lax = 'sax'
save_folder_name = 'predicts_'+sax_or_lax
patient_list_file = os.path.join(main_path,'example_data/data/Patient_list/patient_list_sax.xlsx') if sax_or_lax == 'sax' else os.path.join(main_path,'example_data/data/Patient_list/patient_list_lax.xlsx')
patient_index_list = np.arange(0,1,1)

dataset_pred = build_CMR_datasets.build_dataset(
        args,
        view_type = sax_or_lax,
        patient_list_file = patient_list_file, 
        index_list = patient_index_list, 
        text_prompt_feature = sax_text_prompt_feature if sax_or_lax == 'sax' else lax_text_prompt_feature,
        only_myo = True, 
        shuffle = False, 
        augment = False)

### predict

In [15]:
data_loader_pred = torch.utils.data.DataLoader(dataset_pred, batch_size = 1, shuffle = False, pin_memory = True, num_workers = 0)

with torch.no_grad():
    with torch.cuda.amp.autocast():
        model = build_model(args, device)#skip_nameing = True, chunk = np.shape(np.zeros(0)))

        # load the pretrained model
        if args.pretrained_model is not None:
            print('loading pretrained model : ', args.pretrained_model)
            finetune_checkpoint = torch.load(args.pretrained_model)
            model.load_state_dict(finetune_checkpoint["model"])
                            
        # do the prediction for each slice (2D+T) one by one
        for data_iter_step, batch in tqdm(enumerate(data_loader_pred)):
                
            patient_id = batch["patient_id"][0]
            slice_index = batch["slice_index"].item()
            print('patient_id: ', patient_id, ' slice_index: ', slice_index)
                
            save_folder_patient = os.path.join(main_path, 'example_data/models',trial_name, save_folder_name, patient_id)
            ff.make_folder([os.path.dirname(save_folder_patient), save_folder_patient])

            batch["image"]= batch["image"].cuda()

            batch["text_prompt_feature"] = batch["text_prompt_feature"].to(torch.float32)

            bbox = batch["box_prompt"].detach().cpu().numpy()[0]
                    
            output = model(batch, args.img_size)

            torch.cuda.synchronize()
            
            inference_engine.save_predictions(view_type = sax_or_lax, batch = batch, output = output, args = args, save_folder_patient = save_folder_patient)

  with torch.cuda.amp.autocast():


Important! text prompt: True
Important! box prompt: True
loading pretrained model :  /mnt/camca_NAS/SAM_for_CMR/example_data/models/cineCMR_sam_trial/models/model-62.pth


  finetune_checkpoint = torch.load(args.pretrained_model)
0it [00:00, ?it/s]

patient_id:  ID_0002  slice_index:  0


1it [00:01,  1.41s/it]

patient_id:  ID_0002  slice_index:  1


2it [00:01,  1.10it/s]

patient_id:  ID_0002  slice_index:  2


3it [00:02,  1.33it/s]

patient_id:  ID_0002  slice_index:  3


4it [00:03,  1.48it/s]

patient_id:  ID_0002  slice_index:  4


5it [00:03,  1.58it/s]

patient_id:  ID_0002  slice_index:  5


6it [00:04,  1.65it/s]

patient_id:  ID_0002  slice_index:  6


7it [00:04,  1.62it/s]

patient_id:  ID_0002  slice_index:  7


8it [00:05,  1.67it/s]

patient_id:  ID_0002  slice_index:  8


9it [00:05,  1.70it/s]

patient_id:  ID_0002  slice_index:  9


10it [00:06,  1.73it/s]

patient_id:  ID_0002  slice_index:  10


11it [00:07,  1.56it/s]
