In [1]:
#this is to get scores on the test dataset and check it on flare2021 test grand challenge

In [15]:
from types import SimpleNamespace
import nibabel as nib
import warnings
warnings.filterwarnings("ignore")

In [16]:
from gg_tools import get_test_txt_loader, get_test_data_loader, dice_score, TEMPLATE, get_key, NUM_CLASS, ORGAN_NAME, merge_label_v1, save_result, organ_post_process

In [17]:
import torch
import os
import numpy as np

In [18]:
from monai.inferers import sliding_window_inference
from tqdm import tqdm

In [19]:
def test(model,test_loader,test_transform,args,post_process=False):

    model.eval()
    i = 0
    for batch in tqdm(test_loader):

        if(args.model_type == 'film'):
            image, name, prompt = batch['image'].to(args.device), batch['name'], batch['prompt']
        else:
            image, name = batch['image'].to(args.device), batch['name']

        with torch.no_grad():

            if(args.model_type == 'film'):
                predictor = lambda image_patch:model(image_patch,prompt)
                pred = sliding_window_inference(image, (args.roi_x, args.roi_y, args.roi_z), 1, predictor)
            else:
                pred = sliding_window_inference(image, (args.roi_x, args.roi_y, args.roi_z), 1, model)
            pred_sigmoid = torch.nn.functional.sigmoid(pred)
        
        #now squeeze it  threshold with 0.5 ,  convert it into numpy, post_process it if needed , convert into tensor and store in batch['result']
        pred_sigmoid = torch.squeeze(pred_sigmoid)
        pred_mask = torch.where(pred_sigmoid>=0.5,1,0).to(torch.uint8).cpu().numpy()

        template_key = get_key(name[0]) #since for val_loader we have just 1 .
        organ_list = TEMPLATE[template_key]

        if post_process:
            pred_mask = organ_post_process(pred_mask,organ_list)
        
        pred_mask_merged = merge_label_v1(pred_mask,name[0])
        pred_mask_merged = pred_mask_merged.astype(np.uint8)
        #convert it into tensor and save
        batch['result'] = torch.from_numpy(np.expand_dims(pred_mask_merged,axis=0))
        
        #for path get the folder from name
        file_name = name[0].split('.')[0]
        subfold_path = file_name.split('/')[1]
        save_dir = os.path.join(args.os_save_fold,subfold_path)
        #print(save_dir,file_name)
        save_result(batch,test_transform,save_dir)

In [57]:
args_clip = SimpleNamespace(
    space_x = 1.5,
    space_y = 1.5,
    space_z = 1.5,
    roi_x = 96,
    roi_y = 96,
    roi_z = 96,
    num_samples = 2,
    data_root_path = '/blue/kgong/s.kapoor/language_guided_segmentation/CLIP-Driven-Universal-Model/',
    data_txt_path = './dataset/dataset_list/',
    batch_size = 4,
    num_workers = 8,
    a_min = -175,
    a_max = 250,
    b_min = 0.0,
    b_max = 1.0,
    dataset_list = ['flaretest'], #here it is used to vaidate the model
    NUM_CLASS = NUM_CLASS,
    backbone = 'swinunetr',
    trans_encoding = 'word_embedding',
    pretrain = './swinunetr.pth',
    lr = 4e-4,
    weight_decay = 1e-5,
    precomputed_prompt_path = './embeddings_template_flare.pkl',
    word_embedding = './pretrained_weights/txt_encoding.pth',
    dist = False,
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    model_type = None,
    file_name = 'universal_flaretest.txt',
    os_save_fold = './flaretest/universal_model'
)

In [58]:
from model.Universal_model import Universal_model

In [59]:
#get the model
model = Universal_model(img_size=(args_clip.roi_x, args_clip.roi_y, args_clip.roi_z),
                in_channels=1,
                out_channels=32,
                backbone=args_clip.backbone,
                encoding=args_clip.trans_encoding
                )

In [60]:
# trying to match the keys from the author's weights
checkpoint = torch.load(args_clip.pretrain)
store_dict = model.state_dict()
load_dict = checkpoint['net']
for key,value in load_dict.items():

    #print(key)
    key = '.'.join(key.split('.')[1:]) #remove module
    if 'swinViT' in key or 'encoder' in key or 'decoder' in key: #add backbone context;
        key ='.'.join(['backbone',key])
    #print(key)
    if key in store_dict.keys():
        store_dict[key]=value
    else:
        print(key)
model.load_state_dict(store_dict)

<All keys matched successfully>

In [61]:
model = model.to(args_clip.device)

In [62]:
clip_loader,clip_transform = get_test_data_loader(args_clip)

test len 90


In [63]:
test(model,clip_loader,clip_transform,args_clip)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [11:51<00:00,  7.91s/it]


In [21]:
args_film = SimpleNamespace(
    space_x = 1.5,
    space_y = 1.5,
    space_z = 1.5,
    roi_x = 96,
    roi_y = 96,
    roi_z = 96,
    num_samples = 2,
    data_root_path = '/blue/kgong/s.kapoor/language_guided_segmentation/CLIP-Driven-Universal-Model/',
    data_txt_path = './dataset/dataset_list/',
    batch_size = 4,
    num_workers = 8,
    a_min = -175,
    a_max = 250,
    b_min = 0.0,
    b_max = 1.0,
    dataset_list = ['flaretest'],
    NUM_CLASS = NUM_CLASS,
    backbone = 'swinunetr',
    trans_encoding = 'word_embedding',
    pretrain = './out/deep_film_total/epoch_380.pth',
    lr = 4e-4,
    weight_decay = 1e-5,
    precomputed_prompt_path = 'embeddings_template_flare.pkl',
    dist = False,
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    model_type='film',
    file_name='flare_test_test.txt',
    os_save_fold = './flaretest/deep_film_model'
)

In [22]:
from model.SwinUNETR_DEEP_FILM import SwinUNETR_DEEP_FILM

In [23]:
film_model = SwinUNETR_DEEP_FILM(img_size=(args_film.roi_x, args_film.roi_y, args_film.roi_z),
                        in_channels=1,
                        out_channels=32,
                        precomputed_prompt_path=args_film.precomputed_prompt_path)

In [24]:
film_checkpoint = torch.load(args_film.pretrain)
store_dict = film_model.state_dict()
load_dict = film_checkpoint['net']

In [25]:
for key,value in load_dict.items():

    if 'swinViT' in key or 'encoder' in key or 'decoder' in key:
        name = '.'.join(key.split('.')[1:])
    else:
        name = '.'.join(key.split('.')[1:])
    if name in store_dict.keys():
        store_dict[name]=value
    else:
        print(name)

In [26]:
film_model.load_state_dict(store_dict)

<All keys matched successfully>

In [27]:
film_model = film_model.to(args_film.device)

In [28]:
film_loader,film_transform = get_test_txt_loader(args_film)

test len 90


In [29]:
test(film_model,film_loader,film_transform,args_film)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [11:05<00:00,  7.39s/it]


In [30]:
def remove_0000(folder_path):

    for filename in tqdm(os.listdir(folder_path)):
        if filename.endswith('.nii.gz'):
            new_filename = filename.replace('_0000','')
            old_path = os.path.join(folder_path,filename)
            new_path = os.path.join(folder_path,new_filename)
            os.rename(old_path,new_path)

In [64]:
remove_0000(os.path.join(args_clip.os_save_fold,'imagesTs'))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 2517.14it/s]


In [31]:
remove_0000(os.path.join(args_film.os_save_fold,'imagesTs'))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 2562.36it/s]
