In [1]:
"""
Doing it for post processed stuff
"""

'\nDoing it for post processed stuff\n'

In [1]:
!pwd

/blue/kgong/s.kapoor/language_guided_segmentation/CLIP-Driven-Universal-Model


In [2]:
from types import SimpleNamespace
import nibabel as nib
import warnings
warnings.filterwarnings("ignore")


In [3]:
from gg_tools import get_train_val_data_loader, get_train_val_txt_loader, dice_score, TEMPLATE, get_key_2, NUM_CLASS, ORGAN_NAME, organ_post_process

In [4]:
import torch
import os
import numpy as np

In [5]:
from monai.inferers import sliding_window_inference
from tqdm import tqdm

In [6]:
def dice_score_2(preds, labels, spe_sen=False):  # on CPU with NumPy
    ### preds: w,h,d; labels: w,h,d
    assert preds.shape == labels.shape, "predict & target batch size don't match"
    
    # Flattening the arrays
    predict = preds.ravel()
    target = labels.ravel()

    # True positives
    tp = np.sum(predict * target)

    # Denominator: sum of predicted and target pixels + 1 to avoid division by zero
    den = np.sum(predict) + np.sum(target) + 1

    # Dice score calculation
    dice = 2 * tp / den

    return dice

In [7]:
#now go through val_loader
def validation_postprocess(model,val_loader,args):

    model.eval()

    dice_list = {key: torch.zeros(2,NUM_CLASS).to(args.device) for key in TEMPLATE.keys()}

    for batch in tqdm(val_loader):

        
        if(args.model_type == 'film'):
            image, label, name, prompt = batch['image'].to(args.device), batch['post_label'], batch['name'], batch['prompt']
        else:
            image, label, name = batch['image'].to(args.device), batch['post_label'], batch['name']

        with torch.no_grad():

            if(args.model_type == 'film'):
                predictor = lambda image_patch:model(image_patch,prompt)
                pred = sliding_window_inference(image, (args.roi_x, args.roi_y, args.roi_z), 1, predictor)
            else:
                pred = sliding_window_inference(image, (args.roi_x, args.roi_y, args.roi_z), 1, model)
            pred_sigmoid = torch.nn.functional.sigmoid(pred)
        
        template_key = get_key_2(name[0]) #since for val_loader we have just 1 .
        organ_list = TEMPLATE[template_key]

        pred_sigmoid = torch.squeeze(pred_sigmoid)
        pred_sigmoid = torch.where(pred_sigmoid>0.5,1.,0.)
        pred_mask = pred_sigmoid.cpu().numpy()
        post_processed_mask = organ_post_process(pred_mask,organ_list)
        label = np.array(label)
        label = np.squeeze(label)

        for organ in organ_list:
            dice_organ = dice_score_2(post_processed_mask[organ-1,:,:,:], label[organ-1,:,:,:])
            dice_list[template_key][0][organ-1] += dice_organ
            dice_list[template_key][1][organ-1] += 1
    
    avg_organ_dice = np.zeros((2,NUM_CLASS))

    with open(args.file_name, 'w') as f:
        for key in TEMPLATE.keys():
            organ_list = TEMPLATE[key]
            content = 'Task%s| '%(key)
            for organ in organ_list:
                dice = dice_list[key][0][organ-1] / dice_list[key][1][organ-1]
                content += '%s: %.4f, '%(ORGAN_NAME[organ-1], dice)
                avg_organ_dice[0][organ-1] += dice_list[key][0][organ-1]
                avg_organ_dice[1][organ-1] += dice_list[key][1][organ-1]
            f.write(content)
            f.write('\n')
        content = 'Average | '
        for i in range(NUM_CLASS):
            content += '%s: %.4f, '%(ORGAN_NAME[i], avg_organ_dice[0][i] / avg_organ_dice[1][i])
        f.write(content)
        f.write('\n')

    return avg_organ_dice

In [16]:
args_clip = SimpleNamespace(
    space_x = 1.5,
    space_y = 1.5,
    space_z = 1.5,
    roi_x = 96,
    roi_y = 96,
    roi_z = 96,
    num_samples = 2,
    data_root_path = '/blue/kgong/s.kapoor/language_guided_segmentation/CLIP-Driven-Universal-Model/',
    data_txt_path = './dataset/dataset_list/',
    batch_size = 4,
    num_workers = 8,
    a_min = -175,
    a_max = 250,
    b_min = 0.0,
    b_max = 1.0,
    dataset_list = ['PAOTtest'], #here it is used to vaidate the model
    NUM_CLASS = NUM_CLASS,
    backbone = 'swinunetr',
    trans_encoding = 'word_embedding',
    pretrain = './out/universal_total_org/epoch_400.pth',
    lr = 4e-4,
    weight_decay = 1e-5,
    precomputed_prompt_path = './pretrained_weights/embeddings_template.pkl',
    word_embedding = './pretrained_weights/txt_encoding.pth',
    dist = False,
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    model_type = None,
    file_name = 'paot_test_universal_postprocess.txt',
    os_save_fold = './not_required'
)

In [17]:
from model.Universal_model import Universal_model

In [18]:
#get the model
clip_model = Universal_model(img_size=(args_clip.roi_x, args_clip.roi_y, args_clip.roi_z),
                in_channels=1,
                out_channels=32,
                backbone=args_clip.backbone,
                encoding=args_clip.trans_encoding
                )

In [19]:
clip_checkpoint = torch.load(args_clip.pretrain)
store_dict = clip_model.state_dict()
load_dict = clip_checkpoint['net']

for key,value in load_dict.items():

    if 'swinViT' in key or 'encoder' in key or 'decoder' in key:
        name = '.'.join(key.split('.')[1:])
    else:
        name = '.'.join(key.split('.')[1:])
    if name in store_dict.keys():
        store_dict[name]=value
    else:
        print(name)

In [20]:
clip_model.load_state_dict(store_dict)

<All keys matched successfully>

In [21]:
clip_model = clip_model.to(args_clip.device)

In [22]:
clip_train_loader, clip_val_loader,clip_train_sampler, clip_val_sampler = get_train_val_data_loader(args_clip)

train len 583
val len 583


In [23]:
universal_dice, universal_dice_list = validation_postprocess(clip_model,clip_val_loader,args_clip)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 583/583 [1:06:30<00:00,  6.85s/it]


In [8]:
from model.SwinUNETR_DEEP_FILM import SwinUNETR_DEEP_FILM

In [9]:
args_film = SimpleNamespace(
    space_x = 1.5,
    space_y = 1.5,
    space_z = 1.5,
    roi_x = 96,
    roi_y = 96,
    roi_z = 96,
    num_samples = 2,
    data_root_path = '/blue/kgong/s.kapoor/language_guided_segmentation/CLIP-Driven-Universal-Model/',
    data_txt_path = './dataset/dataset_list/',
    batch_size = 4,
    num_workers = 8,
    a_min = -175,
    a_max = 250,
    b_min = 0.0,
    b_max = 1.0,
    dataset_list = ['PAOTtest'],
    NUM_CLASS = NUM_CLASS,
    backbone = 'swinunetr',
    trans_encoding = 'word_embedding',
    pretrain = './out/deep_film_org_setting/epoch_190.pth',
    lr = 4e-4,
    weight_decay = 1e-5,
    precomputed_prompt_path = 'embeddings_template_flare.pkl',
    dist = False,
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    model_type='film',
    file_name='paot_test_film_org_setting_postporcess.txt',
    os_save_fold = './flaretrain/deep_film'
)

In [10]:
film_model = SwinUNETR_DEEP_FILM(img_size=(args_film.roi_x, args_film.roi_y, args_film.roi_z),
                        in_channels=1,
                        out_channels=32,
                        precomputed_prompt_path=args_film.precomputed_prompt_path)

In [11]:
film_checkpoint = torch.load(args_film.pretrain)
store_dict = film_model.state_dict()
load_dict = film_checkpoint['net']

for key,value in load_dict.items():

    if 'swinViT' in key or 'encoder' in key or 'decoder' in key:
        name = '.'.join(key.split('.')[1:])
    else:
        name = '.'.join(key.split('.')[1:])
    if name in store_dict.keys():
        store_dict[name]=value
    else:
        print(name)

In [12]:
film_model.load_state_dict(store_dict)

<All keys matched successfully>

In [13]:
film_model = film_model.to(args_film.device)

In [None]:
film_train_loader, film_val_loader,film_train_sampler, film_val_sampler = get_train_val_txt_loader(args_film)

train len 583
val len 583


In [15]:
film_avg_organ_dice, film_dice_list = validation_postprocess(film_model,film_val_loader,args_film)

  0%|                                                                                                                                                   | 0/583 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 583/583 [57:52<00:00,  5.96s/it]
