# set working directory

In [None]:
import os

import hydra
import omegaconf
import pyrootutils

root = pyrootutils.setup_root(os.path.abspath("inherently_interpretable_model.ipynb"), pythonpath=True)
import os

os.chdir(root)

# set python path

In [None]:
import sys

sys.path.append(str(root / "src"))

# import packages

In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import tqdm
from scipy.special import softmax
from scipy.stats import norm, pearsonr
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split
from torchvision import models, transforms

import clip
from MONET.datamodules.multiplex_datamodule import MultiplexDatamodule
from MONET.utils.loader import custom_collate_per_key, dataloader_apply_func
from MONET.utils.static import (
    concept_to_prompt,
    fitzpatrick17k_disease_label,
    fitzpatrick17k_ninelabel,
    fitzpatrick17k_threelabel,
    skincon_cols,
)
from MONET.utils.text_processing import generate_prompt_token_from_concept
from MONET.utils.io import load_pkl
from PIL import Image

In [None]:
def wandb_to_exppath(wandb, log_path="/gscratch/cse/chanwkim/MONET_log/train/runs"):
    log_path = Path(log_path)
    for experiment in os.listdir(log_path):
        if os.path.exists(log_path / experiment / "wandb"):
            filenames = os.listdir(log_path / experiment / "wandb")
            filename = [filename for filename in filenames if filename.startswith("run")][0][-8:]
            if filename == wandb:
                return log_path / experiment
    raise RuntimeError("not found")


# exppath = wandb_to_exppath(
#     wandb="15eh81uv", log_path="/projects/leelab2/chanwkim/dermatology_datasets/logs/train/runs"
# )
# print([exppath / "checkpoints" / ckpt for ckpt in os.listdir(exppath / "checkpoints/")])

In [None]:
log_dir = Path("logs")

In [None]:
!gpustat

# Initialize Model

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:7"

cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

model = hydra.utils.instantiate(cfg_model)
model.to(model_device)
model.eval()

In [None]:
model_path_dir = {
    "zt0n2xd0": "/projects/leelab2/chanwkim/dermatology_datasets/logs/train/runs/2023-01-17_20-58-15/checkpoints/last.ckpt",
}

if model_name != "ViT-L/14":
    model_path = model_path_dir[model_name]
    loaded = torch.load(model_path, map_location=model_device)
    model.load_state_dict(loaded["state_dict"])
    print('loaded')

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:7"

cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

model_vanilla = hydra.utils.instantiate(cfg_model)
model_vanilla.to(model_device)
model_vanilla.eval()

In [None]:
variable_dict={}

In [None]:
def setup_dataloader(dataset_name):    
    if dataset_name=="clinical_fd_clean_nodup_nooverlap":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "clinical_fd_clean_nodup_nooverlap=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()  
        
    elif dataset_name=="derm7pt_derm_nodup":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt_derm_nodup=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()    
        
    elif dataset_name=="derm7pt_clinical_nodup":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt_clinical_nodup=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()   
        
    elif dataset_name=="derm7pt":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()
        
    elif dataset_name=="allpubmedtextbook":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "pubmed=all,textbook=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()
        
        dataloader = dm.test_dataloader()          
    
    elif dataset_name=="isic_nodup_nooverlap":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "isic_nodup_nooverlap=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()    
        
    return {"dataloader": dataloader}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", 
                    ]:
    variable_dict.setdefault(dataset_name, {})
    variable_dict[dataset_name].update(setup_dataloader(dataset_name))

In [None]:
def batch_func(batch):
    with torch.no_grad():
        batch["image"] = batch["image"].to(model_device)
        image_features = model.model_step_with_image(batch)["image_features"]
        image_features_vanilla = model_vanilla.model_step_with_image(batch)["image_features"]
    # print(batch["metadata"])
    return {
        "image_features": image_features.detach().cpu(),
        "image_features_vanilla": image_features_vanilla.detach().cpu(),
        "metadata": batch["metadata"],
    }

def setup_features(dataset_name, dataloader):
    if dataset_name=="isic_nodup_nooverlap":
        loader_applied = torch.load(log_dir/"image_features"/"isic_nodup_nooverlap.pt", map_location="cpu")
        image_features = loader_applied["image_features"].cpu()
        metadata_all = loader_applied["metadata_all"]
        
        loader_applied2 = torch.load(log_dir/"image_features"/"isic_nodup_nooverlap_vanilla.pt", map_location="cpu")
        image_features_vanilla = loader_applied2["image_features_vanilla"].cpu()
        metadata_all_vanilla = loader_applied2["metadata_all"]  
        
        assert np.all(metadata_all.index==metadata_all_vanilla.index)

        return {"image_features":image_features, 
                "image_features_vanilla":image_features_vanilla,
                "metadata_all": metadata_all}
    
    else:
        loader_applied = dataloader_apply_func(
            dataloader=dataloader,
            func=batch_func,
            collate_fn=custom_collate_per_key,
        )
        image_features = loader_applied["image_features"].cpu()
        image_features_vanilla = loader_applied["image_features_vanilla"].cpu()
        metadata_all = loader_applied["metadata"]

        return {"image_features":image_features, 
                "image_features_vanilla":image_features_vanilla,
                "metadata_all": metadata_all}

# def setup_features(dataset_name, dataloader):

#     loader_applied = dataloader_apply_func(
#         dataloader=dataloader,
#         func=batch_func,
#         collate_fn=custom_collate_per_key,
#     )
#     image_features = loader_applied["image_features"].cpu()
#     image_features_vanilla = loader_applied["image_features_vanilla"].cpu()
#     metadata_all = loader_applied["metadata"]

#     return {"image_features":image_features, 
#             "image_features_vanilla":image_features_vanilla,
#             "metadata_all": metadata_all}
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", 
                    ]:
    print(dataset_name)
    variable_dict[dataset_name].update(setup_features(dataset_name, variable_dict[dataset_name]["dataloader"])) 

In [None]:
def normalize_embedding(dataset_name, image_features):
    #prompt_ref_tokenized = clip.tokenize(prompt_ref, truncate=True)
    #output = model.model_step_with_text({"text": prompt_ref_tokenized.to(model_device)})
    #prompt_ref_embedding=output["text_features"].detach().cpu()
    #prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=1, keepdim=True)      
    
    image_features_norm = image_features / image_features.norm(dim=1, keepdim=True)
    
    return {"image_features_norm": image_features_norm}    

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", 
                    ]:
    variable_dict[dataset_name].update(
        {"image_features_vanilla_norm":normalize_embedding(dataset_name, 
                        variable_dict[dataset_name]["image_features_vanilla"])["image_features_norm"]}
    )           
    variable_dict[dataset_name].update(
        {"image_features_norm":normalize_embedding(dataset_name, 
                        variable_dict[dataset_name]["image_features"])["image_features_norm"]}
    )  
#     variable_dict[dataset_name].update(
#         {"image_features_vanilla_norm":normalize_embedding(dataset_name, 
#                         variable_dict[dataset_name]["image_features_vanilla"])["image_features_norm"]}
#     )    

In [None]:
variable_dict[dataset_name].keys()

In [None]:
isic_diagnosis_malignant_mapping=\
{'AIMP':'indeterminate',
'acrochordon':'benign',
'actinic keratosis':'benign', # 
'angiofibroma or fibrous papule':'benign', 
'angiokeratoma':'benign',
'angioma':'benign',
'atypical melanocytic proliferation':'indeterminate',
'atypical spitz tumor':'indeterminate', #
'basal cell carcinoma':'malignant', #
'cafe-au-lait macule':'benign',
'clear cell acanthoma':'benign',
'dermatofibroma':'benign', #
'lentigo NOS':'benign',
'lentigo simplex':'benign',
'lichenoid keratosis':'benign',
'melanoma':'malignant',
'melanoma metastasis':'malignant',
'neurofibroma':'benign',
'nevus':'benign',
'other':'indeterminate',
'pigmented benign keratosis':'benign', #??
'scar':'benign',
'seborrheic keratosis':'benign',
'solar lentigo':'benign',
'squamous cell carcinoma':'malignant',
'vascular lesion':'benign', # 
'verruca':'benign'
}

def isic_map_diagnosis_malignant(diagnosis, benign_malignant):
#     if diagnosis=="basal cell carcinoma":
#         print(diagnosis_malignant_mapping[diagnosis], benign_malignant)    
    if isinstance(benign_malignant, str):
        return benign_malignant
    elif diagnosis in isic_diagnosis_malignant_mapping.keys():
        return isic_diagnosis_malignant_mapping[diagnosis]
    elif np.isnan(diagnosis):
        return "indeterminate"
    else:
        raise RuntimeError

In [None]:
derm7pt_diagnosis_malignant_mapping=\
{   'basal cell carcinoma':'malignant', 
    'blue nevus': 'benign', #
    'clark nevus':'benign', # 
    'combined nevus': 'benign', #
    'congenital nevus': 'benign', #
    'dermal nevus': 'benign', 
    'dermatofibroma':'benign', 
    'lentigo': 'benign',
    'melanoma (in situ)': 'malignant',
    'melanoma (less than 0.76 mm)': 'malignant',
    'melanoma (0.76 to 1.5 mm)': 'malignant',
    'melanoma (more than 1.5 mm)': 'malignant',
    'melanoma metastasis': 'malignant',
    'melanosis': 'benign',# 
    'miscellaneous': 'unknown', #
    'recurrent nevus': 'benign', #
    'reed or spitz nevus': 'benign', #
    'seborrheic keratosis':'benign',
    'vascular lesion': 'benign', 
    'melanoma': 'malignant',
}

def derm7pt_map_diagnosis_malignant(diagnosis):
#     if diagnosis=="basal cell carcinoma":
#         print(diagnosis_malignant_mapping[diagnosis], benign_malignant)    
    if diagnosis in derm7pt_diagnosis_malignant_mapping.keys():
        return derm7pt_diagnosis_malignant_mapping[diagnosis]
    elif np.isnan(diagnosis):
        return "indeterminate"
    else:
        raise RuntimeError

In [None]:
ddi_map = {
    "acral-melanotic-macule": "melanoma look-alike",
    "atypical-spindle-cell-nevus-of-reed": "melanoma look-alike",
    "benign-keratosis": "melanoma look-alike",
    "blue-nevus": "melanoma look-alike",
    "dermatofibroma": "melanoma look-alike",
    "dysplastic-nevus": "melanoma look-alike",
    "epidermal-nevus": "melanoma look-alike",
    "hyperpigmentation": "melanoma look-alike",
    "keloid": "melanoma look-alike",
    "inverted-follicular-keratosis": "melanoma look-alike",
    "melanocytic-nevi": "melanoma look-alike",
    "melanoma": "melanoma",
    "melanoma-acral-lentiginous": "melanoma",
    "melanoma-in-situ": "melanoma",
    "nevus-lipomatosus-superficialis": "melanoma look-alike",
    "nodular-melanoma-(nm)": "melanoma",
    "pigmented-spindle-cell-nevus-of-reed": "melanoma look-alike",
    "seborrheic-keratosis": "melanoma look-alike",
    "seborrheic-keratosis-irritated": "melanoma look-alike",
    "solar-lentigo": "melanoma look-alike",
}

In [None]:
ddi_map.keys()

In [None]:
normal_skin=['clean', "smooth", 'Healthy', 'normal', 'soft', 'flat']
concept_reference_dict = {
    "Asymmetry": ["Symmetry", "Regular", "Uniform"],
    "Irregular": ["Regular", "Smooth"],
    "Black": ["White", "Creamy", "Colorless", "Unpigmented"],
    "Blue": ["Green", "Red"],
    "White": ["Black", "Colored", "Pigmented"],
    "Brown": ["Pale", "White"],
    "Erosion":["Deposition", "Buildup"],
    "Multiple Colors": ["Single Color", "Unicolor"],
    "Tiny": ["Large", "Big"],
    "Regular": ["Irregular"],  
}

concept_reference_dict.update({
    'derm7ptconcept_pigment network':['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat'],
    'derm7ptconcept_regression structure':['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat'],
    'derm7ptconcept_pigmentation':['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat'],
    'derm7ptconcept_blue whitish veil':['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat'],
    'derm7ptconcept_vascular structures':['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat'],
    'derm7ptconcept_streaks':['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat'],
    'derm7ptconcept_dots and globules':['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat'],
})


    
for concept_name in skincon_cols:    
    if concept_name=="skincon_Patch":
        negative_terms=["Spotted"]    
    elif concept_name == "skincon_Exudate":
        negative_terms = ["Absence"]
    elif concept_name == "skincon_Xerosis":
        negative_terms = ["Moisturized"]
    elif concept_name == "skincon_Warty/Papillomatous":
        negative_terms = ["Smooth"]
    elif concept_name == "skincon_Dome-shaped":
        negative_terms = ["Flat"]
    elif concept_name == "skincon_Brown(Hyperpigmentation)":
        negative_terms = ["Hypopigmentation"]
    elif concept_name == "skincon_Translucent":
        negative_terms = ["Opaque"]
    elif concept_name == "skincon_White(Hypopigmentation)":
        negative_terms = ["Hyperpigmentation"]
    elif concept_name == "skincon_Purple":
        negative_terms = ["Yellow"]
    elif concept_name == "skincon_Yellow":
        negative_terms = ["Purple"]
    elif concept_name == "skincon_Black":
        negative_terms = ["White", "Creamy", "Colorless", "Unpigmented"]
    elif concept_name == "skincon_Lichenification":
        negative_terms = ["Softening"]
    elif concept_name == "skincon_Blue":
        negative_terms = ["Orange"]
    elif concept_name == "skincon_Gray":
        negative_terms = ["Colorful"]
    else:
        negative_terms = ['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat']
        
    concept_reference_dict[concept_name]=negative_terms
    


# print()
# for key in concept_reference_dict.keys():
#     print(f"{key}: {concept_reference_dict[key]}")

concept_reference_dict_={}
for key,value in concept_reference_dict.items():
    concept_reference_dict_[key if ("skincon_" in key) or ("derm7ptconcept_" in key) else f"cbm_{key}"]=\
    [value if ("skincon_" in v) or ("derm7ptconcept_" in v) else f"cbm_{v}" for v in value]
concept_reference_dict=concept_reference_dict_;del concept_reference_dict_

# concept_reference_dict["red"]=["cbm_White",  "cbm_Colorless", "cbm_Unpigmented"]

print()
for key in concept_reference_dict.keys():
    print(f"{key}: {concept_reference_dict[key]}")

In [None]:
def set_config(dataset_name, metadata_all):
    if "clinical_fd_clean" in dataset_name:
        y_pos_malignant=(((metadata_all["source"]=="fitz")&(metadata_all["three_partition_label"]=="malignant"))|
              ((metadata_all["source"]=="ddi")&(metadata_all["malignant"] == True))).values
        
        y_pos_melanoma=(((metadata_all["source"]=="fitz")&(metadata_all["nine_partition_label"]=="malignant melanoma"))|
              (metadata_all.apply(lambda x: (ddi_map[x["disease"]]=="melanoma") if (x["source"]=="ddi" and x["disease"] in ddi_map.keys()) else False, axis=1))).values
        
        valid_idx_malignant=(metadata_all["skincon_Do not consider this image"]!=1).values
        valid_idx_melanoma=((metadata_all["skincon_Do not consider this image"]!=1)&\
                          (
                            ((metadata_all["source"]=="fitz")&((metadata_all["nine_partition_label"]=='malignant melanoma')|
                                                              (metadata_all["nine_partition_label"]=='benign melanocyte')|
                                                              (metadata_all["label"]=='seborrheic keratosis')|
                                                              (metadata_all["label"]=='dermatofibroma')))|
                            ((metadata_all["source"]=="ddi")&(metadata_all["disease"].isin(ddi_map.keys())))
                          )).values
        
        concept_list=skincon_cols
        
        concept_list=concept_list+\
                        np.unique([f"cbm_{i.lower()}" for i in ['clean', "smooth", 'Healthy', 'normal', 'soft', 'flat']+\
                            ["Asymmetry", "Symmetry", "Regular", "Uniform"]+\
                            ["Irregular", "Regular", "Smooth"]+\
                            ["Black", "White", "Creamy", "Colorless", "Unpigmented"]+\
                            ["Blue", "Green", "Red"]+\
                            ["White", "Black", "Colored", "Pigmented"]+\
                            ["Brown", "Pale", "White"]+\
                            ["Erosion", "Deposition", "Buildup"]+\
                            ["Multiple colors", "Single Color", "Unicolor"]+\
                            ["Tiny", "Large", "Big"]+\
                            ["Regular", "Irregular"]]).tolist()
        
        
        concept_list=concept_list+[concept_name 
                                   for concept_name in np.unique(list(concept_reference_dict.keys())+[j for i in list(concept_reference_dict.values()) for j in i])
                                   if "skincon_" not in concept_name] 
        
        concept_list=[concept_name for concept_name in concept_list if "derm7ptconcept" not in concept_name]
        
        concept_list+=["chest_atelectasis", "chest_cardiomegaly", "chest_consolidation", "chest_odema", "chest_pleural effusion",
                      ]

        
    elif "isic" in dataset_name:  
        metadata_all["benign_malignant_full"]=\
        metadata_all.apply(lambda x: isic_map_diagnosis_malignant(x["diagnosis"], x["benign_malignant"]), axis=1)
        #metadata_all["benign_malignant_full"].value_counts()
        #metadata_all.groupby("diagnosis").apply(lambda x: x["benign_malignant_full"].value_counts())
        metadata_all["benign_malignant_bool"]=metadata_all["benign_malignant_full"].str.contains("malignant")
        
        
        #print(metadata_all[["benign_malignant_full", "diagnosis"]].value_counts())
        
        y_pos_malignant=metadata_all["benign_malignant_bool"].values
        y_pos_melanoma=metadata_all["diagnosis"].fillna('null').str.contains("melanoma").values
        
#         print(metadata_all["diagnosis"].fillna(0).value_counts())
        
        valid_idx_malignant = (metadata_all["benign_malignant_full"].str.contains("malignant")|metadata_all["benign_malignant_full"].str.contains("benign")).values
        valid_idx_melanoma = (~metadata_all["diagnosis"].isnull()).values
        
        concept_list=skincon_cols
        
        concept_list=concept_list+\
                            ["purple pen", 
                             "finger", 
                             "nail", 
                             "pinkish", 
                             "red", 
                             "hair", 
                             "orange sticker", 
                             "blue sticker", 
                             "red sticker",
                             "dermoscope border",
                             "gel",
                             "malignant",
                             "melanoma"]      
        
        
        concept_list=concept_list+[f"derm7ptconcept_{derm7ptconcept}" for derm7ptconcept in ["pigment network", "typical pigment network", "atypical pigment network",
                                   "regression structure",
                                   "pigmentation", "regular pigmentation", "irregular pigmentation",
                                   "blue whitish veil", 
                                   "vascular structures", "typical vascular structures", "atypical vascular structures",
                                   "streaks", "regular streaks", "irregular streaks",
                                   "dots and globules", "regular dots and globules", "irregular dots and globules",
                                  ]]
        
        concept_list=concept_list+[f"isicconcept_{isicconcept}" for isicconcept in ["pigment_network", 
                                                                                   "negative_network",
                                                                                   "milia_like_cyst", 
                                                                                   "streaks", 
                                                                                   "globules"]]
        
        concept_list=concept_list+[f"disease_{disease_name}" for disease_name in ['seborrheic keratosis', 'nevus', 'squamous cell carcinoma',
                        'melanoma', 'lichenoid keratosis', 'lentigo',
                        'actinic keratosis', 'basal cell carcinoma', 'dermatofibroma',
                        'atypical melanocytic proliferation', 'verruca',
                        'clear cell acanthoma', 'angiofibroma or fibrous papule', 'scar',
                        'angioma', 'atypical spitz tumor', 'solar lentigo', 'AIMP',
                        'neurofibroma', 'lentigo simplex', 'acrochordon', 
                        'angiokeratoma', 'vascular lesion', 'cafe-au-lait macule',
                        'pigmented benign keratosis']]
        
        concept_list=concept_list+[concept_name 
                                   for concept_name in np.unique(list(concept_reference_dict.keys())+[j for i in list(concept_reference_dict.values()) for j in i])
                                   if "skincon_" not in concept_name]      
        
        concept_list+=["chest_atelectasis", "chest_cardiomegaly", "chest_consolidation", "chest_odema", "chest_pleural effusion",
                      ]        
        
    elif "derm7pt" in dataset_name:  
        metadata_all["benign_malignant_full"]=\
        metadata_all.apply(lambda x: derm7pt_map_diagnosis_malignant(x["diagnosis"]), axis=1)
        metadata_all["benign_malignant_bool"]=metadata_all["benign_malignant_full"].str.contains("malignant")
        
        y_pos_malignant=metadata_all["benign_malignant_bool"].values
        y_pos_melanoma=metadata_all["diagnosis"].str.contains("melanoma").values
        
        valid_idx_malignant = (metadata_all["benign_malignant_full"].str.contains("malignant")|metadata_all["benign_malignant_full"].str.contains("benign")).values
        valid_idx_melanoma = (~metadata_all["diagnosis"].isnull()).values
        
        concept_list=skincon_cols
        
        concept_list=concept_list+\
                            ["purple pen", 
                             "finger", 
                             "nail", 
                             "pinkish", 
                             "red", 
                             "hair", 
                             "orange sticker", 
                             "blue sticker", 
                             "red sticker",
                             "dermoscope border",
                             "gel",
                             "malignant",
                             "melanoma"]        
        
        concept_list=concept_list+[f"derm7ptconcept_{derm7ptconcept}" for derm7ptconcept in ["pigment network", "typical pigment network", "atypical pigment network",
                                   "regression structure",
                                   "pigmentation", "regular pigmentation", "irregular pigmentation",
                                   "blue whitish veil", 
                                   "vascular structures", "typical vascular structures", "atypical vascular structures",
                                   "streaks", "regular streaks", "irregular streaks",
                                   "dots and globules", "regular dots and globules", "irregular dots and globules",
                                  ]]
        
        concept_list=concept_list+[f"isicconcept_{isicconcept}" for isicconcept in ["pigment_network", 
                                                                                   "negative_network",
                                                                                   "milia_like_cyst", 
                                                                                   "streaks", 
                                                                                   "globules"]]
             
        
        concept_list=concept_list+[f"disease_{disease_name}" for disease_name in ['basal cell carcinoma', 'blue nevus', 'clark nevus',
                                                               'combined nevus', 'congenital nevus', 'dermal nevus',
                                                               'dermatofibroma', 'lentigo', 'melanoma', 'melanosis',
                                                                'recurrent nevus', 'reed or spitz nevus',
                                                               'seborrheic keratosis', 'vascular lesion']]           
        
        
        concept_list=concept_list+[concept_name 
                                   for concept_name in np.unique(list(concept_reference_dict.keys())+[j for i in list(concept_reference_dict.values()) for j in i])
                                   if "skincon_" not in concept_name]           
        
        concept_list+=["chest_atelectasis", "chest_cardiomegaly", "chest_consolidation", "chest_odema", "chest_pleural effusion",
                      ]        
        
    
    print("y_pos_malignant", y_pos_malignant.sum(), "null", np.isnan(y_pos_malignant).sum())
    print("valid_idx_malignant", valid_idx_malignant.sum(), np.isnan(valid_idx_malignant).sum())
    print("y_pos_melanoma", y_pos_melanoma.sum(), np.isnan(y_pos_melanoma).sum())
    print("valid_idx_melanoma", valid_idx_melanoma.sum(), np.isnan(valid_idx_melanoma).sum())
    
    return {"valid_idx_malignant": valid_idx_malignant,
            "valid_idx_melanoma": valid_idx_melanoma,
            "y_pos_malignant": y_pos_malignant,
            "y_pos_melanoma": y_pos_melanoma,
            "metadata_all": metadata_all,
            "concept_list": concept_list}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup", "isic_nodup_nooverlap"]:
    print(dataset_name)
    variable_dict[dataset_name].update(set_config(dataset_name, variable_dict[dataset_name]["metadata_all"]))
    print()

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["valid_idx_malignant"].sum()

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["valid_idx_melanoma"].sum()

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"]\
[variable_dict["clinical_fd_clean_nodup_nooverlap"]["valid_idx_malignant"]]["md5hash"].fillna("null").value_counts()

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"]\
[variable_dict["clinical_fd_clean_nodup_nooverlap"]["valid_idx_melanoma"]]["md5hash"].fillna("null").value_counts().sum()

In [None]:
769-275

In [None]:
494+275

In [None]:
variable_dict["derm7pt_derm_nodup"]["valid_idx_malignant"].sum()

In [None]:
def get_concept_embedding(dataset_name, concept_list, clip_model):
    prompt_info={}
    
    for concept_name in concept_list:
        if "clinical_fd_clean" in dataset_name:
            if concept_name.startswith("cbm_"):  
                cbm_concept_name=concept_name[4:]
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                prompt_target=[[prompt_template.format(term) for term in [cbm_concept_name]] for prompt_template in prompt_template_list]
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]      
            elif concept_name.startswith("chest_"):  
                cbm_concept_name=concept_name[6:]
#                 prompt_template_list=["This is skin imageimage of {}"]
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                prompt_template_list=["This is image of {}"]
                prompt_target=[[prompt_template.format(term) for term in [cbm_concept_name]] for prompt_template in prompt_template_list]
                prompt_ref = [["This is image"]]                      
            else:
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v    
                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is image of {}", "This is dermatology image of {}", "This is image of {}"]
                prompt_template_list=["This is image of {}"]
                #prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]            
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]

                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]          
        
        elif "isic" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
                
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil","blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)         
                    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                     
                
            elif concept_name.startswith("isicconcept_"):
                isicconcept=concept_name[12:]
                if isicconcept=="pigment_network":
                    concept_term_list=["pigment network"]
                elif isicconcept=="negative_network":
                    concept_term_list=["negative network"]
                elif isicconcept=="milia_like_cyst":
                    concept_term_list=["milia like cyst"]
                    concept_term_list=["seborrheic keratosis"]
                elif isicconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif isicconcept=="globules":
                    concept_term_list=["globules"]
                else:
                    raise ValueError(isicconcept)                
            
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                 
                
                
            elif concept_name.startswith("cbm_"):  
                cbm_concept_name=concept_name[4:]
                prompt_target=[[f"This is dermatoscopy of {cbm_concept_name}"],
                               [f"This is dermoscopy of {cbm_concept_name}"]] 
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                  
                
            elif concept_name.startswith("chest_"):  
                cbm_concept_name=concept_name[6:]
                prompt_target=[[f"This is dermatoscopy of {cbm_concept_name}"],
                               [f"This is dermoscopy of {cbm_concept_name}"]] 
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                                  
                
                
            elif concept_name.startswith("disease_"):  
                if concept_name=="disease_AIMP":
                    disease_name=concept_name[8:]
                    prompt_target=[["This is dermatoscopy of AIMP",
                                    "This is dermatoscopy of Atypical intraepidermal melanocytic proliferation"],
                                   ["This is dermoscopy of AIMP",
                                    "This is dermoscopy of Atypical intraepidermal melanocytic proliferation"]]
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                else:
                    disease_name=concept_name[8:]
                    prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                                   [f"This is dermoscopy of {disease_name}"]] 
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                    
        elif "derm7pt_derm" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil","blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)
    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
            elif concept_name.startswith("isicconcept_"):
                isicconcept=concept_name[12:]
                if isicconcept=="pigment_network":
                    concept_term_list=["pigment network"]
                elif isicconcept=="negative_network":
                    concept_term_list=["negative network"]
                elif isicconcept=="milia_like_cyst":
                    concept_term_list=["milia like cyst"]
                    concept_term_list=["seborrheic keratosis"]
                elif isicconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif isicconcept=="globules":
                    concept_term_list=["globules"]
                else:
                    raise ValueError(isicconcept)                
            
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                  
                
            elif concept_name.startswith("cbm_"):  
                cbm_concept_name=concept_name[4:]
                prompt_target=[[f"This is dermatoscopy of {cbm_concept_name}"],
                               [f"This is dermoscopy of {cbm_concept_name}"]] 
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                  
            elif concept_name.startswith("chest_"):  
                cbm_concept_name=concept_name[6:]
                prompt_target=[[f"This is dermatoscopy of {cbm_concept_name}"],
                               [f"This is dermoscopy of {cbm_concept_name}"]] 
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                                  
                
            elif concept_name.startswith("disease_"):  
                disease_name=concept_name[8:]
                prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                               [f"This is dermoscopy of {disease_name}"]] 
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]                    
                
        
        #print(prompt_target, prompt_ref)
        # target embedding
        prompt_target_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_target]
        with torch.no_grad():
            prompt_target_embedding = torch.stack([clip_model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_target_tokenized])
        prompt_target_embedding_norm=prompt_target_embedding/prompt_target_embedding.norm(dim=2, keepdim=True)          

        # reference embedding
        prompt_ref_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_ref]
        with torch.no_grad():
            prompt_ref_embedding = torch.stack([clip_model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_ref_tokenized])
        prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=2, keepdim=True)                

        prompt_info[concept_name]={"prompt_ref_embedding_norm":prompt_ref_embedding_norm,
                                   "prompt_target_embedding_norm":prompt_target_embedding_norm,
                                  }
        print(dataset_name, concept_name, "\n" , prompt_target, prompt_ref)
        print('-----------')
        #print(prompt_target_embedding_norm)
        #print(prompt_ref_embedding_norm)
        del prompt_ref
        del prompt_target
    
    return {"prompt_info": prompt_info}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", 
                    ]:

    variable_dict[dataset_name].update(
        {"prompt_info_vanilla":get_concept_embedding(dataset_name, 
                      concept_list=variable_dict[dataset_name]["concept_list"],
                             clip_model=model_vanilla)["prompt_info"]})  
        
    variable_dict[dataset_name].update(
        {"prompt_info":get_concept_embedding(dataset_name, 
                      concept_list=variable_dict[dataset_name]["concept_list"],
                             clip_model=model)["prompt_info"]})  

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", 
                    ]:

    variable_dict[dataset_name].update(
        {"prompt_info_vanilla":get_concept_embedding(dataset_name, 
                      concept_list=variable_dict[dataset_name]["concept_list"],
                             clip_model=model_vanilla)["prompt_info"]})  
        
    variable_dict[dataset_name].update(
        {"prompt_info":get_concept_embedding(dataset_name, 
                      concept_list=variable_dict[dataset_name]["concept_list"],
                             clip_model=model)["prompt_info"]})  

In [None]:
!ls logs/experiment_results/ -trl

In [None]:
# torch.save(variable_dict, "logs/experiment_results/cbm_variables_1003.pt")

In [None]:
variable_dict = torch.load("logs/experiment_results/cbm_variables_1003.pt", map_location="cpu")

In [None]:
def calculate_similaity_score(image_features_norm, 
                              prompt_target_embedding_norm,
                              prompt_ref_embedding_norm,
                              temp=1,
                              normalize=True):

    target_similarity=prompt_target_embedding_norm.float()@image_features_norm.T.float()
    ref_similarity=prompt_ref_embedding_norm.float()@image_features_norm.T.float()
    
    
    target_similarity_mean=target_similarity.mean(dim=[1])
    ref_similarity_mean=ref_similarity.mean(axis=1)
         
    if normalize:
        similarity_score=scipy.special.softmax([target_similarity_mean.numpy()/temp, 
                            ref_similarity_mean.numpy()/temp], axis=0)[0,:].mean(axis=0)   
    else:
        similarity_score=target_similarity_mean.mean(axis=0)

    
    return similarity_score

In [None]:
import scipy

In [None]:
variable_dict[dataset_name]["prompt_info"].keys()

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"].keys()

In [None]:
# def calculate_similaity_score(image_features_norm, 
#                               prompt_target_embedding_norm,
#                               prompt_ref_embedding_norm,
#                               temp=1,
#                               normalize=True):

#     target_similarity=prompt_target_embedding_norm.float()@image_features_norm.T.float()
#     ref_similarity=prompt_ref_embedding_norm.float()@image_features_norm.T.float()
    
    
#     target_similarity_mean=target_similarity.mean(dim=[1])
#     ref_similarity_mean=ref_similarity.mean(axis=1)
         
#     if normalize:
#         similarity_score=scipy.special.softmax([target_similarity_mean.numpy()/temp, 
#                             ref_similarity_mean.numpy()/temp], axis=0)[0,:].mean(axis=0)   
#     else:
#         similarity_score=target_similarity_mean.mean(axis=0)

    
#     return similarity_score

In [None]:
variable_dict.keys()

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"][
~variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"]["skincon_Pigmented"].isnull()]\
["skincon_Pigmented"].fillna(-9).value_counts()

# manual annotation

In [None]:
def get_concept_bool_from_metadata(dataset_name, metadata_all, concept_name):
    if "derm7pt_derm" in dataset_name:
        if concept_name=="derm7ptconcept_pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
            concept_bool=(metadata_all["pigment_network"]!="absent")
        elif concept_name=="derm7ptconcept_typical pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
#                 valid_idx=(metadata_all["pigment_network"].str.contains("typical")).values
            concept_bool=(metadata_all["pigment_network"]=="typical")
        elif concept_name=="derm7ptconcept_atypical pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
#                 valid_idx=(metadata_all["pigment_network"].str.contains("typical")).values
            concept_bool=(metadata_all["pigment_network"]=="atypical")
        elif concept_name=="derm7ptconcept_regression structure":
            valid_idx=(~metadata_all["regression_structures"].isnull()).values
            concept_bool=(metadata_all["regression_structures"]!="absent")
        elif concept_name=="derm7ptconcept_pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
            concept_bool=(metadata_all["pigmentation"]!="absent")
        elif concept_name=="derm7ptconcept_regular pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
#                 valid_idx=(metadata_all["pigmentation"].str.contains("regular")).values
            concept_bool=(metadata_all["pigmentation"].str.contains(" regular"))
        elif concept_name=="derm7ptconcept_irregular pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
#                 valid_idx=(metadata_all["pigmentation"].str.contains("regular")).values
            concept_bool=(metadata_all["pigmentation"].str.contains(" irregular"))
        elif concept_name=="derm7ptconcept_blue whitish veil":
            valid_idx=(~metadata_all["blue_whitish_veil"].isnull()).values
            concept_bool=(metadata_all["blue_whitish_veil"]!="absent")
        elif concept_name=="derm7ptconcept_vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"]!="absent")
        elif concept_name=="derm7ptconcept_typical vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"].isin(["within regression", "arborizing", "comma", "hairpin", "wreath"]))
        elif concept_name=="derm7ptconcept_atypical vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"].isin(["dotted", "linear irregular"]))
        elif concept_name=="derm7ptconcept_streaks":
            #print(metadata_all["streaks"].value_counts())
            valid_idx=(~metadata_all["streaks"].isnull()).values
            concept_bool=(metadata_all["streaks"]!="absent")
        elif concept_name=="derm7ptconcept_regular streaks":
            valid_idx=(~metadata_all["streaks"].isnull()).values
            #valid_idx=(metadata_all["streaks"].str.contains("regular")).values
            concept_bool=(metadata_all["streaks"]=="regular")
        elif concept_name=="derm7ptconcept_irregular streaks":
            valid_idx=(~metadata_all["streaks"].isnull()).values
            #valid_idx=(metadata_all["streaks"].str.contains("regular")).values
            concept_bool=(metadata_all["streaks"]=="irregular")
        elif concept_name=="derm7ptconcept_dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]!="absent")
        elif concept_name=="derm7ptconcept_regular dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]=="regular")
        elif concept_name=="derm7ptconcept_irregular dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]=="irregular")
        else:
            raise ValueError(concept_name)
            
        concept_bool_true=concept_bool
        concept_bool_false=(~concept_bool)


    elif "isic" in dataset_name:
        if concept_name=='isicconcept_pigment_network':
            label=(metadata_all["pigment_network"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values
            concept_bool=(label>30)
        elif concept_name=='isicconcept_negative_network':
            label=(metadata_all["negative_network"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_milia_like_cyst':
            label=(metadata_all["milia_like_cyst"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_streaks':
            label=(metadata_all["streaks"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_globules':
            label=(metadata_all["globules"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values
            concept_bool=(label>30)
        else:
            raise ValueError(concept_name)            
            
    elif "clinical_fd_clean" in dataset_name:
        if concept_name.startswith("skincon_"):
            concept_bool_true=(metadata_all[concept_name]==1)
            concept_bool_false=(metadata_all[concept_name]==0)
        else:
            raise ValueError(concept_name)
        
      
    return {"concept_bool_true": concept_bool_true,
            "concept_bool_false": concept_bool_false,
           }

In [None]:
def temp():
    dataset_name="clinical_fd_clean_nodup_nooverlap"
    
    score_derm_list=[]
    score_other_list=[]
    
    record_list=[]
    
    #for concept_name in skincon_cols:
    for concept_name in ["skincon_Nodule"]:
#         score_derm=calculate_similaity_score(
#                     image_features_norm=variable_dict[dataset_name]["image_features_norm"],
#                     prompt_target_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_target_embedding_norm"],
#                     prompt_ref_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_ref_embedding_norm"],
#                     temp=1/np.exp(4.5944),
#                     ) 
        score_derm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_target_embedding_norm"].float()@variable_dict[dataset_name]["image_features_norm"].T.float()
        score_derm=score_derm.mean(axis=[0,1]).numpy()
        
        
        concept_bool=get_concept_bool_from_metadata(dataset_name, 
                                       variable_dict[dataset_name]["metadata_all"], 
                                       concept_name)
#         concept_bool["concept_bool_true"])
        #print(concept_bool["concept_bool_false"].shape)
        #print(score_derm.shape)
#         sds

#         fig=plt.figure()
#         ax=fig.add_subplot()
#         ax.hist(score_derm, alpha=0.5, label="derm"+" "+str(np.round(np.sort(score_derm)[100],5)))
#         ax.hist(score_other, alpha=0.5, label="ches"+" "+str(np.round(np.sort(score_other)[100],5)))
#         ax.legend()
#         ax.set_title(concept_name)
#         score_derm_list.append(np.sort(score_derm)[::-1][100])
        score_derm_list+=score_derm[concept_bool["concept_bool_true"]].tolist()
    
#         print(np.sort(score_derm))
        record_list.append({"score_top100": np.sort(score_derm)[::-1][100],
                            "score_mean": np.mean(score_derm),
                           "prompt_type": "derm"
                           })    
    
    
#     for concept_name_other in ['chest_atelectasis', 'chest_cardiomegaly', 'chest_consolidation', 'chest_odema', 'chest_pleural effusion']:
    for concept_name_other in ['chest_pleural effusion']:
#         score_other=calculate_similaity_score(
#                 image_features_norm=variable_dict[dataset_name]["image_features_norm"],
#                 prompt_target_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name_other]["prompt_target_embedding_norm"],
#                 prompt_ref_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name_other]["prompt_ref_embedding_norm"],
#                 temp=1/np.exp(4.5944),
#                 )
        
        
#         score_derm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_target_embedding_norm"].float()@variable_dict[dataset_name]["image_features_norm"].T.float()
#         score_derm=score_derm.mean(axis=[0,1]).numpy()        
        score_other=variable_dict[dataset_name]["prompt_info"][concept_name_other]["prompt_target_embedding_norm"].float()@variable_dict[dataset_name]["image_features_norm"].T.float()
        score_other=score_other.mean(axis=[0,1]).numpy()
    
    
#         score_other_list.append(np.sort(score_other)[::-1][100])
        
        score_other_list+=score_other[concept_bool["concept_bool_true"]].tolist()

        record_list.append({"score_top100": np.sort(score_other)[::-1][100],
                            "score_mean": np.mean(score_other),                            
                           "prompt_type": "other"
                           })        
        
#     fig=plt.figure()
#     ax=fig.add_subplot()
#     ax.hist(score_derm_list, alpha=0.5)#, label="derm"+" "+str(np.round(np.sort(score_derm)[100],5)))
#     ax.hist(score_other_list, alpha=0.5)#, label="ches"+" "+str(np.round(np.sort(score_other)[100],5)))
#     ax.legend()
#     ax.set_title(concept_name)

    fig=plt.figure()
    ax=fig.add_subplot()
    
    ax.hist(score_other_list, density=True, alpha=0.5)
    ax.hist(score_derm_list, density=True, alpha=0.5)
    
    
    dsd
    sns.histplot(x="score_top100", hue="prompt_type", bins=100, data=pd.DataFrame(record_list), ax=ax)
    
    print(pd.DataFrame(record_list).groupby("prompt_type")["score_top100"].apply(
        lambda x: 
            {"mean": x.mean(),
              "std": x.std(),
              "max": x.max(),
              "q3": x.quantile(q=0.75),                                                                      
              "median": x.median(),
              "q1": x.quantile(q=0.25),
              "min": x.min(),  
            }
        )
    )
    
    print(pd.DataFrame(record_list).groupby("prompt_type")["score_mean"].apply(
        lambda x: 
            {"mean": x.mean(),
              "std": x.std(),
              "max": x.max(),
              "q3": x.quantile(q=0.75),                                                                      
              "median": x.median(),
              "q1": x.quantile(q=0.25),
              "min": x.min(),  
            }
        )
    )    
         
    fig=plt.figure()
    ax=fig.add_subplot()         
    ax.hist(score_derm_list, alpha=0.5)#, label="derm"+" "+str(np.round(np.sort(score_derm)[100],5)))
    ax.hist(score_other_list, alpha=0.5)#, label="ches"+" "+str(np.round(np.sort(score_other)[100],5)))
    ax.legend()
    ax.set_title(concept_name)
    
temp()


In [None]:
def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    print(se,scipy.stats.t.ppf((1 + confidence) / 2., n-1))
    return m, m-h, m+h

def mean_confidence_interval2(data, confidence=0.95):

    return scipy.stats.t.interval(alpha=confidence, df=len(data)-1, loc=np.mean(data), scale=scipy.stats.sem(data)) 

def temp():
    dataset_name="clinical_fd_clean_nodup_nooverlap"
    
    score_derm_list=[]
    score_other_list=[]
    
    record_list=[]
    
    for concept_name in skincon_cols:
#         score_derm=calculate_similaity_score(
#                     image_features_norm=variable_dict[dataset_name]["image_features_norm"],
#                     prompt_target_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_target_embedding_norm"],
#                     prompt_ref_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_ref_embedding_norm"],
#                     temp=1/np.exp(4.5944),
#                     )  
        
        score_derm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_target_embedding_norm"].float()@variable_dict[dataset_name]["image_features_norm"].T.float()
        score_derm=score_derm.mean(axis=[0,1]).numpy()                
        
        
        concept_bool=get_concept_bool_from_metadata(dataset_name, 
                                       variable_dict[dataset_name]["metadata_all"], 
                                       concept_name)
    
        
        sample_idx=concept_bool["concept_bool_true"]&variable_dict[dataset_name]["valid_idx_malignant"]
        sample_non_idx=concept_bool["concept_bool_false"]&variable_dict[dataset_name]["valid_idx_malignant"]
        sample_all=variable_dict[dataset_name]["valid_idx_malignant"]
        print(sample_idx.sum(), sample_non_idx.sum())
        #print(concept_name, sample_idx.sum())
        
#         for concept_name_other in ['chest_atelectasis', 'chest_cardiomegaly', 'chest_consolidation', 'chest_odema', 'chest_pleural effusion']:
        for concept_name_other in ['chest_pleural effusion']:
#             score_other=calculate_similaity_score(
#                     image_features_norm=variable_dict[dataset_name]["image_features_norm"],
#                     prompt_target_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name_other]["prompt_target_embedding_norm"],
#                     prompt_ref_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name_other]["prompt_ref_embedding_norm"],
#                     temp=1/np.exp(4.5944),
#                     )   
            

            score_other=variable_dict[dataset_name]["prompt_info"][concept_name_other]["prompt_target_embedding_norm"].float()@variable_dict[dataset_name]["image_features_norm"].T.float()
            score_other=score_other.mean(axis=[0,1]).numpy()            
            
            score_other_ref=variable_dict[dataset_name]["prompt_info"][concept_name_other]["prompt_ref_embedding_norm"].float()@variable_dict[dataset_name]["image_features_norm"].T.float()
            score_other_ref=score_other_ref.mean(axis=[0,1]).numpy()                        

            record_list.append(
            
            {
                "concept_name": concept_name,
                "concept_name_other":concept_name_other,
                "num_pos": sample_idx.sum(),
                "prop_larger_pos": np.mean((score_derm[sample_idx]-score_other[sample_idx])>0),
                "pval_larger_pos": scipy.stats.ttest_rel(score_derm[sample_idx], score_other[sample_idx], alternative='greater')[1],
                "prop_larger_neg": np.mean((score_derm[sample_non_idx]-score_other[sample_non_idx])>0),
                "pval_larger_neg": scipy.stats.ttest_rel(score_derm[sample_non_idx], score_other[sample_non_idx], alternative='greater')[1],
                "prop_larger_all": np.mean((score_derm[sample_all]-score_other[sample_all])>0),
                "pval_larger_all": scipy.stats.ttest_rel(score_derm[sample_all], score_other[sample_all], alternative='greater')[1],                
                "prop_larger_ref": np.mean((score_derm[sample_idx]-score_other_ref[sample_idx])>0)
            }
            )
            #print(concept_name, score_other[sample_idx].mean())
            
        #print()
        
    record_list_df=pd.DataFrame(record_list).set_index("concept_name")
    
    
    
    print((record_list_df[record_list_df["num_pos"]>=30]["prop_larger_pos"].hist()))
    print((record_list_df[record_list_df["num_pos"]>=30]["prop_larger_pos"].hist()))
    print(record_list_df[record_list_df["num_pos"]>=30])
    
    print("larger_pos")
    print(mean_confidence_interval2(record_list_df[record_list_df["num_pos"]>=30]["prop_larger_pos"]),
          "prop mean", record_list_df[record_list_df["num_pos"]>=30]["prop_larger_pos"].mean(),
          "prop std", record_list_df[record_list_df["num_pos"]>=30]["prop_larger_pos"].std(),
          "prop max", record_list_df[record_list_df["num_pos"]>=30]["prop_larger_pos"].max(),
          "prop min", record_list_df[record_list_df["num_pos"]>=30]["prop_larger_pos"].min(),
          mean_confidence_interval2(record_list_df[record_list_df["num_pos"]>=30]["pval_larger_pos"]),
          "pval mean", record_list_df[record_list_df["num_pos"]>=30]["pval_larger_pos"].mean(),
          "pval median", record_list_df[record_list_df["num_pos"]>=30]["pval_larger_neg"].median(),
          "pval std", record_list_df[record_list_df["num_pos"]>=30]["pval_larger_pos"].std(),
          "pval max", record_list_df[record_list_df["num_pos"]>=30]["pval_larger_pos"].max(),
          "pval min", record_list_df[record_list_df["num_pos"]>=30]["pval_larger_pos"].min()          
         )
    
#     print("larger_neg")
#     print(mean_confidence_interval2(record_list_df[record_list_df["num_pos"]>=30]["prop_larger_neg"]),
#           record_list_df[record_list_df["num_pos"]>=30]["prop_larger_neg"].mean(),
#           record_list_df[record_list_df["num_pos"]>=30]["prop_larger_neg"].std(),
#           mean_confidence_interval2(record_list_df[record_list_df["num_pos"]>=30]["pval_larger_neg"]),
#           "pval mean", record_list_df[record_list_df["num_pos"]>=30]["pval_larger_neg"].mean(),
#           "pval median", record_list_df[record_list_df["num_pos"]>=30]["pval_larger_neg"].median(),
#           "pval std", record_list_df[record_list_df["num_pos"]>=30]["pval_larger_neg"].std()
#          )
    
    print("larger_all")
    print(mean_confidence_interval2(record_list_df[record_list_df["num_pos"]>=30]["prop_larger_all"]),
          "prop mean", record_list_df[record_list_df["num_pos"]>=30]["prop_larger_all"].mean(),
          "prop std", record_list_df[record_list_df["num_pos"]>=30]["prop_larger_all"].std(),
          "prop max", record_list_df[record_list_df["num_pos"]>=30]["prop_larger_all"].max(),
          "prop min", record_list_df[record_list_df["num_pos"]>=30]["prop_larger_all"].min(),
          mean_confidence_interval2(record_list_df[record_list_df["num_pos"]>=30]["pval_larger_all"]),
          "pval mean", record_list_df[record_list_df["num_pos"]>=30]["pval_larger_all"].mean(),
          "pval median", record_list_df[record_list_df["num_pos"]>=30]["pval_larger_all"].median(),
          "pval std", record_list_df[record_list_df["num_pos"]>=30]["pval_larger_all"].std(),
          "pval max", record_list_df[record_list_df["num_pos"]>=30]["pval_larger_all"].max(),
          "pval min", record_list_df[record_list_df["num_pos"]>=30]["pval_larger_all"].min()
         )    
    
    
        
temp()

In [None]:
0.00034191898671290515*21

In [None]:
1.4730998132949117e-19*21

In [None]:
from scipy import stats

In [None]:
from statsmodels.stats import multitest


# .multipletests

In [None]:
multitest.multipletests([0.01,0.02, 0.05, 0.067], method="bonferroni")

In [None]:
stats.

In [None]:
data_audit_loaded=torch.load("logs/experiment_results/concept_annotation_data_auditing_0930.pt", map_location="cpu")

In [None]:
compare_manaul_0930=torch.load("logs/experiment_results/compare_manaul_0930.pt", map_location="cpu")

In [None]:
data_audit_loaded["clinical_fd_clean_nodup_nooverlap"].keys()

In [None]:
compare_manaul_0930["clinical_fd_clean_nodup_nooverlap"].keys()

In [None]:
compare_manaul_0930["clinical_fd_clean_nodup_nooverlap"]["concept_eval"]

In [None]:
compare_manaul_0930["clinical_fd_clean_nodup_nooverlap"]["skintone_eval"]

In [None]:
1635

In [None]:
get_concept_bool_from_metadata??

In [None]:
26+1691

In [None]:
import clip

In [None]:
model, preprocess = clip.load('ViT-B/32', device)

In [None]:
0.863+0.105

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["valid_idx_malignant"].shape

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["valid_idx_malignant"].sum()

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"].keys()

In [None]:
0.8639559576+1.96*0.022826282925137

In [None]:
0.10460316933/np.sqrt(21)

In [None]:
pd.Series([1, 2, 3]).sem()