In [3]:
import os
import sys
sys.path.append('/workspace/Documents')  ### remove this if not needed!
import numpy as np
import pandas as pd
import random
import torch
from tqdm import tqdm
import torch.backends.cudnn as cudnn
 
from cineCMR_SAM.utils.model_util import *
from cineCMR_SAM.segment_anything.model import build_model 
from cineCMR_SAM.utils.save_utils import *
from cineCMR_SAM.utils.config_util import Config

import cineCMR_SAM.inference_engine as inference_engine

import cineCMR_SAM.dataset.build_CMR_datasets as build_CMR_datasets
import cineCMR_SAM.functions_collection as ff
import cineCMR_SAM.get_args_parser as get_args_parser

### define parameters for this experiment

In [4]:
# set experiment-specific parameters
main_path = '/mnt/camca_NAS/SAM_for_CMR/'  # replace with your own path
trial_name = 'cineCMR_sam_github'
text_prompt = True # whether we need to input text prompt to specify the view types (LAX or SAX). True or False. default = True
box_prompt = False # whether we have the bounding box for myocardium defined by the user. None means no box, 'one' means one box at ED and 'two' means two boxes at ED and ES

if box_prompt == 'two':
    pretrained_model = os.path.join(main_path, 'models',trial_name, 'models/model_text_2boxes.pth')  # replace with your own model
elif box_prompt == 'one':
    pretrained_model = os.path.join(main_path, 'models',trial_name, 'models/model_text_1box.pth')  # replace with your own model
else:
    pretrained_model = os.path.join(main_path, 'models',trial_name, 'models/model_text.pth') # replace with your own model

# preload the text prompt feature (it's the output of a CLIP model when I input "LAX" or "SAX" into it)
sax_text_prompt_feature = np.load('/mnt/camca_NAS/SAM_for_CMR/data/text_prompt_clip/sax.npy')
lax_text_prompt_feature = np.load('/mnt/camca_NAS/SAM_for_CMR/data/text_prompt_clip/lax.npy')

# also define the original SAM model 
original_sam = os.path.join( "/mnt/camca_NAS/SAM_for_CMR/", 'models/pretrained_sam/sam_vit_h_4b8939.pth') # replace with your own path (you can easily download the original SAM model from online)

args = get_args_parser.get_args_parser(text_prompt = text_prompt, box_prompt = box_prompt, pretrained_model = pretrained_model, original_sam = original_sam)
args = args.parse_args([])

# some other settings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True

### define our data


In [5]:
sax_or_lax = 'sax'
save_folder_name = 'predicts_'+sax_or_lax
patient_list_file = os.path.join(main_path,'models/cineCMR_sam_github/patient_list_sax.xlsx') if sax_or_lax == 'sax' else os.path.join(main_path,'models/cineCMR_sam_github/patient_list_lax.xlsx')
patient_index_list = np.arange(0,1,1)

dataset_pred = build_CMR_datasets.build_dataset(
        args,
        view_type = sax_or_lax,
        patient_list_file = patient_list_file, 
        index_list = patient_index_list, 
        text_prompt_feature = sax_text_prompt_feature if sax_or_lax == 'sax' else lax_text_prompt_feature,
        only_myo = True, 
        shuffle = False, 
        augment = False)

### predict

In [6]:
data_loader_pred = torch.utils.data.DataLoader(dataset_pred, batch_size = 1, shuffle = False, pin_memory = True, num_workers = 0)# cpu_count())

with torch.no_grad():
    with torch.cuda.amp.autocast():
        model = build_model(args, device)#skip_nameing = True, chunk = np.shape(np.zeros(0)))

        # load the pretrained model
        if args.pretrained_model is not None:
            print('loading pretrained model : ', args.pretrained_model)
            finetune_checkpoint = torch.load(args.pretrained_model)
            model.load_state_dict(finetune_checkpoint["model"])
                            
        # do the prediction for each slice (2D+T) one by one
        for data_iter_step, batch in tqdm(enumerate(data_loader_pred)):
                
            patient_id = batch["patient_id"][0]
            slice_index = batch["slice_index"].item()
            print('patient_id: ', patient_id, ' slice_index: ', slice_index)
                
            save_folder_patient = os.path.join(main_path, 'models',trial_name, save_folder_name, patient_id)
            ff.make_folder([os.path.dirname(save_folder_patient), save_folder_patient])

            batch["image"]= batch["image"].cuda()

            batch["text_prompt_feature"] = batch["text_prompt_feature"].to(torch.float32)

            bbox = batch["box_prompt"].detach().cpu().numpy()[0]
                    
            output = model(batch, args.img_size)

            torch.cuda.synchronize()
            
            inference_engine.save_predictions(view_type = sax_or_lax, batch = batch, output = output, args = args, save_folder_patient = save_folder_patient)

  with torch.cuda.amp.autocast():


Important! text prompt: True
Important! box prompt: False
loading pretrained model :  /mnt/camca_NAS/SAM_for_CMR/models/cineCMR_sam_github/models/model_text.pth


  finetune_checkpoint = torch.load(args.pretrained_model)
0it [00:00, ?it/s]

in dataset_SAX, patient_id is:  ID_0002
patient_id:  ID_0002  slice_index:  0


1it [00:02,  2.17s/it]

in dataset_SAX, patient_id is:  ID_0002
patient_id:  ID_0002  slice_index:  1


2it [00:02,  1.27s/it]

in dataset_SAX, patient_id is:  ID_0002
patient_id:  ID_0002  slice_index:  2


3it [00:03,  1.01it/s]

in dataset_SAX, patient_id is:  ID_0002
patient_id:  ID_0002  slice_index:  3


4it [00:04,  1.21it/s]

in dataset_SAX, patient_id is:  ID_0002
patient_id:  ID_0002  slice_index:  4


5it [00:04,  1.36it/s]

in dataset_SAX, patient_id is:  ID_0002
patient_id:  ID_0002  slice_index:  5


6it [00:05,  1.47it/s]

in dataset_SAX, patient_id is:  ID_0002
patient_id:  ID_0002  slice_index:  6


7it [00:05,  1.54it/s]

in dataset_SAX, patient_id is:  ID_0002
patient_id:  ID_0002  slice_index:  7


8it [00:06,  1.58it/s]

in dataset_SAX, patient_id is:  ID_0002
patient_id:  ID_0002  slice_index:  8


9it [00:06,  1.62it/s]

in dataset_SAX, patient_id is:  ID_0002
patient_id:  ID_0002  slice_index:  9


10it [00:07,  1.66it/s]

in dataset_SAX, patient_id is:  ID_0002
patient_id:  ID_0002  slice_index:  10


11it [00:08,  1.35it/s]
