In [None]:
import os

import hydra
import omegaconf
import pyrootutils

root = pyrootutils.setup_root(os.path.abspath("concept_annotation.ipynb"), pythonpath=True)

os.chdir(root)

In [None]:
import sys

sys.path.append(str(root / "src"))

In [None]:
import os
import time
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib import gridspec
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn.metrics
import torch
import tqdm
from IPython import display
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

import clip

In [None]:
# !cp /System/Library/Fonts/Supplemental ~/.local/share/fonts/
# rm -fr ~/.cache/matplotlib
from matplotlib import font_manager
from matplotlib.lines import Line2D
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.edgecolor']='1.0'
plt.rcParams['legend.framealpha']=0

# https://github.com/dsc/colorbrewer-python/blob/master/colorbrewer.py

Set1 = {
    3: [[228,26,28], [55,126,184], [77,175,74]],
    4: [[228,26,28], [55,126,184], [77,175,74], [152,78,163]],
    5: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0]],
    6: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51]],
    7: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40]],
    8: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191]],
    9: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191], [153,153,153]],
}

Paired = {
    3: [(166,206,227), [31,120,180], [178,223,138]],
    4: [[166,206,227], [31,120,180], [178,223,138], [51,160,44]],
    5: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153]],
    6: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28]],
    7: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111]],
    8: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0]],
    9: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214]],
    10: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154]],
    11: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153]],
    12: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153], [177,89,40]]
}

color_qual_7=['#F53345',
            '#87D303',
            '#04CBCC',
            '#8650CD',
            (160/256, 95/256, 0),
            '#F5A637',              
            '#DBD783',            
             ]

pd.set_option('display.max_rows', 500)

In [None]:
import scipy.special
import tqdm.contrib.concurrent

In [None]:
from MONET.datamodules.multiplex_datamodule import MultiplexDatamodule
from MONET.utils.loader import custom_collate_per_key, dataloader_apply_func
from MONET.utils.metrics import skincon_calcualte_auc_all
from MONET.utils.static import (
    concept_to_prompt,
    fitzpatrick17k_disease_label,
    fitzpatrick17k_ninelabel,
    fitzpatrick17k_threelabel,
    skincon_cols,
)
from MONET.utils.text_processing import generate_prompt_token_from_concept

In [None]:
def wandb_to_exppath(wandb, log_path="/gscratch/cse/chanwkim/MONET_log/train/runs"):
    log_path = Path(log_path)
    for experiment in os.listdir(log_path):
        if os.path.exists(log_path / experiment / "wandb"):
            filenames = os.listdir(log_path / experiment / "wandb")
            filename = [filename for filename in filenames if filename.startswith("run")][0][-8:]
            if filename == wandb:
                return log_path / experiment
    raise RuntimeError("not found")


exppath = wandb_to_exppath(
    wandb="baqqmm5v", log_path="/projects/leelab2/chanwkim/dermatology_datasets/logs/train/runs"
)
print([exppath / "checkpoints" / ckpt for ckpt in os.listdir(exppath / "checkpoints/")])

In [None]:
variable_dict={}

In [None]:
def setup_dataloader(dataset_name):    
    if dataset_name=="clinical_fd_clean_nodup_nooverlap":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "clinical_fd_clean_nodup_nooverlap=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()  
        
    elif dataset_name=="derm7pt_derm_nodup":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt_derm_nodup=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()  
        
    elif dataset_name=="derm7pt_clinical_nodup":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt_clinical_nodup=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()          
        
        
    elif dataset_name=="allpubmedtextbook":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "pubmed=all,textbook=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()
        
        dataloader = dm.test_dataloader()          
    
    elif dataset_name=="isic_nodup_nooverlap":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "isic_nodup_nooverlap=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()    
        
    return {"dataloader": dataloader}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup", "derm7pt_clinical_nodup",
                     "isic_nodup_nooverlap", "allpubmedtextbook"
                    ]:
    variable_dict.setdefault(dataset_name, {})
    variable_dict[dataset_name].update(setup_dataloader(dataset_name))

In [None]:
!gpustat

# initialize model

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:3"

In [None]:
cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

In [None]:
model = hydra.utils.instantiate(cfg_model)
model.to(model_device)
model.eval()

In [None]:
model_path_dir = {
    "zt0n2xd0": "logs/train/runs/2023-01-17_20-58-15/checkpoints/last.ckpt",
}
if model_name != "ViT-L/14":
    model_path = model_path_dir[model_name]
    loaded = torch.load(model_path, map_location=model_device)
    model.load_state_dict(loaded["state_dict"])
    print('loaded')

In [None]:
# torch.save(loaded["state_dict"], "/cse/web/research/aimslab/MONET/weight.pt")

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:3"

cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

model_vanilla = hydra.utils.instantiate(cfg_model)
model_vanilla.to(model_device)
model_vanilla.eval()

# Image features

In [None]:
def batch_func(batch):
    with torch.no_grad():
        batch["image"] = batch["image"].to(model_device)
        image_features = model.model_step_with_image(batch)["image_features"]
    # print(batch["metadata"])
    return {
        "image_features": image_features.detach().cpu(),
        "metadata": batch["metadata"],
    }

In [None]:
def setup_features(dataset_name, dataloader):
    if dataset_name=="isic_nodup_nooverlap":
        loader_applied = torch.load(log_dir/"image_features"/"isic_nodup_nooverlap.pt", map_location="cpu")
        image_features = loader_applied["image_features"].cpu()
        metadata_all = loader_applied["metadata_all"]   
        
        loader_applied2 = torch.load(log_dir/"image_features"/"isic_nodup_nooverlap_vanilla.pt", map_location="cpu")
        image_features_vanilla = loader_applied2["image_features_vanilla"].cpu()
        metadata_all_vanilla = loader_applied2["metadata_all"]  
        
        assert np.all(metadata_all.index==metadata_all_vanilla.index)

        return {"image_features":image_features, 
                "image_features_vanilla":image_features_vanilla,
                "metadata_all": metadata_all}
        
    elif dataset_name=="allpubmedtextbook":
        loader_applied = torch.load(log_dir/"image_features"/"allpubmedtextbook.pt", map_location="cpu")
        image_features = loader_applied["image_features"].cpu()
        metadata_all = loader_applied["metadata_all"]   
    else:
        loader_applied = dataloader_apply_func(
            dataloader=dataloader,
            func=batch_func,
            collate_fn=custom_collate_per_key,
        )
        image_features = loader_applied["image_features"].cpu()
        metadata_all = loader_applied["metadata"]

    return {"image_features":image_features, "metadata_all": metadata_all}

In [None]:
log_dir = Path("logs")

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", "allpubmedtextbook"]:
    print(dataset_name)
    variable_dict[dataset_name].update(setup_features(dataset_name, variable_dict[dataset_name]["dataloader"])) 

In [None]:
isic_diagnosis_malignant_mapping=\
{'AIMP':'indeterminate',
'acrochordon':'benign',
'actinic keratosis':'benign', # 
'angiofibroma or fibrous papule':'benign', 
'angiokeratoma':'benign',
'angioma':'benign',
'atypical melanocytic proliferation':'indeterminate',
'atypical spitz tumor':'indeterminate', #
'basal cell carcinoma':'malignant', #
'cafe-au-lait macule':'benign',
'clear cell acanthoma':'benign',
'dermatofibroma':'benign', #
'lentigo NOS':'benign',
'lentigo simplex':'benign',
'lichenoid keratosis':'benign',
'melanoma':'malignant',
'melanoma metastasis':'malignant',
'neurofibroma':'benign',
'nevus':'benign',
'other':'indeterminate',
'pigmented benign keratosis':'benign', #??
'scar':'benign',
'seborrheic keratosis':'benign',
'solar lentigo':'benign',
'squamous cell carcinoma':'malignant',
'vascular lesion':'benign', # 
'verruca':'benign'
}

def isic_map_diagnosis_malignant(diagnosis, benign_malignant):
#     if diagnosis=="basal cell carcinoma":
#         print(diagnosis_malignant_mapping[diagnosis], benign_malignant)    
    if isinstance(benign_malignant, str):
        return benign_malignant
    elif diagnosis in isic_diagnosis_malignant_mapping.keys():
        return isic_diagnosis_malignant_mapping[diagnosis]
    elif np.isnan(diagnosis):
        return "indeterminate"
    else:
        raise RuntimeError

In [None]:
derm7pt_diagnosis_malignant_mapping=\
{   'basal cell carcinoma':'malignant', 
    'blue nevus': 'benign', #
    'clark nevus':'benign', # 
    'combined nevus': 'benign', #
    'congenital nevus': 'benign', #
    'dermal nevus': 'benign', 
    'dermatofibroma':'benign', 
    'lentigo': 'benign',
    'melanoma (in situ)': 'malignant',
    'melanoma (less than 0.76 mm)': 'malignant',
    'melanoma (0.76 to 1.5 mm)': 'malignant',
    'melanoma (more than 1.5 mm)': 'malignant',
    'melanoma metastasis': 'malignant',
    'melanosis': 'benign',# 
    'miscellaneous': 'unknown', #
    'recurrent nevus': 'benign', #
    'reed or spitz nevus': 'benign', #
    'seborrheic keratosis':'benign',
    'vascular lesion': 'benign', 
    'melanoma': 'malignant',
}

def derm7pt_map_diagnosis_malignant(diagnosis):
#     if diagnosis=="basal cell carcinoma":
#         print(diagnosis_malignant_mapping[diagnosis], benign_malignant)    
    if diagnosis in derm7pt_diagnosis_malignant_mapping.keys():
        return derm7pt_diagnosis_malignant_mapping[diagnosis]
    elif np.isnan(diagnosis):
        return "indeterminate"
    else:
        raise RuntimeError

In [None]:
def set_config(dataset_name, metadata_all):
    if "clinical_fd_clean" in dataset_name:
        y_pos=(((metadata_all["source"]=="fitz")&(metadata_all["three_partition_label"]=="malignant"))|
              ((metadata_all["source"]=="ddi")&(metadata_all["malignant"] == True))).values
        
        valid_idx=(metadata_all["skincon_Do not consider this image"]!=1).values
        
        concept_list=skincon_cols
        
    elif "isic" in dataset_name:  
        metadata_all["benign_malignant_full"]=\
        metadata_all.apply(lambda x: isic_map_diagnosis_malignant(x["diagnosis"], x["benign_malignant"]), axis=1)
        #metadata_all["benign_malignant_full"].value_counts()
        #metadata_all.groupby("diagnosis").apply(lambda x: x["benign_malignant_full"].value_counts())
        metadata_all["benign_malignant_bool"]=metadata_all["benign_malignant_full"].str.contains("malignant")
        
        y_pos=metadata_all["benign_malignant_bool"].values
        
        valid_idx = (metadata_all["benign_malignant_full"].str.contains("malignant")|metadata_all["benign_malignant_full"].str.contains("benign")).values
        
        concept_list=skincon_cols
        
        concept_list=concept_list+["purple pen", 
                     "nail", 
                     "pinkish", 
                     "red", 
                     "hair", 
                     "orange sticker", 
                     "dermoscope border",
                     "gel",]      
        
        
        concept_list=concept_list+[f"derm7ptconcept_{derm7ptconcept}" for derm7ptconcept in ["pigment network", "typical pigment network", "atypical pigment network",
                                   "regression structure",
                                   "pigmentation", "regular pigmentation", "irregular pigmentation",
                                   "blue whitish veil", 
                                   "vascular structures", "typical vascular structures", "atypical vascular structures",
                                   "streaks", "regular streaks", "irregular streaks",
                                   "dots and globules", "regular dots and globules", "irregular dots and globules",
                                  ]]
        
        concept_list=concept_list+[f"isicconcept_{isicconcept}" for isicconcept in ["pigment_network", 
                                                                                   "negative_network",
                                                                                   "milia_like_cyst", 
                                                                                   "streaks", 
                                                                                   "globules"]]
        
        concept_list=concept_list+[f"disease_{disease_name}" for disease_name in ['seborrheic keratosis', 'nevus', 'squamous cell carcinoma',
                        'melanoma', 'lichenoid keratosis', 'lentigo',
                        'actinic keratosis', 'basal cell carcinoma', 'dermatofibroma',
                        'atypical melanocytic proliferation', 'verruca',
                        'clear cell acanthoma', 'angiofibroma or fibrous papule', 'scar',
                        'angioma', 'atypical spitz tumor', 'solar lentigo', 'AIMP',
                        'neurofibroma', 'lentigo simplex', 'acrochordon', 
                        'angiokeratoma', 'vascular lesion', 'cafe-au-lait macule',
                        'pigmented benign keratosis']]

        
    elif "derm7pt_clinical" in dataset_name:  
        metadata_all["benign_malignant_full"]=\
        metadata_all.apply(lambda x: derm7pt_map_diagnosis_malignant(x["diagnosis"]), axis=1)
        metadata_all["benign_malignant_bool"]=metadata_all["benign_malignant_full"].str.contains("malignant")
        
        y_pos=metadata_all["benign_malignant_bool"].values
        
        valid_idx = (~metadata_all["diagnosis"].isnull()).values
        
        concept_list=skincon_cols        
        
    elif "derm7pt_derm" in dataset_name:  
        metadata_all["benign_malignant_full"]=\
        metadata_all.apply(lambda x: derm7pt_map_diagnosis_malignant(x["diagnosis"]), axis=1)
        metadata_all["benign_malignant_bool"]=metadata_all["benign_malignant_full"].str.contains("malignant")
        
        y_pos=metadata_all["benign_malignant_bool"].values
        
        valid_idx = (~metadata_all["diagnosis"].isnull()).values
        
        concept_list=skincon_cols
        
        concept_list=concept_list+["purple pen", 
                             "nail", 
                             "pinkish", 
                             "red", 
                             "hair", 
                             "orange sticker", 
                             "dermoscope border",
                             "gel",]    
        
        concept_list=concept_list+[f"derm7ptconcept_{derm7ptconcept}" for derm7ptconcept in ["pigment network", "typical pigment network", "atypical pigment network",
                                   "regression structure",
                                   "pigmentation", "regular pigmentation", "irregular pigmentation",
                                   "blue whitish veil", 
                                   "vascular structures", "typical vascular structures", "atypical vascular structures",
                                   "streaks", "regular streaks", "irregular streaks",
                                   "dots and globules", "regular dots and globules", "irregular dots and globules",
                                  ]]
        
        concept_list=concept_list+[f"isicconcept_{isicconcept}" for isicconcept in ["pigment_network", 
                                                                                   "negative_network",
                                                                                   "milia_like_cyst", 
                                                                                   "streaks", 
                                                                                   "globules"]]
             
        
        concept_list=concept_list+[f"disease_{disease_name}" for disease_name in ['basal cell carcinoma', 'blue nevus', 'clark nevus',
                                                               'combined nevus', 'congenital nevus', 'dermal nevus',
                                                               'dermatofibroma', 'lentigo', 'melanoma', 'melanosis',
                                                                'recurrent nevus', 'reed or spitz nevus',
                                                               'seborrheic keratosis', 'vascular lesion']]    
        
        
    elif "allpubmedtextbook" in dataset_name:  
        y_pos=None
        
        valid_idx = (~variable_dict["allpubmedtextbook"]["metadata_all"].index.isnull())
        
        concept_list=None         
        
    return {"valid_idx": valid_idx,
            "y_pos": y_pos,
            "metadata_all": metadata_all,
            "concept_list": concept_list}

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", "allpubmedtextbook"
                    ]:
    variable_dict[dataset_name].update(
        set_config(dataset_name, variable_dict[dataset_name]["metadata_all"])
    )

In [None]:
def normalize_embedding(dataset_name, image_features):
    #prompt_ref_tokenized = clip.tokenize(prompt_ref, truncate=True)
    #output = model.model_step_with_text({"text": prompt_ref_tokenized.to(model_device)})
    #prompt_ref_embedding=output["text_features"].detach().cpu()
    #prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=1, keepdim=True)      
    
    image_features_norm = image_features / image_features.norm(dim=1, keepdim=True)
    
    return {"image_features_norm": image_features_norm}    

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap","derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", "allpubmedtextbook"
                    ]:
    variable_dict[dataset_name].update(
    normalize_embedding(dataset_name, 
                        variable_dict[dataset_name]["image_features"])
    )

In [None]:
for dataset_name in ["isic_nodup_nooverlap", 
                    ]:
    variable_dict[dataset_name].update(
        {"image_features_vanilla_norm":normalize_embedding(dataset_name, 
                        variable_dict[dataset_name]["image_features_vanilla"])["image_features_norm"]}
    )           
    variable_dict[dataset_name].update(
        {"image_features_norm":normalize_embedding(dataset_name, 
                        variable_dict[dataset_name]["image_features"])["image_features_norm"]}
    )  
#     variable_dict[dataset_name].update(
#         {"image_features_vanilla_norm":normalize_embedding(dataset_name, 
#                         variable_dict[dataset_name]["image_features_vanilla"])["image_features_norm"]}
#     )    

In [None]:
def get_concept_embedding(dataset_name, concept_list, clip_model):
    prompt_info={}
    
    for concept_name in concept_list:
        if "clinical_fd_clean" in dataset_name:
            prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
            prompt_engineered_list = []
            for k, v in prompt_dict.items():
                if k != "original":
                    prompt_engineered_list += v    
            concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                      for prompt in prompt_engineered_list]))
            prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
            #prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]            
            prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
            
            prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]        
        
        elif "isic" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
                
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil","blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)         
                    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                     
                
            elif concept_name.startswith("isicconcept_"):
                isicconcept=concept_name[12:]
                if isicconcept=="pigment_network":
                    concept_term_list=["pigment network"]
                elif isicconcept=="negative_network":
                    concept_term_list=["negative network"]
                elif isicconcept=="milia_like_cyst":
                    concept_term_list=["milia like cyst"]
                    concept_term_list=["seborrheic keratosis"]
                elif isicconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif isicconcept=="globules":
                    concept_term_list=["globules"]
                else:
                    raise ValueError(isicconcept)                
            
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                 
                
            elif concept_name.startswith("disease_"):  
                if concept_name=="disease_AIMP":
                    disease_name=concept_name[8:]
                    prompt_target=[["This is dermatoscopy of AIMP",
                                    "This is dermatoscopy of Atypical intraepidermal melanocytic proliferation"],
                                   ["This is dermoscopy of AIMP",
                                    "This is dermoscopy of Atypical intraepidermal melanocytic proliferation"]]
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                else:
                    disease_name=concept_name[8:]
                    prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                                   [f"This is dermoscopy of {disease_name}"]] 
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                    
        elif "derm7pt_clinical" in dataset_name:
            prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
            prompt_engineered_list = []
            for k, v in prompt_dict.items():
                if k != "original":
                    prompt_engineered_list += v    
            concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                      for prompt in prompt_engineered_list]))
            prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
            #prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]            
            prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
            
            prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                      
                    
        elif "derm7pt_derm" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil","blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)
    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
            elif concept_name.startswith("isicconcept_"):
                isicconcept=concept_name[12:]
                if isicconcept=="pigment_network":
                    concept_term_list=["pigment network"]
                elif isicconcept=="negative_network":
                    concept_term_list=["negative network"]
                elif isicconcept=="milia_like_cyst":
                    concept_term_list=["milia like cyst"]
                    concept_term_list=["seborrheic keratosis"]
                elif isicconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif isicconcept=="globules":
                    concept_term_list=["globules"]
                else:
                    raise ValueError(isicconcept)                
            
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                  
                
            elif concept_name.startswith("disease_"):  
                disease_name=concept_name[8:]
                prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                               [f"This is dermoscopy of {disease_name}"]] 
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                    
                    
#         prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
#         #prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]            
#         prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]

#         prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                        
                
        
        #print(prompt_target, prompt_ref)
        # target embedding
        prompt_target_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_target]
        with torch.no_grad():
            prompt_target_embedding = torch.stack([clip_model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_target_tokenized])
        prompt_target_embedding_norm=prompt_target_embedding/prompt_target_embedding.norm(dim=2, keepdim=True)          

        # reference embedding
        prompt_ref_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_ref]
        with torch.no_grad():
            prompt_ref_embedding = torch.stack([clip_model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_ref_tokenized])
        prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=2, keepdim=True)                

        prompt_info[concept_name]={"prompt_ref_embedding_norm":prompt_ref_embedding_norm,
                                   "prompt_target_embedding_norm":prompt_target_embedding_norm,
                                  }
        print(dataset_name, concept_name, "\n" , prompt_target, prompt_ref)
        print('-----------')
        #print(prompt_target_embedding_norm)
        #print(prompt_ref_embedding_norm)
        del prompt_ref
        del prompt_target
    
    return {"prompt_info": prompt_info}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap","derm7pt_derm_nodup",
                     "isic_nodup_nooverlap"]:
# for dataset_name in ["isic_nodup_nooverlap"]:
    variable_dict[dataset_name].update(
        {"prompt_info": get_concept_embedding(dataset_name, 
                              concept_list=variable_dict[dataset_name]["concept_list"],
                             clip_model=model
                             )["prompt_info"]})

In [None]:
for dataset_name in ["isic_nodup_nooverlap"]:
    variable_dict[dataset_name].update(
        {"prompt_info_vanilla": get_concept_embedding(dataset_name, 
                              concept_list=variable_dict[dataset_name]["concept_list"],
                             clip_model=model_vanilla
                             )["prompt_info"]})

In [None]:
variable_dict["isic_nodup_nooverlap"]["metadata_all"][["attribution", "copyright_license"]].fillna(-9).value_counts().to_frame()

In [None]:
variable_dict["isic_nodup_nooverlap"]["metadata_all"]["attribution"]=variable_dict["isic_nodup_nooverlap"]["metadata_all"]["attribution"].str.replace("ViDIR group", "ViDIR Group")

In [None]:
variable_dict["isic_nodup_nooverlap"]["metadata_all"][["attribution", "copyright_license"]].fillna(-9).value_counts().to_frame()

In [None]:
variable_dict["isic_nodup_nooverlap"]["metadata_all"]["attribution"].value_counts().index[3]

In [None]:
variable_dict["isic_nodup_nooverlap"]["metadata_all"]["attribution"].value_counts().index[6]

In [None]:
!ls data/isic

In [None]:
import scipy.stats

In [None]:
def get_concept_bool_from_metadata(dataset_name, metadata_all, concept_name):
    if "derm7pt" in dataset_name:
        if concept_name=="derm7ptconcept_pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
            concept_bool=(metadata_all["pigment_network"]!="absent")
        elif concept_name=="derm7ptconcept_typical pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
#                 valid_idx=(metadata_all["pigment_network"].str.contains("typical")).values
            concept_bool=(metadata_all["pigment_network"]=="typical")
        elif concept_name=="derm7ptconcept_atypical pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
#                 valid_idx=(metadata_all["pigment_network"].str.contains("typical")).values
            concept_bool=(metadata_all["pigment_network"]=="atypical")
        elif concept_name=="derm7ptconcept_regression structure":
            valid_idx=(~metadata_all["regression_structures"].isnull()).values
            concept_bool=(metadata_all["regression_structures"]!="absent")
        elif concept_name=="derm7ptconcept_pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
            concept_bool=(metadata_all["pigmentation"]!="absent")
        elif concept_name=="derm7ptconcept_regular pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
#                 valid_idx=(metadata_all["pigmentation"].str.contains("regular")).values
            concept_bool=(metadata_all["pigmentation"].str.contains(" regular"))
        elif concept_name=="derm7ptconcept_irregular pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
#                 valid_idx=(metadata_all["pigmentation"].str.contains("regular")).values
            concept_bool=(metadata_all["pigmentation"].str.contains(" irregular"))
        elif concept_name=="derm7ptconcept_blue whitish veil":
            valid_idx=(~metadata_all["blue_whitish_veil"].isnull()).values
            concept_bool=(metadata_all["blue_whitish_veil"]!="absent")
        elif concept_name=="derm7ptconcept_vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"]!="absent")
        elif concept_name=="derm7ptconcept_typical vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"].isin(["within regression", "arborizing", "comma", "hairpin", "wreath"]))
        elif concept_name=="derm7ptconcept_atypical vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"].isin(["dotted", "linear irregular"]))
        elif concept_name=="derm7ptconcept_streaks":
            #print(metadata_all["streaks"].value_counts())
            valid_idx=(~metadata_all["streaks"].isnull()).values
            concept_bool=(metadata_all["streaks"]!="absent")
        elif concept_name=="derm7ptconcept_regular streaks":
            valid_idx=(~metadata_all["streaks"].isnull()).values
            #valid_idx=(metadata_all["streaks"].str.contains("regular")).values
            concept_bool=(metadata_all["streaks"]=="regular")
        elif concept_name=="derm7ptconcept_irregular streaks":
            valid_idx=(~metadata_all["streaks"].isnull()).values
            #valid_idx=(metadata_all["streaks"].str.contains("regular")).values
            concept_bool=(metadata_all["streaks"]=="irregular")
        elif concept_name=="derm7ptconcept_dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]!="absent")
        elif concept_name=="derm7ptconcept_regular dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]=="regular")
        elif concept_name=="derm7ptconcept_irregular dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]=="irregular")
        else:
            raise ValueError(concept_name)
            
        concept_bool_true=concept_bool
        concept_bool_false=(~concept_bool)


    elif "isic" in dataset_name:
        if concept_name=='isicconcept_pigment_network':
            label=(metadata_all["pigment_network"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values
            concept_bool=(label>30)
        elif concept_name=='isicconcept_negative_network':
            label=(metadata_all["negative_network"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_milia_like_cyst':
            label=(metadata_all["milia_like_cyst"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_streaks':
            label=(metadata_all["streaks"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_globules':
            label=(metadata_all["globules"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values
            concept_bool=(label>30)
        else:
            raise ValueError(concept_name)            
            
    elif "clinical_fd_clean" in dataset_name:
        if concept_name.startswith("skincon_"):
            concept_bool_true=(metadata_all[concept_name]==1)
            concept_bool_false=(metadata_all[concept_name]==0)
        else:
            raise ValueError(concept_name)
        
      
    return {"concept_bool_true": concept_bool_true,
            "concept_bool_false": concept_bool_false,
           }

In [None]:
variable_dict["derm7pt_derm_nodup"].keys()

In [None]:
!ls logs/experiment_results/

In [None]:
variable_dict=torch.load("logs/experiment_results/concept_annotation_data_auditing_0823.pt", map_location="cpu")

In [None]:
# torch.save(variable_dict, "logs/experiment_results/concept_annotation_data_auditing_0823.pt")

In [None]:
# torch.save(variable_dict, "logs/experiment_results/data_audit_new_0429.pt")

In [None]:
variable_dict=torch.load("logs/experiment_results/data_audit_new.pt", map_location="cpu")

In [None]:
variable_dict["isic"]["image_features_norm"]

In [None]:
variable_dict["isic"]["prompt_info"]["skincon_Vesicle"]["prompt_target_embedding_norm"]

In [None]:
! ls -lt logs/experiment_results/

In [None]:
variable_dict=torch.load("logs/experiment_results/concept_annotation_data_auditing_0930.pt", map_location="cpu")

# concept retrieval

In [None]:
# def calculate_similaity_score(image_features_norm, 
#                               prompt_target_embedding_norm,
#                               prompt_ref_embedding_norm,
#                               normalize=True):

#     target_similarity=prompt_target_embedding_norm.float()@image_features_norm.T.float()
#     ref_similarity=prompt_ref_embedding_norm.float()@image_features_norm.T.float()
    
    
#     target_similarity_mean=target_similarity.mean(dim=[1])
#     ref_similarity_mean=ref_similarity.mean(axis=1)
         
#     if normalize:
#         similarity_score=scipy.special.softmax([target_similarity_mean.numpy(), 
#                             ref_similarity_mean.numpy()], axis=0)[0,:].mean(axis=0)   
#     else:
#         similarity_score=target_similarity_mean.mean(axis=0)

    
#     return similarity_score

In [None]:
def calculate_similaity_score(image_features_norm, 
                              prompt_target_embedding_norm,
                              prompt_ref_embedding_norm,
                              temp=1,
                              normalize=True):

    target_similarity=prompt_target_embedding_norm.float()@image_features_norm.T.float()
    ref_similarity=prompt_ref_embedding_norm.float()@image_features_norm.T.float()
    
    
    target_similarity_mean=target_similarity.mean(dim=[1])
    ref_similarity_mean=ref_similarity.mean(axis=1)
#     import ipdb
#     ipdb.set_trace()
    if normalize:
        similarity_score=scipy.special.softmax([target_similarity_mean.numpy()/temp, 
                            ref_similarity_mean.numpy()/temp], axis=0)[0,:].mean(axis=0)   
    else:
        similarity_score=target_similarity_mean.mean(axis=0)

    
    return similarity_score

In [None]:
shorten_concept_name("skincon_Brown(Hyperpigmentation)")

In [None]:
def shorten_concept_name(concept_name):
    if concept_name=="skincon_Erythema":
        short_name="Erythema"
    elif concept_name=="skincon_Bulla":
        short_name="Bulla"
    elif concept_name=="skincon_Lichenification":
        short_name="Lichenification"
    elif concept_name=="skincon_Pustule":
        short_name="Pustule"
    elif concept_name=="skincon_Ulcer":
        short_name="Ulcer"
    elif concept_name=="skincon_Warty/Papillomatous":
        short_name="Warty"
    elif concept_name=="skincon_White(Hypopigmentation)":
        short_name="Hypopigmentation"
    elif concept_name=="skincon_Brown(Hyperpigmentation)":
        short_name="Hyperpigmentation"        
    elif concept_name=="purple pen":
        short_name="Purple pen"
    elif concept_name=="nail":
        short_name="Nail"  
    elif concept_name=="orange sticker":
        short_name="Orange sticker"          
    elif concept_name=="hair":
        short_name="Hair"          
    elif concept_name=="gel":
        short_name="Gel"
    elif concept_name=="dermoscope border":
        short_name="Dermoscopic border"
        
    elif concept_name=="derm7ptconcept_pigment network":
        short_name="Pigment network"        
    elif concept_name=="derm7ptconcept_regression structure":
        short_name="Regression structure"        
    elif concept_name=="derm7ptconcept_pigmentation":
        short_name="Pigmentation"        
    elif concept_name=="derm7ptconcept_blue whitish veil":
        short_name="Blue whitish veil"        
    elif concept_name=="derm7ptconcept_vascular structures":
        short_name="Vascular structures"        
    elif concept_name=="derm7ptconcept_streaks":
        short_name="Streaks"        
    elif concept_name=="derm7ptconcept_dots and globules":
        short_name="Dots and globules"
       
    else:
        if concept_name.startswith("skincon_"):
            short_name=concept_name[8:]
        elif concept_name.startswith("derm7ptconcept_"):
            short_name=concept_name[15:]   
        elif concept_name.startswith("isicconcept_"):
            short_name=concept_name[12:]
        else:
            short_name=concept_name
            
    return short_name

In [None]:
def check_image_default(dataset_name, idx):
    if "clinical_fd_clean_nodup_nooverlap" in dataset_name:
        #print(idx)
        if idx in ["58b4bc079ca94e6e9377a42ca7564b40.jpg",
                 "720cf31558966c82c118ab75b50632eb.jpg",
                 "5f046cda32a3cc547205662e7be774f9.jpg",
                 "d8bf377acc45a3beb0c6e81bf7ac1ff5.jpg"]:
            print("Excluded not appropriate image", dataset_name, idx)
            return False
        else:
            return True
    elif "isic" in dataset_name:
        license=variable_dict[dataset_name]["metadata_all"].loc[idx]["copyright_license"]
        attribution=variable_dict[dataset_name]["metadata_all"].loc[idx]["attribution"]
        
        if license=="CC-BY-NC":  
            if attribution=="MSKCC":
                print("Excluded CC-BY-NC MSKCC", dataset_name, idx)
                return False
            elif attribution=="Pascale Guitera":
                print("Excluded CC-BY-NC Pascale", dataset_name, idx)
                return False                        
            elif attribution=="Hospital Clínic de Barcelona":
                print("Included but note that CC-BY-NC Barcelona", dataset_name, idx)
                return True
            elif attribution=="ViDIR Group, Department of Dermatology, Medical University of Vienna":
                print("Included but note that CC-BY-NC Vienna", dataset_name, idx)
                return True            
            elif attribution=="Department of Dermatology, Hospital Clínic de Barcelona":
                print("Included but note that CC-BY-NC Barcelona", dataset_name, idx)
                return True                     
            else:
                raise ValueError(attribution)
        elif license=="CC-0":
            return True
        elif license=="CC-BY":
            print(f"Included but note that CC-BY {attribution}", dataset_name, idx)
            return True
        else: 
            raise ValueError(license)
    elif "derm7pt" in dataset_name:
        print("Included but note that permission", dataset_name, idx)
        return True            
        
#     elif "isic" in dataset_name:
#         if idx in ["ISIC_0056928",
#                    "ISIC_0061557",
#                    "ISIC_0057348",
#                     "ISIC_0025542",
#                     "ISIC_0061046",
#                     "ISIC_0029760"]:
#             return False
#         else:
#             return True
    else:
        return True

In [None]:
def check_image_ddionly(dataset_name, idx):
    if "clinical_fd_clean_nodup_nooverlap" in dataset_name:
        #print(idx)
#         if idx in ["58b4bc079ca94e6e9377a42ca7564b40.jpg",
#                  "720cf31558966c82c118ab75b50632eb.jpg",
#                  "5f046cda32a3cc547205662e7be774f9.jpg",
#                  "d8bf377acc45a3beb0c6e81bf7ac1ff5.jpg"]:
#             print("Excluded not appropriate image", dataset_name, idx)
#             return False
#         else:
#             return True
        #print(idx)
        if variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"].loc[idx]["source"]=="ddi":
            return True
        else:
            return False
    elif "isic" in dataset_name:
        license=variable_dict[dataset_name]["metadata_all"].loc[idx]["copyright_license"]
        attribution=variable_dict[dataset_name]["metadata_all"].loc[idx]["attribution"]
        
        if license=="CC-BY-NC":  
            if attribution=="MSKCC":
                print("Excluded CC-BY-NC MSKCC", dataset_name, idx)
                return False
            elif attribution=="Pascale Guitera":
                print("Excluded CC-BY-NC Pascale", dataset_name, idx)
                return False                        
            elif attribution=="Hospital Clínic de Barcelona":
                print("Included but note that CC-BY-NC Barcelona", dataset_name, idx)
                return True
            elif attribution=="ViDIR Group, Department of Dermatology, Medical University of Vienna":
                print("Included but note that CC-BY-NC Vienna", dataset_name, idx)
                return True            
            elif attribution=="Department of Dermatology, Hospital Clínic de Barcelona":
                print("Included but note that CC-BY-NC Barcelona", dataset_name, idx)
                return True                     
            else:
                raise ValueError(attribution)
        elif license=="CC-0":
            return True
        elif license=="CC-BY":
            print(f"Included but note that CC-BY {attribution}", dataset_name, idx)
            return True
        else: 
            raise ValueError(license)
    elif "derm7pt" in dataset_name:
        print("Included but note that permission", dataset_name, idx)
        return True
        
#     elif "isic" in dataset_name:
#         if idx in ["ISIC_0056928",
#                    "ISIC_0061557",
#                    "ISIC_0057348",
#                     "ISIC_0025542",
#                     "ISIC_0061046",
#                     "ISIC_0029760"]:
#             return False
#         else:
#             return True
    else:
        return True

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"]

In [None]:
variable_dict["isic_nodup_nooverlap"]["metadata_all"][variable_dict["isic_nodup_nooverlap"]["metadata_all"]["copyright_license"]=="CC-BY-NC"]\
["attribution"].value_counts()

In [None]:
variable_dict["isic_nodup_nooverlap"]["metadata_all"]["copyright_license"].value_counts()

In [None]:
def plot_main(dataset_concept_list_list, check_image, fontsize=30, normalize=True):
    example_per_concept=10


    fig = plt.figure(figsize=(3*(10), 
                              3*(example_per_concept//10)*(len([j for i in dataset_concept_list_list for j in i]))))

    box1 = gridspec.GridSpec(len(dataset_concept_list_list), 1,
                             wspace=0.0,
                             hspace=0.05)

    axd={}
    for idx1, dataset_concept_list in enumerate(dataset_concept_list_list):
        box2 = gridspec.GridSpecFromSubplotSpec(len(dataset_concept_list), 1,
                        subplot_spec=box1[idx1], wspace=0.0, hspace=0.1)

        for idx2, (dataset_name, concept_name) in enumerate(dataset_concept_list):
            box3 = gridspec.GridSpecFromSubplotSpec(example_per_concept//10, 10,
                                                    subplot_spec=box2[idx2], wspace=0, hspace=0.1)
            for rank_num in range(example_per_concept):
                ax=plt.Subplot(fig, box3[rank_num])
                fig.add_subplot(ax)

                plot_key=(idx1, dataset_name, concept_name, rank_num)
                axd[plot_key]=ax

    for idx1, dataset_concept_list in enumerate(dataset_concept_list_list):       
        for idx2, (dataset_name, concept_name) in enumerate(dataset_concept_list):
            similaity_score=calculate_similaity_score(
                            image_features_norm=variable_dict[dataset_name]["image_features_norm"],
                            prompt_target_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_target_embedding_norm"],
                            prompt_ref_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_ref_embedding_norm"],
                            temp=1/np.exp(4.5944),
                            normalize=normalize)            

            image_idx_list=pd.Series(similaity_score)[(variable_dict[dataset_name]["valid_idx"])].sort_values(ascending=False).index.tolist()

            count=0
            rank_num=0
            while rank_num<example_per_concept:
                if check_image(dataset_name, 
                               variable_dict[dataset_name]["metadata_all"].index[image_idx_list[count]]):
                    pass
                else:
                    print('error',dataset_name,concept_name, image_idx_list[count])
                    count+=1
                    continue
                    
                plot_key=(idx1, dataset_name, concept_name, rank_num)
                
                item=variable_dict[dataset_name]["dataloader"].dataset.getitem(image_idx_list[count])
                print(plot_key, item["metadata"].name)
                image=item["image"]
                axd[plot_key].imshow(image.resize((300, 300)))

                axd[plot_key].set_xticks([])
                axd[plot_key].set_yticks([])        

                if rank_num==0:   
                    axd[plot_key].set_ylabel(shorten_concept_name(concept_name), fontsize=fontsize, zorder=-10)

                if rank_num==0 and idx2==0:
                      axd[plot_key].text(x=-0.3, y=1.05, transform=axd[plot_key].transAxes,
                                         s=["A", "B", "C", "D", "E"][idx1], fontsize=35, weight='bold')

                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(1)      
                    
                rank_num+=1
                count+=1
    return fig

In [None]:
def plot_main_new(dataset_concept_list_list, check_image, debug=True, debug_mode="default", fontsize=30, normalize=True):
    example_per_concept=10


    fig = plt.figure(figsize=(3*(10), 
                              3*(example_per_concept//10)*(len([j for i in dataset_concept_list_list for j in i]))))

    box1 = gridspec.GridSpec(len(dataset_concept_list_list), 1,
                             wspace=0.0,
                             hspace=0.05)

    axd={}
    for idx1, dataset_concept_list in enumerate(dataset_concept_list_list):
        box2 = gridspec.GridSpecFromSubplotSpec(len(dataset_concept_list), 1,
                        subplot_spec=box1[idx1], wspace=0.0, hspace=0.1)

        for idx2, (dataset_name_list, concept_name) in enumerate(dataset_concept_list):
            box3 = gridspec.GridSpecFromSubplotSpec(example_per_concept//10, 10,
                                                    subplot_spec=box2[idx2], wspace=0, hspace=0.1)
            for rank_num in range(example_per_concept):
                ax=plt.Subplot(fig, box3[rank_num])
                fig.add_subplot(ax)

                plot_key=(idx1, str(dataset_name_list), concept_name, rank_num)
                axd[plot_key]=ax

    for idx1, dataset_concept_list in enumerate(dataset_concept_list_list):       
        print('-----------------------------------')
        for idx2, (dataset_name_list, concept_name) in enumerate(dataset_concept_list):
            print(concept_name)
            dataset_score=[]
            for dataset_name in dataset_name_list:
                similaity_score=calculate_similaity_score(
                                image_features_norm=variable_dict[dataset_name]["image_features_norm"],
                                prompt_target_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_target_embedding_norm"],
                                prompt_ref_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_ref_embedding_norm"],
                                temp=1/np.exp(4.5944),
                                normalize=normalize)
                
                similaity_score=pd.Series(similaity_score).to_frame()
                similaity_score["dataset_name"]=dataset_name
                similaity_score["valid_idx"]=variable_dict[dataset_name]["valid_idx"]
                dataset_score.append(similaity_score)
            dataset_score=pd.concat(dataset_score)
            
            
            dataset_score_sorted=dataset_score[dataset_score["valid_idx"]==True].sort_values(0, ascending=False)

            count=0
            rank_num=0
            while rank_num<example_per_concept:
                dataset_name=dataset_score_sorted.iloc[count]["dataset_name"]
                image_index=dataset_score_sorted.iloc[count].name                
                
                if check_image(dataset_name, 
                               variable_dict[dataset_name]["metadata_all"].index[image_index]):
                    pass
                else:
                    #print('error', dataset_name, concept_name, image_index)
                    count+=1
                    continue
                    
                plot_key=(idx1, str(dataset_name_list), concept_name, rank_num)
                
                item=variable_dict[dataset_name]["dataloader"].dataset.getitem(image_index)
                #print(plot_key, item["metadata"].name)
                image=item["image"]
                axd[plot_key].imshow(image.resize((300, 300)))

                axd[plot_key].set_xticks([])
                axd[plot_key].set_yticks([])    
                
                
                if debug:
                    #print(concept_name in variable_dict[dataset_name]["metadata_all"].columns)
                    if debug_mode=="default":
                        axd[plot_key].set_title(str(dataset_name)[:5]+"_"+str(image_index)+"_"+str(variable_dict[dataset_name]["metadata_all"].index[image_index]), fontsize=8)
                        
                    elif debug_mode=="skincon":
                        if concept_name in variable_dict[dataset_name]["metadata_all"].columns:
                            if variable_dict[dataset_name]["metadata_all"].iloc[image_index][concept_name]:
                                concept_count=concept_count+1
                            concept_count_total=concept_count_total+1

                            concept_label=str(variable_dict[dataset_name]["metadata_all"].iloc[image_index][concept_name])
                        else:
                            concept_label=""

                        axd[plot_key].set_title(concept_label+"_"+str(dataset_name)[:5]+"_"+str(image_index)+"_"+str(variable_dict[dataset_name]["metadata_all"].index[image_index]), fontsize=8)
                    elif debug_mode=="attribution":
                        if "isic" in dataset_name:
                            attribution=str(variable_dict[dataset_name]["metadata_all"].iloc[image_index]["attribution"])
                            attribution={'Memorial Sloan Kettering Cancer Center': "MSCKK",
                            'Attributed to Konstantinos Liopyris': "Konstantinos Liopyris",
                            'For educational purpose only': "For educational purpose only",
                            'Department of Dermatology, Hospital Clínic de Barcelona': "Hospital Clínic de Barcelona",
                            'Pascale Guitera': "Pascale Guitera",
                            'The University of Queensland Diamantina Institute, The University of Queensland, Dermatology Research Centre': "University of Queensland",
                            'ViDIR Group, Department of Dermatology, Medical University of Vienna': "Medical University of Vienna",
                            'MSKCC': "MSKCC",
                            'Hospital Clínic de Barcelona': "Hospital Clínic de Barcelona",
                            'Anonymous': "Anonymous",
                            'Dermoscopedia': "Dermoscopedia",
                            }[attribution]
                            license=variable_dict[dataset_name]["metadata_all"].iloc[image_index]["copyright_license"]
                        elif "derm7pt" in dataset_name:
                            attribution="derm7pt"
                        elif "clinical_fd_clean_nodup_nooverlap"==dataset_name:
                            attribution=""
                            
                        #axd[plot_key].set_title(attribution+" / "+license, fontsize=8)
                        axd[plot_key].set_title(attribution, fontsize=12, pad=3)
                        #axd[plot_key].set_title(attribution+"_"+license+"_"+str(image_index)+"_"+str(variable_dict[dataset_name]["metadata_all"].index[image_index]), fontsize=8)
                                        
                

                if rank_num==0:   
                    axd[plot_key].set_ylabel(shorten_concept_name(concept_name), fontsize=fontsize, zorder=-10)

                if rank_num==0 and idx2==0:
                      axd[plot_key].text(x=-0.3, y=1.05, transform=axd[plot_key].transAxes,
                                         s=["A", "B", "C", "D", "E"][idx1], fontsize=35, weight='bold')

                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(1)      
                    
                rank_num+=1
                count+=1
    return fig

In [None]:
fig=plot_main_new(dataset_concept_list_list=[
    [
        (["clinical_fd_clean_nodup_nooverlap"], "skincon_Erythema"),
        (["clinical_fd_clean_nodup_nooverlap"], "skincon_Pigmented"),  
        (["clinical_fd_clean_nodup_nooverlap"], "skincon_Papule"),
        (["clinical_fd_clean_nodup_nooverlap"], "skincon_Dome-shaped"),
        (["clinical_fd_clean_nodup_nooverlap"], "skincon_Ulcer"),
    ],
    [
        (["isic_nodup_nooverlap", "derm7pt_derm_nodup"], "skincon_Erythema"),
        (["isic_nodup_nooverlap", "derm7pt_derm_nodup"], "skincon_Blue"),  
        (["isic_nodup_nooverlap", "derm7pt_derm_nodup"], "skincon_Nodule"),
        (["isic_nodup_nooverlap", "derm7pt_derm_nodup"], "skincon_Ulcer"),    
        (["isic_nodup_nooverlap", "derm7pt_derm_nodup"], "skincon_Warty/Papillomatous"),
    ]    
], check_image=check_image_ddionly, debug=False, debug_mode="attribution")


# fig.savefig(log_dir/"plots"/"main_example.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_example.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_example.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_example.pdf", bbox_inches='tight')
# plt.close(fig)

In [None]:
import glob

In [None]:
variable_dict.keys()

In [None]:
def plot_supplement(dataset_concept_list_list, check_image, debug=True, offset=0, normalize=True):
    example_per_concept=30


    fig = plt.figure(figsize=(3*(10), 
                              3*(example_per_concept//10)*(len([j for i in dataset_concept_list_list for j in i]))))

    box1 = gridspec.GridSpec(len(dataset_concept_list_list), 1,
                             wspace=0.0,
                             hspace=0.1)

    axd={}
    for idx1, dataset_concept_list in enumerate(dataset_concept_list_list):
        box2 = gridspec.GridSpecFromSubplotSpec(len(dataset_concept_list), 1,
                        subplot_spec=box1[idx1], wspace=0.0, hspace=0.1)


        #axd[plot_key].set_title(["A", "B", "C", "D", "E"][idx1], fontsize=30, loc="left")
        for idx2, (dataset_name, concept_name) in enumerate(dataset_concept_list):
            box3 = gridspec.GridSpecFromSubplotSpec(example_per_concept//10, 10,
                                                    subplot_spec=box2[idx2], wspace=0, hspace=0.1)
            for rank_num in range(example_per_concept):
                ax=plt.Subplot(fig, box3[rank_num])
                fig.add_subplot(ax)

                plot_key=(idx1, dataset_name, concept_name, rank_num)
                axd[plot_key]=ax

        pass

    for idx1, dataset_concept_list in enumerate(dataset_concept_list_list):       
        print('------------------------------')
        for idx2, (dataset_name, concept_name) in enumerate(dataset_concept_list):
            similaity_score=calculate_similaity_score(
                            image_features_norm=variable_dict[dataset_name]["image_features_norm"],
                            prompt_target_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_target_embedding_norm"],
                            prompt_ref_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_ref_embedding_norm"],
                            temp=1/np.exp(4.5944),
                            normalize=normalize)

            
#             print(len(similaity_score))
#             print(variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"]["md5hash"].isnull().values)
            
            image_idx_list=pd.Series(similaity_score)[
                
                
                (variable_dict[dataset_name]["valid_idx"])
#                 variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"]["md5hash"].isnull().values
            
            
            ].sort_values(ascending=False).index.tolist()

            count=0
            rank_num=0
            while rank_num<example_per_concept:
                if check_image(dataset_name, variable_dict[dataset_name]["metadata_all"].index[image_idx_list[count]]):
                    pass
                else:
                    print('error', dataset_name, concept_name, variable_dict[dataset_name]["metadata_all"].index[image_idx_list[count]])
                    count+=1
                    continue

                plot_key=(idx1, dataset_name, concept_name, rank_num)

                image=variable_dict[dataset_name]["dataloader"].dataset.getitem(image_idx_list[count])["image"]
                axd[plot_key].imshow(image.resize((300, 300)))
                
#                 axd[plot_key].set_title(image_idx_list[count])

                axd[plot_key].set_xticks([])
                axd[plot_key].set_yticks([])     
                
                if debug:
                    axd[plot_key].set_title(image_idx_list[count])
                    

                if rank_num==0:   
                    axd[plot_key].text(x=-0.0, y=1.06, transform=axd[plot_key].transAxes,
                                       s=shorten_concept_name(concept_name), fontsize=35, zorder=-10)

                if rank_num==0 and idx2==0:
                      axd[plot_key].text(x=-0.3, y=1.06, transform=axd[plot_key].transAxes,
                                         s=["A.", "B.", "C.", "D.", "E.", "F.", "G."][idx1+offset], fontsize=35, weight='bold')

    #             if len(dataset_concept_list_list)>0:
    #                 subfigs_main[idx0].suptitle({0: "A", 1: "B"}[idx0], fontsize=45, weight='bold', x=0, zorder=10)                

                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(1.5) 
                
                rank_num+=1
                count+=1                
                
    return fig

In [None]:
def plot_supplement_new(dataset_concept_list_list, check_image, debug=True, debug_mode="default", offset=0, normalize=True):
    example_per_concept=30


    fig = plt.figure(figsize=(3*(10), 
                              3*(example_per_concept//10)*(len([j for i in dataset_concept_list_list for j in i]))))

    box1 = gridspec.GridSpec(len(dataset_concept_list_list), 1,
                             wspace=0.0,
                             hspace=0.1)

    axd={}
    for idx1, dataset_concept_list in enumerate(dataset_concept_list_list):
        box2 = gridspec.GridSpecFromSubplotSpec(len(dataset_concept_list), 1,
                        subplot_spec=box1[idx1], wspace=0.0, hspace=0.1)


        #axd[plot_key].set_title(["A", "B", "C", "D", "E"][idx1], fontsize=30, loc="left")
        for idx2, (dataset_name_list, concept_name) in enumerate(dataset_concept_list):
            box3 = gridspec.GridSpecFromSubplotSpec(example_per_concept//10, 10,
                                                    subplot_spec=box2[idx2], wspace=0, hspace=0.1)
            for rank_num in range(example_per_concept):
                ax=plt.Subplot(fig, box3[rank_num])
                fig.add_subplot(ax)

                plot_key=(idx1, str(dataset_name_list), concept_name, rank_num)
                axd[plot_key]=ax

        pass

    for idx1, dataset_concept_list in enumerate(dataset_concept_list_list):       
        print('------------------------------')
        for idx2, (dataset_name_list, concept_name) in enumerate(dataset_concept_list):
            print(concept_name)
            dataset_score=[]
            for dataset_name in dataset_name_list:
                similaity_score=calculate_similaity_score(
                                image_features_norm=variable_dict[dataset_name]["image_features_norm"],
                                prompt_target_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_target_embedding_norm"],
                                prompt_ref_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_ref_embedding_norm"],
                                temp=1/np.exp(4.5944),
                                normalize=normalize)
                
                similaity_score=pd.Series(similaity_score).to_frame()
                similaity_score["dataset_name"]=dataset_name
                similaity_score["valid_idx"]=variable_dict[dataset_name]["valid_idx"]
                dataset_score.append(similaity_score)
            dataset_score=pd.concat(dataset_score)
            
            
            dataset_score_sorted=dataset_score[dataset_score["valid_idx"]==True].sort_values(0, ascending=False)

            
#             print(len(similaity_score))
#             print(variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"]["md5hash"].isnull().values)

            count=0
            rank_num=0
        
            concept_count=0
            concept_count_total=0
            while rank_num<example_per_concept:
                dataset_name=dataset_score_sorted.iloc[count]["dataset_name"]
                image_index=dataset_score_sorted.iloc[count].name
                
                if check_image(dataset_name, 
                               variable_dict[dataset_name]["metadata_all"].index[image_index]):
                    pass
                else:
                    #print('error', dataset_name, concept_name, variable_dict[dataset_name]["metadata_all"].index[image_index])
                    count+=1
                    continue

                plot_key=(idx1, str(dataset_name_list), concept_name, rank_num)

                image=variable_dict[dataset_name]["dataloader"].dataset.getitem(image_index)["image"]
                axd[plot_key].imshow(image.resize((300, 300)))
                
#                 axd[plot_key].set_title(image_idx_list[count])

                axd[plot_key].set_xticks([])
                axd[plot_key].set_yticks([])     
                
                if debug:
                    #print(concept_name in variable_dict[dataset_name]["metadata_all"].columns)
                    if debug_mode=="default":
                        axd[plot_key].set_title(str(dataset_name)[:5]+"_"+str(image_index)+"_"+str(variable_dict[dataset_name]["metadata_all"].index[image_index]), fontsize=8)
                        
                    elif debug_mode=="skincon":
                        if concept_name in variable_dict[dataset_name]["metadata_all"].columns:
                            if variable_dict[dataset_name]["metadata_all"].iloc[image_index][concept_name]:
                                concept_count=concept_count+1
                            concept_count_total=concept_count_total+1

                            concept_label=str(variable_dict[dataset_name]["metadata_all"].iloc[image_index][concept_name])
                        else:
                            concept_label=""

                        axd[plot_key].set_title(concept_label+"_"+str(dataset_name)[:5]+"_"+str(image_index)+"_"+str(variable_dict[dataset_name]["metadata_all"].index[image_index]), fontsize=8)
                    elif debug_mode=="attribution":
                        if "isic" in dataset_name:
                            attribution=str(variable_dict[dataset_name]["metadata_all"].iloc[image_index]["attribution"])
                            attribution={'Memorial Sloan Kettering Cancer Center': "MSCKK",
                            'Attributed to Konstantinos Liopyris': "Konstantinos Liopyris",
                            'For educational purpose only': "For educational purpose only",
                            'Department of Dermatology, Hospital Clínic de Barcelona': "Hospital Clínic de Barcelona",
                            'Pascale Guitera': "Pascale Guitera",
                            'The University of Queensland Diamantina Institute, The University of Queensland, Dermatology Research Centre': "University of Queensland",
                            'ViDIR Group, Department of Dermatology, Medical University of Vienna': "Medical University of Vienna",
                            'MSKCC': "MSKCC",
                            'Hospital Clínic de Barcelona': "Hospital Clínic de Barcelona",
                            'Anonymous': "Anonymous",
                            'Dermoscopedia': "Dermoscopedia",
                            }[attribution]
                            license=variable_dict[dataset_name]["metadata_all"].iloc[image_index]["copyright_license"]
                            
                            collections=variable_dict[dataset_name]["metadata_all"].iloc[image_index]
#                             import ipdb
#                             ipdb.set_trace()
                            collections=collections[collections.index.str.contains("collection")]
                            collections=collections[collections==1]
                            collections="_".join([i[11:] for i in collections.index.tolist()])
                            
                        elif "derm7pt" in dataset_name:
                            attribution="derm7pt"
                            collections=""
                            
                        #axd[plot_key].set_title(attribution+" / "+license, fontsize=8)
                        axd[plot_key].set_title(attribution+"_"+collections, fontsize=12, pad=3)
                        #axd[plot_key].set_title(attribution+"_"+license+"_"+str(image_index)+"_"+str(variable_dict[dataset_name]["metadata_all"].index[image_index]), fontsize=8)
                        

                if rank_num==0:   
                    axd[plot_key].text(x=-0.0, y=1.06, transform=axd[plot_key].transAxes,
                                       s=shorten_concept_name(concept_name), fontsize=35, zorder=-10)

                if rank_num==0 and idx2==0:
                      axd[plot_key].text(x=-0.3, y=1.06, transform=axd[plot_key].transAxes,
                                         s=["A.", "B.", "C.", "D.", "E.", "F.", "G."][idx1+offset], fontsize=35, weight='bold')

    #             if len(dataset_concept_list_list)>0:
    #                 subfigs_main[idx0].suptitle({0: "A", 1: "B"}[idx0], fontsize=45, weight='bold', x=0, zorder=10)                

                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(1.5) 
                
                rank_num+=1
                count+=1
            if debug_mode=="skincon":
                print(concept_name, concept_count/concept_count_total)
                
    return fig

In [None]:
fig=plot_supplement_new(dataset_concept_list_list=[
        [(["clinical_fd_clean_nodup_nooverlap"], "skincon_Erythema")],
        [(["clinical_fd_clean_nodup_nooverlap"], "skincon_Pigmented")],  
        [(["clinical_fd_clean_nodup_nooverlap"], "skincon_Papule")],
        [(["clinical_fd_clean_nodup_nooverlap"], "skincon_Dome-shaped")],
        [(["clinical_fd_clean_nodup_nooverlap"], "skincon_Ulcer")],
], check_image=check_image_ddionly, debug=False)

# fig.savefig(log_dir/"plots"/"supple_example_clinical.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"supple_example_clinical.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"supple_example_clinical.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"supple_example_clinical.pdf", bbox_inches='tight')

In [None]:
fig=plot_supplement_new(dataset_concept_list_list=[
        [(["derm7pt_derm_nodup", "isic_nodup_nooverlap"], "skincon_Erythema")],
        [(["derm7pt_derm_nodup", "isic_nodup_nooverlap"], "skincon_Blue")],  
        [(["derm7pt_derm_nodup", "isic_nodup_nooverlap"], "skincon_Nodule")],
        [(["derm7pt_derm_nodup", "isic_nodup_nooverlap"], "skincon_Ulcer")],
        [(["derm7pt_derm_nodup", "isic_nodup_nooverlap"], "skincon_Warty/Papillomatous")],
], check_image=check_image_ddionly,offset=0, debug=False)

fig.savefig(log_dir/"plots"/"supple_example_dermoscopy.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"supple_example_dermoscopy.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"supple_example_dermoscopy.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"supple_example_dermoscopy.pdf", bbox_inches='tight')
plt.close(fig)

In [None]:
variable_dict["isic_nodup_nooverlap"]['dataloader'].dataset.getitem(51649)

In [None]:
collection_list

In [None]:
shorten_concept_nam

In [None]:
42203
51649

In [None]:
fig=plot_supplement_new(dataset_concept_list_list=[
        [(["isic_nodup_nooverlap"], "purple pen")],
        [(["isic_nodup_nooverlap"], "orange sticker")],  
        [(["isic_nodup_nooverlap"], "nail")],
        [(["isic_nodup_nooverlap"], "hair")],
        [(["isic_nodup_nooverlap"], "dermoscope border")],
], check_image=check_image_ddionly, debug=False, debug_mode="attribution")
fig.savefig(log_dir/"plots"/"supple_example_dermoscopy_confounder.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"supple_example_dermoscopy_confounder.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"supple_example_dermoscopy_confounder.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"supple_example_dermoscopy_confounder.pdf", bbox_inches='tight')
plt.close(fig)

In [None]:
fig=plot_supplement_new(dataset_concept_list_list=[
        [(["derm7pt_derm_nodup", "isic_nodup_nooverlap"], 'derm7ptconcept_pigment network')],
        [(["derm7pt_derm_nodup", "isic_nodup_nooverlap"], 'derm7ptconcept_regression structure')],  
        [(["derm7pt_derm_nodup", "isic_nodup_nooverlap"], 'derm7ptconcept_pigmentation')],
        [(["derm7pt_derm_nodup", "isic_nodup_nooverlap"], 'derm7ptconcept_blue whitish veil')],
        [(["derm7pt_derm_nodup", "isic_nodup_nooverlap"], 'derm7ptconcept_vascular structures')],
], check_image=check_image_ddionly, offset=0, debug=False)

fig.savefig(log_dir/"plots"/"supple_example_dermoscopy_derm7pt_1.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"supple_example_dermoscopy_derm7pt_1.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"supple_example_dermoscopy_derm7pt_1.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"supple_example_dermoscopy_derm7pt_1.pdf", bbox_inches='tight')
plt.close(fig)

In [None]:
fig=plot_supplement_new(dataset_concept_list_list=[
        [(["derm7pt_derm_nodup", "isic_nodup_nooverlap"], 'derm7ptconcept_streaks')],    
        [(["derm7pt_derm_nodup", "isic_nodup_nooverlap"], 'derm7ptconcept_dots and globules')],      
    
], check_image=check_image_ddionly, offset=5, debug=False)

fig.savefig(log_dir/"plots"/"supple_example_dermoscopy_derm7pt_2.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"supple_example_dermoscopy_derm7pt_2.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"supple_example_dermoscopy_derm7pt_2.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"supple_example_dermoscopy_derm7pt_2.pdf", bbox_inches='tight')
plt.close(fig)

In [None]:
def plot_supplement_new_many(dataset_concept_list_list, model_select, debug=True, offset=0, normalize=True, ):
    example_per_concept=100


    fig = plt.figure(figsize=(3*(10), 
                              3*(example_per_concept//10)*(len([j for i in dataset_concept_list_list for j in i]))))

    box1 = gridspec.GridSpec(len(dataset_concept_list_list), 1,
                             wspace=0.0,
                             hspace=0.1)

    axd={}
    for idx1, dataset_concept_list in enumerate(dataset_concept_list_list):
        box2 = gridspec.GridSpecFromSubplotSpec(len(dataset_concept_list), 1,
                        subplot_spec=box1[idx1], wspace=0.0, hspace=0.1)


        #axd[plot_key].set_title(["A", "B", "C", "D", "E"][idx1], fontsize=30, loc="left")
        for idx2, (dataset_name_list, concept_name) in enumerate(dataset_concept_list):
            box3 = gridspec.GridSpecFromSubplotSpec(example_per_concept//10, 10,
                                                    subplot_spec=box2[idx2], wspace=0, hspace=0.1)
            for rank_num in range(example_per_concept):
                ax=plt.Subplot(fig, box3[rank_num])
                fig.add_subplot(ax)

                plot_key=(idx1, str(dataset_name_list), concept_name, rank_num)
                axd[plot_key]=ax

        pass

    for idx1, dataset_concept_list in enumerate(dataset_concept_list_list):       
        for idx2, (dataset_name_list, concept_name) in enumerate(dataset_concept_list):
            dataset_score=[]
            for dataset_name in dataset_name_list:
                if model_select=="MONET":
                    similaity_score=calculate_similaity_score(
                                    image_features_norm=variable_dict[dataset_name]["image_features_norm"],
                                    prompt_target_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_target_embedding_norm"],
                                    prompt_ref_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_ref_embedding_norm"],
                                    temp=1/np.exp(4.5944),
                                    normalize=normalize)
                elif model_select=="CLIP":
                    similaity_score=calculate_similaity_score(
                                    image_features_norm=variable_dict[dataset_name]["image_features_vanilla_norm"],
                                    prompt_target_embedding_norm=variable_dict[dataset_name]["prompt_info_vanilla"][concept_name]["prompt_target_embedding_norm"],
                                    prompt_ref_embedding_norm=variable_dict[dataset_name]["prompt_info_vanilla"][concept_name]["prompt_ref_embedding_norm"],
                                    temp=1/np.exp(4.5944),
                                    normalize=normalize)
                else:
                    raise ValueError(model_select)
                
                similaity_score=pd.Series(similaity_score).to_frame()
                similaity_score["dataset_name"]=dataset_name
                similaity_score["valid_idx"]=variable_dict[dataset_name]["valid_idx"]
                dataset_score.append(similaity_score)
            dataset_score=pd.concat(dataset_score)
            
            
            dataset_score_sorted=dataset_score[dataset_score["valid_idx"]==True].sort_values(0, ascending=False)

            
#             print(len(similaity_score))
#             print(variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"]["md5hash"].isnull().values)

            count=0
            rank_num=0
            while rank_num<example_per_concept:
                dataset_name=dataset_score_sorted.iloc[count]["dataset_name"]
                image_index=dataset_score_sorted.iloc[count].name
                
                if check_image(dataset_name, 
                               variable_dict[dataset_name]["metadata_all"].index[image_index]):
                    pass
                else:
                    print('error', dataset_name, concept_name, variable_dict[dataset_name]["metadata_all"].index[image_index])
                    count+=1
                    continue

                plot_key=(idx1, str(dataset_name_list), concept_name, rank_num)

                image=variable_dict[dataset_name]["dataloader"].dataset.getitem(image_index)["image"]
                axd[plot_key].imshow(image.resize((300, 300)))
                
#                 axd[plot_key].set_title(image_idx_list[count])

                axd[plot_key].set_xticks([])
                axd[plot_key].set_yticks([])     
                
                if debug:
                    axd[plot_key].set_title(str(dataset_name)+str(image_index))
                    

                if rank_num==0:   
                    axd[plot_key].text(x=-0.0, y=1.06, transform=axd[plot_key].transAxes,
                                       s=shorten_concept_name(concept_name), fontsize=35, zorder=-10)

                if rank_num==0 and idx2==0:
                      axd[plot_key].text(x=-0.3, y=1.06, transform=axd[plot_key].transAxes,
                                         s=["A.", "B.", "C.", "D.", "E.", "F.", "G."][idx1+offset], fontsize=35, weight='bold')

    #             if len(dataset_concept_list_list)>0:
    #                 subfigs_main[idx0].suptitle({0: "A", 1: "B"}[idx0], fontsize=45, weight='bold', x=0, zorder=10)                

                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(1.5) 
                
                rank_num+=1
                count+=1                
                
    return fig

In [None]:
# method MONET CLIP
# pen 92 83 
# orange sticker 96 99
# nail 70 54
# hair 100 97
# derm border 95 12 

In [None]:
[("isic_nodup_nooverlap", "purple pen")],
[("isic_nodup_nooverlap", "orange sticker")],  
[("isic_nodup_nooverlap", "nail")],
[("isic_nodup_nooverlap", "hair")],
[("isic_nodup_nooverlap", "dermoscope border")],

In [None]:
fig=plot_supplement_new_many(dataset_concept_list_list=[
                            [(["isic_nodup_nooverlap"], "orange sticker")]
], model_select="CLIP", debug=False)

In [None]:
fig=plot_supplement_new_many(dataset_concept_list_list=[
                            [(["isic_nodup_nooverlap"], "orange sticker")]
], model_select="MONET", debug=False)

In [None]:
fig=plot_supplement_new_vanilla(dataset_concept_list_list=[
        [(["isic_nodup_nooverlap"], "purple pen")],
        [(["isic_nodup_nooverlap"], "orange sticker")],  
        [(["isic_nodup_nooverlap"], "nail")],
        [(["isic_nodup_nooverlap"], "hair")],
        [(["isic_nodup_nooverlap"], "dermoscope border")],
], debug=False)

In [None]:
variable_dict["isic_nodup_nooverlap"].keys()

In [None]:
metadata_all_temp=variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"]

In [None]:
scipy.special.softmax([target_similarity_mean.numpy()/temp, ref_similarity_mean.numpy()/temp], axis=0)

In [None]:
def temp():
    metadata_temp=variable_dict["clinical_fd_clean_nodup"]["metadata_all"]
    skincon_idx=(metadata_temp["skincon_Do not consider this image"]==0).values
    for tone_split in ["all", "12", "34", "56"]:
        if tone_split=="all":
            tone_idx=((metadata_temp["fitzpatrick_scale"].fillna(-9).astype(int).isin([1,2,3,4,5,6,-9]))|\
                    (metadata_temp["skin_tone"].fillna(-9).astype(int).isin([12,34,56,-9]))).values
        elif tone_split=="12":
            tone_idx=((metadata_temp["fitzpatrick_scale"].fillna(-9).astype(int).isin([1,2]))|\
                    (metadata_temp["skin_tone"].fillna(-9).astype(int).isin([12]))).values
        elif tone_split=="34":
            tone_idx=((metadata_temp["fitzpatrick_scale"].fillna(-9).astype(int).isin([3,4]))|\
                    (metadata_temp["skin_tone"].fillna(-9).astype(int).isin([34]))).values
        elif tone_split=="56":
            tone_idx=((metadata_temp["fitzpatrick_scale"].fillna(-9).astype(int).isin([5,6]))|\
                    (metadata_temp["skin_tone"].fillna(-9).astype(int).isin([56]))).values        
        print(tone_split)
        print(tone_idx.shape, tone_idx.sum())
        print(skincon_idx.shape, skincon_idx.sum())
        print((skincon_idx&tone_idx).shape, (skincon_idx&tone_idx).sum())
temp()        

In [None]:
variable_dict["clinical_fd_clean_nodup"]["metadata_all"]["source"].fillna(0).value_counts()

Erythema
2741 720cf31558966c82c118ab75b50632eb.jpg
1690 58b4bc079ca94e6e9377a42ca7564b40.jpg
2894 5f046cda32a3cc547205662e7be774f9.jpg

Ulcer
2484 d8bf377acc45a3beb0c6e81bf7ac1ff5.jpg

In [None]:
def plot_training_data(dataset_concept_list_list, debug=True, offset=0, normalize=True):
    def get_idx_from_concat_dataset(idx, concat_dataset):
        offset=0
        assert isinstance(concat_dataset, list)
        assert isinstance(concat_dataset[0], torch.utils.data.Dataset)

        for count, dataset in enumerate(concat_dataset):
            if idx-offset>=len(dataset):
                offset+=len(dataset)
                continue
            return count, idx-offset      
    
    example_per_concept=30


    fig = plt.figure(figsize=(3*(10), 
                              3*(example_per_concept//10)*(len([j for i in dataset_concept_list_list for j in i]))))

    box1 = gridspec.GridSpec(len(dataset_concept_list_list), 1,
                             wspace=0.0,
                             hspace=0.1)

    axd={}
    for idx1, dataset_concept_list in enumerate(dataset_concept_list_list):
        box2 = gridspec.GridSpecFromSubplotSpec(len(dataset_concept_list), 1,
                        subplot_spec=box1[idx1], wspace=0.0, hspace=0.1)


        #axd[plot_key].set_title(["A", "B", "C", "D", "E"][idx1], fontsize=30, loc="left")
        for idx2, (dataset_name, concept_name) in enumerate(dataset_concept_list):
            box3 = gridspec.GridSpecFromSubplotSpec(example_per_concept//10, 10,
                                                    subplot_spec=box2[idx2], wspace=0, hspace=0.1)
            for rank_num in range(example_per_concept):
                ax=plt.Subplot(fig, box3[rank_num])
                fig.add_subplot(ax)

                plot_key=(idx1, dataset_name, concept_name, rank_num)
                axd[plot_key]=ax

        pass

    for idx1, dataset_concept_list in enumerate(dataset_concept_list_list):       
        for idx2, (dataset_name, concept_name) in enumerate(dataset_concept_list):
            similaity_score=calculate_similaity_score(
                            image_features_norm=variable_dict[dataset_name]["image_features_norm"],
                            prompt_target_embedding_norm=variable_dict["derm7pt_derm_nodup"]["prompt_info"][concept_name]["prompt_target_embedding_norm"],
                            prompt_ref_embedding_norm=variable_dict["derm7pt_derm_nodup"]["prompt_info"][concept_name]["prompt_ref_embedding_norm"],
                            temp=1/np.exp(4.5944),
                            normalize=normalize)

            #image_idx_list=pd.Series(similaity_score)[(variable_dict[dataset_name]["valid_idx"])].sort_values(ascending=False).index.tolist()
            
            
            if concept_name=="purple pen":
                valid_idx=(allpubmedtextbook_annotation['prompt'].str.lower().str.contains("marking").values)
            elif concept_name=="orange sticker":
                valid_idx=(allpubmedtextbook_annotation['prompt'].str.lower().str.contains("").values)
            elif concept_name=="nail":
                valid_idx=(allpubmedtextbook_annotation['prompt'].str.lower().str.contains("nail").values)
            elif concept_name=="hair":
                valid_idx=(allpubmedtextbook_annotation['prompt'].str.lower().str.contains("hair").values)
            elif concept_name=="dermoscope border":
                valid_idx=(allpubmedtextbook_annotation['prompt'].str.lower().str.contains("border").values)
            else:
                raise ValueError
            
            image_idx_list=pd.Series(similaity_score)[valid_idx].sort_values(ascending=False).index.tolist()
            

            count=0
            rank_num=0
            while rank_num<example_per_concept:
                if check_image(dataset_name, image_idx_list[count]):
                    pass
                else:
                    print('error',dataset_name,concept_name, image_idx_list[count])
                    count+=1
                    continue

                plot_key=(idx1, dataset_name, concept_name, rank_num)
                
                concat_dataset_idx, sample_idx = get_idx_from_concat_dataset(image_idx_list[count], variable_dict[dataset_name]["dataloader"].dataset.datasets)
                
                

                item=variable_dict[dataset_name]["dataloader"].dataset.datasets[concat_dataset_idx].getitem(sample_idx)
                image=item["image"]
                print(item["prompt"])
                axd[plot_key].imshow(image.resize((300, 300)))
                
#                 axd[plot_key].set_title(image_idx_list[count])

                axd[plot_key].set_xticks([])
                axd[plot_key].set_yticks([])     
                
                if debug:
                    axd[plot_key].set_title(image_idx_list[count])
                    

                if rank_num==0:   
                    axd[plot_key].text(x=-0.0, y=1.06, transform=axd[plot_key].transAxes,
                                       s=shorten_concept_name(concept_name), fontsize=35, zorder=-10)

                if rank_num==0 and idx2==0:
                      axd[plot_key].text(x=-0.3, y=1.06, transform=axd[plot_key].transAxes,
                                         s=["A.", "B.", "C.", "D.", "E.", "F.", "G."][idx1+offset], fontsize=35, weight='bold')

    #             if len(dataset_concept_list_list)>0:
    #                 subfigs_main[idx0].suptitle({0: "A", 1: "B"}[idx0], fontsize=45, weight='bold', x=0, zorder=10)                

                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(1.5) 
                
                rank_num+=1
                count+=1                
                
    return fig

In [None]:
allpubmedtextbook_annotation=pd.read_csv("logs/experiment_results/skintone_annotation/allpubmedtextbook_annotation.csv", index_col=0)


In [None]:
fig=plot_training_data(dataset_concept_list_list=[
        [("allpubmedtextbook", "purple pen")],
    
], offset=0, debug=False)

In [None]:
fig=plot_training_data(dataset_concept_list_list=[
        [("allpubmedtextbook", "purple pen")],
        [("allpubmedtextbook", "orange sticker")],  
        [("allpubmedtextbook", "nail")],
        [("allpubmedtextbook", "hair")],
        [("allpubmedtextbook", "dermoscope border")],  
    
], offset=0, debug=False)

In [None]:
fig=plot_training_data(dataset_concept_list_list=[
        [("allpubmedtextbook", 'derm7ptconcept_pigment network')],
        [("allpubmedtextbook", 'derm7ptconcept_regression structure')],  
        [("allpubmedtextbook", 'derm7ptconcept_pigmentation')],
        [("allpubmedtextbook", 'derm7ptconcept_blue whitish veil')],
        [("allpubmedtextbook", 'derm7ptconcept_vascular structures')],    
    
], offset=0, debug=False)

In [None]:
variable_dict["derm7pt_derm_nodup"]["metadata_all"].isnull().sum(axis=0)

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"]

In [None]:
variable_dict.keys()

In [None]:
get_concept_bool_from_metadata(dataset_name, metadata_all, concept_name)

In [None]:
get_concept_bool_from_metadata("derm7pt_derm_nodup",
                               metadata_all, concept_name)

In [None]:
def perform_benchmark(dataset_name, metadata_all, image_features_norm, concept_list, valid_idx, prompt_info, replace=True, min_sample=30, 
                      attribution_list=[["all", "all"]]):

    benchmark_dict_list_dict={}
    for pos_attribution, neg_attribution in attribution_list:
        print(pos_attribution, neg_attribution)
        benchmark_dict_list=[]
        #for concept_name_target in tqdm.tqdm(prompt_info.keys()):
        for concept_name_target in tqdm.tqdm(concept_list):
            metadata_all_valid=metadata_all.reset_index()[valid_idx]
            concept_bool_true=get_concept_bool_from_metadata(dataset_name, 
                                                        metadata_all_valid, 
                                                        concept_name_target)["concept_bool_true"]
            concept_bool_false=get_concept_bool_from_metadata(dataset_name, 
                                                        metadata_all_valid, 
                                                        concept_name_target)["concept_bool_false"]        

            
            pos_attribution_idx=(get_subset_index(dataset_name,
                             metadata_all_valid, 
                             pos_attribution))
        
            neg_attribution_idx=(get_subset_index(dataset_name,
                             metadata_all_valid, 
                             neg_attribution))
            print(concept_name_target, pos_attribution_idx.sum(), neg_attribution_idx.sum())
            
            if (pos_attribution_idx&(concept_bool_true).values).sum()<min_sample:
                print(concept_name_target, "continue")
                continue
            if (pos_attribution_idx&(concept_bool_false).values).sum()<min_sample:
                print(concept_name_target, "continue")
                continue
            if (neg_attribution_idx&(concept_bool_false).values).sum()<min_sample:
                print(concept_name_target, "continue")
                continue
            if (neg_attribution_idx&(concept_bool_true).values).sum()<min_sample:
                print(concept_name_target, "continue")
                continue                

    #         if concept_name_target not in metadata_all_valid.columns:
    #             print(f"Omitted {concept_name_target}")
    #             continue


            

            #for num_sample in [25, 50, 100, 200, 400]:
            for num_sample in [100]:
            #for num_sample in [100]:
                # for noise_level in [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]:
                for noise_level in [0, 0.1, 0.2, 0.3, 0.4, 0.5]:
                #for noise_level in [0.2]:
                    for random_seed in range(20):
                    #for random_seed in [1]:                    
                        num_false = int(num_sample*noise_level)
                        num_true = num_sample-num_false
                        #print(num_true, num_false)

                        if replace:
                            idx_pos_true=metadata_all_valid[pos_attribution_idx&(concept_bool_true).values].sample(n=num_true, replace=True, random_state=random_seed).index
                            idx_pos_false=metadata_all_valid[pos_attribution_idx&(concept_bool_false).values].sample(n=num_false, replace=True, random_state=random_seed).index
                            # idx_pos_false=metadata_all_valid[(metadata_all_valid[concept_name_target]!=-1).values].sample(n=num_false, replace=True, random_state=random_seed).index
                            idx_neg_true=metadata_all_valid[neg_attribution_idx&(concept_bool_false).values].sample(n=num_true, replace=True, random_state=random_seed).index
                            idx_neg_false=metadata_all_valid[neg_attribution_idx&(concept_bool_true).values].sample(n=num_false, replace=True, random_state=random_seed).index
                            # idx_neg_false=metadata_all_valid[(metadata_all_valid[concept_name_target]!=-1).values].sample(n=num_false, replace=True, random_state=random_seed).index

                            idx_pos=idx_pos_true.tolist()+idx_pos_false.tolist()
                            idx_neg=idx_neg_true.tolist()+idx_neg_false.tolist()
                        else:
                            if (concept_bool_true).values.sum()<num_sample:
                                continue
                            if (concept_bool_false).values.sum()<num_sample:
                                continue

                            idx_pos_true=metadata_all_valid[pos_attribution_idx&(concept_bool_true).values].sample(n=num_true, replace=False, random_state=random_seed).index
                            idx_pos_false=metadata_all_valid[pos_attribution_idx&(concept_bool_false).values].sample(n=num_false, replace=False, random_state=random_seed).index
                            idx_neg_true=metadata_all_valid[neg_attribution_idx&(concept_bool_false).values].sample(n=num_true, replace=False, random_state=random_seed).index
                            idx_neg_false=metadata_all_valid[neg_attribution_idx&(concept_bool_true).values].sample(n=num_false, replace=False, random_state=random_seed).index

                            idx_pos=idx_pos_true.tolist()+idx_pos_false.tolist()
                            idx_neg=idx_neg_true.tolist()+idx_neg_false.tolist()                        

                        #print(len(idx_pos_true), len(idx_pos_false), len(idx_pos))
                        #print(len(idx_neg_true), len(idx_neg_false), len(idx_neg))
                        image_features_norm_pos=image_features_norm[idx_pos].float().numpy()
                        image_features_norm_neg=image_features_norm[idx_neg].float().numpy()

                        image_features_norm_pos_mean=image_features_norm_pos.mean(axis=0)
                        image_features_norm_neg_mean=image_features_norm_neg.mean(axis=0)
                        #image_features_pos_mean=image_features_pos.mean(axis=0)

                        #print(image_features_norm_pos_mean.shape, image_features_norm_neg_mean.shape)
                        #image_features_neg_mean=image_features_neg.mean(axis=0)    

                        image_features_norm_diff=(image_features_norm_pos_mean-image_features_norm_neg_mean)

                        record_dict_list=[]
                        #for concept_name_test in prompt_info.keys():
                        if len(concept_list)<47:
                            concept_name_test_list=concept_list+np.random.RandomState(random_seed).choice([i for i in skincon_cols if i not in ['skincon_Brown(Hyperpigmentation)','skincon_White(Hypopigmentation)','skincon_Blue', 'skincon_Pigmented']], size=48-len(concept_list), replace=False).tolist()
                        else:
                            concept_name_test_list=concept_list
                            
                        for concept_name_test in concept_name_test_list:
                            concept_test_embedding_norm=prompt_info[concept_name_test]["prompt_target_embedding_norm"].numpy() #concept_embedding_dict[concept_name]
                            concept_test_embedding_norm_mean=concept_test_embedding_norm.mean(axis=(0,1))
                            record_dict_list.append({"concept_name_test": concept_name_test,
                                                     "value": (image_features_norm_diff@concept_test_embedding_norm_mean).item()
                                                     })

    #                     print(concept_name_target, noise_level, pd.DataFrame(record_dict_list).sort_values('value', ascending=False)["concept_name_test"].tolist())
                        rank=pd.DataFrame(record_dict_list).sort_values('value', ascending=False)["concept_name_test"].tolist().index(concept_name_target)
                        percentile=rank/len(record_dict_list)
                        #print(concept_name_target, pd.DataFrame(record_dict_list).sort_values('value', ascending=False).iloc[:5])


                        benchmark_dict_list.append({"concept_name_target": concept_name_target,
                                                    "num_sample": num_sample,
                                                    "noise_level": noise_level,
                                                    "random_seed": random_seed,                                                
                                                    "rank": rank,    
                                                    "num_images": len(idx_pos)+len(idx_neg),
                                                    "percentile": percentile,
                                                    "1-percentile": 1-percentile,
                                                   })
        benchmark_dict_list_dict[(pos_attribution, neg_attribution)]=benchmark_dict_list
                #print(concept_name_target, pd.DataFrame(benchmark_dict_list_temp).sort_values("value", ascending=False).iloc[:5])
    return {"benchmark_dict_list_dict":
            benchmark_dict_list_dict}

In [None]:
def get_subset_index(dataset_name, metadata_all, attribution):
    if "isic" in dataset_name:
        if attribution=="all":
            #pd.Series([True], index=metadata_all)
            subset_idx=np.array([True]*len(metadata_all))
        else:
            collection_65=(metadata_all["collection_65"]==1).values
            
            if attribution=="barcelona_all":
                subset_idx=((metadata_all["attribution"]=="Department of Dermatology, Hospital Clínic de Barcelona")|(metadata_all["attribution"]=="Hospital Clínic de Barcelona")).values
            elif attribution=="mskcc_all":
                subset_idx=((metadata_all["attribution"]=="MSKCC")|(metadata_all["attribution"]=="Memorial Sloan Kettering Cancer Center")).values
            else:
                subset_idx=(metadata_all["attribution"]==attribution).values          
                
            subset_idx=subset_idx&collection_65
        #for attribution in [None]+["barcelona", "vienna", "barcelona_all"]:
        #for attribution in ["all"]+list(metadata_all["attribution"].unique())+["barcelona_all", "mskcc_all"]:
    elif "derm7pt" in dataset_name:
        if attribution=="all":
            subset_idx=np.array([True]*len(metadata_all))
        else:
            raise NotImplementedError
            
            
    elif "clinical_fd_clean" in dataset_name:
        if attribution=="all":
            subset_idx=np.array([True]*len(metadata_all))
        else:
            if attribution=="dark":
                subset_idx=((metadata_all["fitzpatrick_scale"].fillna(-9).astype(int).isin([5,6]))|\
                        (metadata_all["skin_tone"].fillna(-9).astype(int).isin([56]))).values   
            elif attribution=="light":
                subset_idx=((metadata_all["fitzpatrick_scale"].fillna(-9).astype(int).isin([1,2,3,4]))|\
                        (metadata_all["skin_tone"].fillna(-9).astype(int).isin([12,34]))).values
    return subset_idx

In [None]:
# derm7ptconcept_pigment network         120
# derm7ptconcept_regression structure    120
# derm7ptconcept_pigmentation            120
# derm7ptconcept_blue whitish veil       120
# derm7ptconcept_vascular structures     120
# derm7ptconcept_streaks                 120
# derm7ptconcept_dots and globules       120

In [None]:
def top_n_accuracy(rank_all):
    rank_all_top1=rank_all.copy()
    rank_all_top1["top_n"]=1
    rank_all_top1["accuracy"]=(rank_all_top1["rank"]<1).astype(int)

    rank_all_top2=rank_all.copy()
    rank_all_top2["top_n"]=2
    rank_all_top2["accuracy"]=(rank_all_top2["rank"]<2).astype(int)

    rank_all_top3=rank_all.copy()
    rank_all_top3["top_n"]=3
    rank_all_top3["accuracy"]=(rank_all_top3["rank"]<3).astype(int)
    
    rank_all_top4=rank_all.copy()
    rank_all_top4["top_n"]=4
    rank_all_top4["accuracy"]=(rank_all_top4["rank"]<4).astype(int)    

    rank_all_top5=rank_all.copy()
    rank_all_top5["top_n"]=5
    rank_all_top5["accuracy"]=(rank_all_top5["rank"]<5).astype(int)
    
    rank_all_top=pd.concat([rank_all_top1,
                            rank_all_top2,
                            rank_all_top3,
                            rank_all_top4,                            
                            rank_all_top5])    
    return rank_all_top

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap"]: 
    x=perform_benchmark(
        dataset_name=dataset_name,
        metadata_all=variable_dict[dataset_name]["metadata_all"],
        image_features_norm=variable_dict[dataset_name]["image_features_norm"],
        concept_list=skincon_cols,
        valid_idx=variable_dict[dataset_name]["valid_idx"],
        prompt_info=variable_dict[dataset_name]["prompt_info"],
        min_sample=30)
#     variable_dict[dataset_name].update(x)

In [None]:
#benchmark_df=pd.DataFrame(variable_dict['clinical_fd_clean']["benchmark_dict_list"])
benchmark_df=pd.DataFrame(x["benchmark_dict_list_dict"][("all", "all")])
rank_all_top=top_n_accuracy(benchmark_df)

In [None]:
rank_all_top

In [None]:
len(rank_all_top["concept_name_target"].unique())

In [None]:
len(rank_all_top["concept_name_target"].value_counts())

In [None]:
rank_all_top

In [None]:
rank_all_top[(rank_all_top["num_sample"]==100)&(rank_all_top["noise_level"]!=0.3)&(rank_all_top["noise_level"]!=0.5)]\
.groupby(["top_n", "noise_level", "concept_name_target"])["accuracy"].mean().reset_index()

In [None]:
rank_all_top_group.groupby("noise_l")

In [None]:
rank_all_top[(rank_all_top["num_sample"]==100)&(rank_all_top["noise_level"]!=0.3)&(rank_all_top["noise_level"]!=0.5)]

In [None]:
rank_all_top

In [None]:
rank_all_top[(rank_all_top["num_sample"]==100)&(rank_all_top["noise_level"]!=0.3)&(rank_all_top["noise_level"]!=0.5)]\
.groupby(["top_n", "noise_level", "random_seed"])["accuracy"].mean().reset_index()

In [None]:
import scipy.stats

def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

In [None]:
# plt.rcParams["axes.prop_cycle"]=cycler('color',color_qual_7)
plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                Paired[12][3],
                                                                                Paired[12][5],
                                                                                Paired[12][7],
                                                                                Paired[12][9],
                                                                                Paired[12][11]
                                                                                ]])
fig = plt.figure(constrained_layout=True, figsize=(10, 5))
subfigs = fig.subfigures(1, 1)

ax = subfigs.subplots(1,1)


# sns.barplot(x="top_n", y="accuracy", hue="noise_level", 
#             data=rank_all_top[(rank_all_top["num_sample"]==100)&(rank_all_top["noise_level"]!=0.3)&(rank_all_top["noise_level"]!=0.5)],
#            ax=ax)

# rank_all_top_group=rank_all_top[(rank_all_top["num_sample"]==100)&(rank_all_top["noise_level"]!=0.3)&(rank_all_top["noise_level"]!=0.5)]\
# .groupby(["top_n", "noise_level"])["accuracy"].mean().reset_index()
# sns.barplot(x="top_n", y="accuracy", hue="noise_level", 
#             data=rank_all_top_group,
#            ax=ax)

rank_all_top_group=rank_all_top[(rank_all_top["num_sample"]==100)&(rank_all_top["noise_level"]!=0.3)&(rank_all_top["noise_level"]!=0.5)]\
.groupby(["top_n", "noise_level", "random_seed"])["accuracy"].mean().reset_index()
sns.barplot(x="top_n", y="accuracy", hue="noise_level", 
            capsize=.07,
            errcolor=(0,0,0,1),     
            errwidth=1.3,
            data=rank_all_top_group,
            ax=ax)


# sns.stripplot(y="accuracy", 
#               x="top_n",
#               hue="noise_level",
#               dodge=True,
#               color=(1,1,1,0),#"black",
#               alpha=0,
#               palette=sns.color_palette([[1,1,1,0]]),
#               #edgecolor=(0,0,0,1),
#               edgecolor=(0,0,0,1),
#               linewidth=0.8,
#               #alpha=0.3,
#               #zorder=-100,
#               size=5.5,
#               jitter=0.15, #0.3,
#               data=rank_all_top_group, 
#               legend=False,
#               ax=ax)

strip=sns.stripplot(y="accuracy", 
              x="top_n",
              hue="noise_level",
              dodge=True,
              #palette=sns.color_palette([[0,0,0]]),
              palette=sns.color_palette([[1,1,1]]),
              #edgecolor=(0,0,0,1),
              edgecolor=(0,0,0,1),
              linewidth=0.8,
              #alpha=0.3,
              #zorder=-100,
              size=4,
              #jitter=0.15, #0.3,
              data=rank_all_top_group, 
              legend=False,
              ax=ax)

for path_collection in strip.collections:
    # Get current face and edge colors
    facecolors = path_collection.get_facecolor()
    edgecolors = path_collection.get_edgecolor()

    # Set new face color with alpha, while keeping edge color the same
    new_alpha = 0.0  # set your desired alpha for fill
    new_facecolors = []
    for facecolor in facecolors:
        new_facecolor = facecolor.copy()  # Copy to avoid changing the original color
        new_facecolor[-1] = new_alpha  # Change only the alpha value
        new_facecolors.append(new_facecolor)

    path_collection.set_facecolor(new_facecolors)

# Add pointplot
# sns.pointplot(x="top_n", y="accuracy", hue="noise_level", data=rank_all_top_group, 
#               dodge=0.4, join=False, palette="dark", markers="d", scale=0.75, ax=ax)


for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.5)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_minor_locator(MultipleLocator(0.1))
ax.yaxis.grid(True, which='major', linewidth=0.4, alpha=0.6)
ax.yaxis.grid(True, which='minor', linewidth=0.4, alpha=0.2)

ax.tick_params(axis='both', which='major', labelsize=16)
ax.tick_params(axis='both', which='minor', labelsize=16)

ax.set_xlabel("Top-N", fontsize=16)
ax.set_ylabel("Accuracy", fontsize=16)

ax.set_ylim(0, 1.01)

# ax.set_title("Evaluation of Data Auditing Framework", pad=10, fontsize=16)

for patch in ax.patches :
    patch.set_linewidth(1)
    patch.set_edgecolor("black")

# leg=ax.legend(fontsize = 16, facecolor='white', framealpha=0.5)
# leg=ax.legend(fontsize = 14, facecolor='white', framealpha=0.5, 
#               loc="upper left", bbox_to_anchor=(0,0.6,0.3,0.5))
leg=ax.legend(fontsize = 14, facecolor='white', framealpha=0.2, 
              loc="upper left", bbox_to_anchor=(0,0.55,0.3,0.5),
             labelspacing=0.3)
leg.set_title("Noise Level", prop={"size":14})

# fig.savefig(log_dir/"plots"/"supple_data_audit_eval.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"supple_data_audit_eval.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"supple_data_audit_eval.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"supple_data_audit_eval.pdf", bbox_inches='tight')
# plt.close(fig)

In [None]:
for dataset_name in ["derm7pt_derm_nodup"]: 
    x=perform_benchmark(
        dataset_name=dataset_name,
        metadata_all=variable_dict[dataset_name]["metadata_all"],
        image_features_norm=variable_dict[dataset_name]["image_features_norm"],
        valid_idx=variable_dict[dataset_name]["valid_idx"],
        concept_list=['derm7ptconcept_pigment network',
                    'derm7ptconcept_regression structure',
                    'derm7ptconcept_pigmentation',
                    'derm7ptconcept_blue whitish veil',
                    'derm7ptconcept_vascular structures',
                    'derm7ptconcept_streaks',
                    'derm7ptconcept_dots and globules'],
        prompt_info=variable_dict[dataset_name]["prompt_info"],
        min_sample=30,
        attribution_list=[["all", "all"]]
    )
#     variable_dict[dataset_name].update(x)

In [None]:
#benchmark_df=pd.DataFrame(variable_dict['clinical_fd_clean']["benchmark_dict_list"])
benchmark_df=pd.DataFrame(x["benchmark_dict_list_dict"][("all", "all")])
rank_all_top=top_n_accuracy(benchmark_df)

In [None]:
# plt.rcParams["axes.prop_cycle"]=cycler('color',color_qual_7)
plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                Paired[12][3],
                                                                                Paired[12][5],
                                                                                Paired[12][7],
                                                                                Paired[12][9],
                                                                                Paired[12][11]
                                                                                ]])
fig = plt.figure(constrained_layout=True, figsize=(10, 5))
subfigs = fig.subfigures(1, 1)

ax = subfigs.subplots(1,1)


# sns.barplot(x="top_n", y="accuracy", hue="noise_level", 
#             data=rank_all_top[(rank_all_top["num_sample"]==100)&(rank_all_top["noise_level"]!=0.3)&(rank_all_top["noise_level"]!=0.5)],
#            ax=ax)

# rank_all_top_group=rank_all_top[(rank_all_top["num_sample"]==100)&(rank_all_top["noise_level"]!=0.3)&(rank_all_top["noise_level"]!=0.5)]\
# .groupby(["top_n", "noise_level"])["accuracy"].mean().reset_index()
# sns.barplot(x="top_n", y="accuracy", hue="noise_level", 
#             data=rank_all_top_group,
#            ax=ax)

rank_all_top_group=rank_all_top[(rank_all_top["num_sample"]==100)&(rank_all_top["noise_level"]!=0.3)&(rank_all_top["noise_level"]!=0.5)]\
.groupby(["top_n", "noise_level", "random_seed"])["accuracy"].mean().reset_index()
sns.barplot(x="top_n", y="accuracy", hue="noise_level", 
            capsize=.07,
            errcolor=(0,0,0,1),     
            errwidth=1.3,            
            data=rank_all_top_group,
           ax=ax)

strip=sns.stripplot(y="accuracy", 
              x="top_n",
              hue="noise_level",
              dodge=True,
              #palette=sns.color_palette([[0,0,0]]),
              palette=sns.color_palette([[1,1,1]]),
              #edgecolor=(0,0,0,1),
              edgecolor=(0,0,0,1),
              linewidth=0.8,
              #alpha=0.3,
              #zorder=-100,
              size=4,
              #jitter=0.15, #0.3,
              data=rank_all_top_group, 
              legend=False,
              ax=ax)

for path_collection in strip.collections:
    # Get current face and edge colors
    facecolors = path_collection.get_facecolor()
    edgecolors = path_collection.get_edgecolor()

    # Set new face color with alpha, while keeping edge color the same
    new_alpha = 0.0  # set your desired alpha for fill
    new_facecolors = []
    for facecolor in facecolors:
        new_facecolor = facecolor.copy()  # Copy to avoid changing the original color
        new_facecolor[-1] = new_alpha  # Change only the alpha value
        new_facecolors.append(new_facecolor)

    path_collection.set_facecolor(new_facecolors)


for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.5)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_minor_locator(MultipleLocator(0.1))
ax.yaxis.grid(True, which='major', linewidth=0.4, alpha=0.6)
ax.yaxis.grid(True, which='minor', linewidth=0.4, alpha=0.2)

ax.tick_params(axis='both', which='major', labelsize=16)
ax.tick_params(axis='both', which='minor', labelsize=16)

ax.set_xlabel("Top-N", fontsize=16)
ax.set_ylabel("Accuracy", fontsize=16)

ax.set_ylim(0, 1.01)

# ax.set_title("Evaluation of Data Auditing Framework", pad=10, fontsize=16)

for patch in ax.patches :
    patch.set_linewidth(1)
    patch.set_edgecolor("black")

leg=ax.legend(fontsize = 14, facecolor='white', framealpha=0.2, 
              loc="upper left", bbox_to_anchor=(0,0.55,0.3,0.5),
             labelspacing=0.3)
# leg=ax.legend(fontsize = 16, facecolor='white', framealpha=0.5, 
#               ncols=4,
#               loc="lower center", bbox_to_anchor=(0,-0.4,1,0.5))

leg.set_title("Noise Level", prop={"size":14})

# fig.savefig(log_dir/"plots"/"supple_data_audit_eval_derm7pt.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"supple_data_audit_eval_derm7pt.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"supple_data_audit_eval_derm7pt.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"supple_data_audit_eval_derm7pt.pdf", bbox_inches='tight')
# plt.close(fig)

In [None]:
fig

In [None]:
benchmark_df["concept_name_target"].value_counts()

In [None]:
from sklearn.model_selection import train_test_split
from MONET.datamodules.components.base_dataset import BaseDataset

def get_test_data(dataloader, metadata_all, valid_idx, y_pos, n_px=None):
    metadata_all_new = dataloader.dataset.metadata_all.copy()
    metadata_all_new["label"]=y_pos.astype(int)
    # metadata_all_new=metadata_all_new.iloc[list(true_set.union(false_set))]
    print("test:", len(metadata_all_new))
    
    if n_px is None:
        n_px=dataloader.dataset.n_px    
    
    data_all = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new,
        integrity_level="weak",
        return_label=["label"],
    )    

    from MONET.utils.loader import custom_collate
    all_dataloader = torch.utils.data.DataLoader(
        dataset=data_all,
        batch_size=64,
        num_workers=4,
        pin_memory=False,
        persistent_workers=False,
        shuffle=False,
        collate_fn=custom_collate,
    )    
          
    
    return all_dataloader

In [None]:
for dataset_name in ["isic_nodup_nooverlap", "clinical_fd_clean_nodup_nooverlap"]: 
    x=get_test_data(dataloader=variable_dict[dataset_name]["dataloader"], 
                  metadata_all=variable_dict[dataset_name]["metadata_all"],
                  valid_idx=variable_dict[dataset_name]["valid_idx"],
                  y_pos=variable_dict[dataset_name]["y_pos"])
    variable_dict[dataset_name].update({f"classifier_dataloader_all": x}) 

In [None]:
from sklearn.model_selection import train_test_split
from MONET.datamodules.components.base_dataset import BaseDataset

def get_training_data(dataloader, metadata_all, valid_idx, y_pos, subset_idx_train, n_px=None):
    metadata_all_new = dataloader.dataset.metadata_all.copy()
    metadata_all_new["label"]=y_pos.astype(int)
    metadata_all_new_train=metadata_all_new[valid_idx&subset_idx_train]
    # metadata_all_new=metadata_all_new.iloc[list(true_set.union(false_set))]

    train_idx, val_idx = train_test_split(metadata_all_new_train.index, test_size=0.2, random_state=42)
    
    print("train:", len(metadata_all_new_train))

    if n_px is None:
        n_px=dataloader.dataset.n_px
    
    data_train = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new_train.loc[train_idx],
        integrity_level="weak",
        return_label=["label"],
    )

    data_val = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new_train.loc[val_idx],
        integrity_level="weak",
        return_label=["label"],
    )

    from MONET.utils.loader import custom_collate

    train_dataloader = torch.utils.data.DataLoader(
        dataset=data_train,
        batch_size=64,
        num_workers=4,
        pin_memory=False,
        persistent_workers=False,
        shuffle=True,
        collate_fn=custom_collate,
    )
    val_dataloader = torch.utils.data.DataLoader(
        dataset=data_val,
        batch_size=64,
        num_workers=4,
        pin_memory=False,
        persistent_workers=False,
        shuffle=False,
        collate_fn=custom_collate,
    )       
    
    return train_dataloader, val_dataloader

In [None]:
for dataset_name in ["isic_nodup_nooverlap", "clinical_fd_clean_nodup_nooverlap"]: 
    for attribution_name in attribution_dict[dataset_name]:
        x=get_training_data(dataloader=variable_dict[dataset_name]["dataloader"], 
                      metadata_all=variable_dict[dataset_name]["metadata_all"],
                      valid_idx=variable_dict[dataset_name]["valid_idx"],
                      y_pos=variable_dict[dataset_name]["y_pos"],
                      subset_idx_train=get_subset_index(dataset_name=dataset_name, metadata_all=variable_dict[dataset_name]["metadata_all"], attribution=attribution_name),
                 )
        variable_dict[dataset_name].update({f"classifier_dataloader_{attribution_name}": x}) 

In [None]:
from MONET.datamodules.components.base_dataset import BaseDataset

In [None]:
!gpustat

In [None]:
import torchvision
from torch import nn
from torch.nn import functional as F
from torchmetrics import AUROC

class Classifier(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.backbone = torchvision.models.resnet50(weights="ResNet50_Weights.IMAGENET1K_V1")

        for param in self.backbone.parameters():
            param.requires_grad = True
            # pass

        head_in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.head = nn.Linear(head_in_features, output_dim)

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x
    
class Inception(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.inception = torchvision.models.inception_v3(weights="Inception_V3_Weights.IMAGENET1K_V1")
        self.fc = nn.Linear(2048, output_dim)

    def forward(self, x):
        x = self.input_to_representation(x)
        x = self.fc(x)
        # N x 200
        return x

    def input_to_representation(self, x):
        # N x 3 x 299 x 299
        x = self.inception.Conv2d_1a_3x3(x)
        # N x 32 x 149 x 149
        x = self.inception.Conv2d_2a_3x3(x)
        # N x 32 x 147 x 147
        x = self.inception.Conv2d_2b_3x3(x)
        # N x 64 x 147 x 147
        x = self.inception.maxpool1(x)
        # N x 64 x 73 x 73
        x = self.inception.Conv2d_3b_1x1(x)
        # N x 80 x 73 x 73
        x = self.inception.Conv2d_4a_3x3(x)
        # N x 192 x 71 x 71
        x = self.inception.maxpool2(x)
        # N x 192 x 35 x 35
        x = self.inception.Mixed_5b(x)
        # N x 256 x 35 x 35
        x = self.inception.Mixed_5c(x)
        # N x 288 x 35 x 35
        x = self.inception.Mixed_5d(x)
        # N x 288 x 35 x 35
        x = self.inception.Mixed_6a(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6b(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6c(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6d(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6e(x)
        # N x 768 x 17 x 17
        aux: Optional[torch.Tensor] = None
        if self.inception.AuxLogits is not None:
            if self.inception.training:
                aux = self.inception.AuxLogits(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_7a(x)
        # N x 1280 x 8 x 8
        x = self.inception.Mixed_7b(x)
        # N x 2048 x 8 x 8
        x = self.inception.Mixed_7c(x)
        # N x 2048 x 8 x 8
        # Adaptive average pooling
        x = self.inception.avgpool(x)
        # N x 2048 x 1 x 1
        x = self.inception.dropout(x)
        # N x 2048 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 2048
        return x    


class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
        
def train_classifier(train_dataloader, val_dataloader, test_dataloader, classifier_type="resnet"):
    if classifier_type=="resnet":
        classifier = Classifier(output_dim=1)
    elif classifier_type=="inception":
        classifier = Inception(output_dim=1)
    classifier_device = "cuda:4"
    classifier.to(classifier_device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=2, verbose=True)
    early_stopper = EarlyStopper(patience=5, min_delta=0)

    train_auroc = AUROC(task="binary")
    val_auroc = AUROC(task="binary")
    for epoch in range(20):
        train_loss = 0
        train_correct = 0
        classifier.train()
        for batch in tqdm.tqdm(train_dataloader):
            image, label = batch["image"].to(classifier_device), batch["label"].to(classifier_device)
            logits = classifier(image)
            weight = torch.ones(label.shape[0], device=label.device)
            loss = F.binary_cross_entropy_with_logits(
                input=logits[:, 0], target=(label == 1).float()
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * image.size(0)
            train_auroc.update(logits, (label == 1))

        val_loss = 0
        classifier.eval()
        with torch.no_grad():
            for batch in tqdm.tqdm(val_dataloader):
                image, label = batch["image"].to(classifier_device), batch["label"].to(
                    classifier_device
                )
                logits = classifier(image)
                loss = F.binary_cross_entropy_with_logits(
                    input=logits[:, 0], target=(label == 1).float()
                )
                val_loss += loss.item() * image.size(0)
                val_auroc.update(logits, (label == 1))

        print(
            f"Train loss: {train_loss/len(train_dataloader.dataset):.3f} AUROC: {train_auroc.compute():.3f} Val loss: {val_loss/len(val_dataloader.dataset):.3f} AUROC: {val_auroc.compute():.3f}"
        )
        scheduler.step(val_loss)
        if early_stopper.early_stop(val_loss):
            print("break")
            break
        train_auroc.reset()
        val_auroc.reset() 
        
    test_auroc = AUROC(task="binary")    
    test_loss = 0
    classifier.eval()
    with torch.no_grad():
        for batch in tqdm.tqdm(test_dataloader):
            image, label = batch["image"].to(classifier_device), batch["label"].to(
                classifier_device
            )
            logits = classifier(image)
            loss = F.binary_cross_entropy_with_logits(
                input=logits[:, 0], target=(label == 1).float()
            )
            test_loss += loss.item() * image.size(0)
            test_auroc.update(logits, (label == 1))

    print(
        f"Test loss: {test_loss/len(test_dataloader.dataset):.3f} AUROC: {test_auroc.compute():.3f}"
    )   
    return classifier

In [None]:
!gpustat

In [None]:
classifier_device = "cuda:4"

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "isic_nodup_nooverlap"]: 
    x=train_classifier(train_dataloader=variable_dict[dataset_name][f"classifier_dataloader_{attribution_dict[dataset_name][0]}"][0], 
                       val_dataloader=variable_dict[dataset_name][f"classifier_dataloader_{attribution_dict[dataset_name][0]}"][1],
                       test_dataloader=variable_dict[dataset_name][f"classifier_dataloader_{attribution_dict[dataset_name][1]}"][0])    
    variable_dict[dataset_name].update({f"classifier_model_{attribution_dict[dataset_name][0]}": x}) 
    
    x=train_classifier(train_dataloader=variable_dict[dataset_name][f"classifier_dataloader_{attribution_dict[dataset_name][1]}"][0], 
                       val_dataloader=variable_dict[dataset_name][f"classifier_dataloader_{attribution_dict[dataset_name][1]}"][1],
                       test_dataloader=variable_dict[dataset_name][f"classifier_dataloader_{attribution_dict[dataset_name][0]}"][0])    
    variable_dict[dataset_name].update({f"classifier_model_{attribution_dict[dataset_name][1]}": x}) 

In [None]:
def get_logits(classifier, dataloader):
    def batch_func(batch):
        with torch.no_grad():
            image, label = batch["image"].to(classifier_device), batch["label"].to(
                classifier_device
            )        
            logits = classifier(image)
        # print(batch["metadata"])
        return {
            "logits": pd.Series(logits[:,0].cpu().numpy(), index=batch["metadata"].index),
            "metadata": batch["metadata"],
        }
    
    loader_applied = dataloader_apply_func(
        dataloader=dataloader,
        func=batch_func,
        collate_fn=custom_collate_per_key,
    )    
    return loader_applied

In [None]:
# for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "isic_nodup_nooverlap"]: 
#     x=train_classifier(train_dataloader=variable_dict[dataset_name][f"classifier_dataloader_{attribution_dict[dataset_name][0]}"][0], 
#                        val_dataloader=variable_dict[dataset_name][f"classifier_dataloader_{attribution_dict[dataset_name][0]}"][1],
#                        test_dataloader=variable_dict[dataset_name][f"classifier_dataloader_{attribution_dict[dataset_name][1]}"][0])    
#     variable_dict[dataset_name].update({f"classifier_model_{attribution_dict[dataset_name][0]}": x}) 
    
#     x=train_classifier(train_dataloader=variable_dict[dataset_name][f"classifier_dataloader_{attribution_dict[dataset_name][1]}"][0], 
#                        val_dataloader=variable_dict[dataset_name][f"classifier_dataloader_{attribution_dict[dataset_name][1]}"][1],
#                        test_dataloader=variable_dict[dataset_name][f"classifier_dataloader_{attribution_dict[dataset_name][0]}"][0])    
#     variable_dict[dataset_name].update({f"classifier_model_{attribution_dict[dataset_name][1]}": x}) 

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "isic_nodup_nooverlap"]: 
    logits_hospital_1=get_logits(classifier=variable_dict[dataset_name][f'classifier_model_{attribution_dict[dataset_name][0]}'],
              dataloader=variable_dict[dataset_name]['classifier_dataloader_all'],
              )
    variable_dict[dataset_name].update({f"classifier_model_{attribution_dict[dataset_name][0]}_eval": logits_hospital_1})
    
    logits_hospital_2=get_logits(classifier=variable_dict[dataset_name][f'classifier_model_{attribution_dict[dataset_name][1]}'],
              dataloader=variable_dict[dataset_name]['classifier_dataloader_all'],
              )   
    variable_dict[dataset_name].update({f"classifier_model_{attribution_dict[dataset_name][1]}_eval": logits_hospital_2})

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap"]: 
    logits_hospital_1=get_logits(classifier=variable_dict[dataset_name][f'classifier_model_{attribution_dict[dataset_name][0]}'],
              dataloader=variable_dict[dataset_name]['classifier_dataloader_all'],
              )
    variable_dict[dataset_name].update({f"classifier_model_{attribution_dict[dataset_name][0]}_eval": logits_hospital_1})
    
    logits_hospital_2=get_logits(classifier=variable_dict[dataset_name][f'classifier_model_{attribution_dict[dataset_name][1]}'],
              dataloader=variable_dict[dataset_name]['classifier_dataloader_all'],
              )   
    variable_dict[dataset_name].update({f"classifier_model_{attribution_dict[dataset_name][1]}_eval": logits_hospital_2})

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "isic_nodup_nooverlap"]:
    variable_dict[dataset_name][f"classifier_model_{attribution_dict[dataset_name][0]}_eval"]["logits_with_index"]=\
    pd.Series(variable_dict[dataset_name][f"classifier_model_{attribution_dict[dataset_name][0]}_eval"]["logits"],
        index=variable_dict[dataset_name][f"classifier_model_{attribution_dict[dataset_name][0]}_eval"]["metadata"].index)
    variable_dict[dataset_name][f"classifier_model_{attribution_dict[dataset_name][1]}_eval"]["logits_with_index"]=\
    pd.Series(variable_dict[dataset_name][f"classifier_model_{attribution_dict[dataset_name][1]}_eval"]["logits"],
        index=variable_dict[dataset_name][f"classifier_model_{attribution_dict[dataset_name][1]}_eval"]["metadata"].index)    

In [None]:
torch.save(variable_dict, "logs/experiment_results/concept_annotation_data_auditing_0826.pt")

In [None]:
# torch.save(variable_dict, log_dir/"experiment_results"/"data_audit_new.pt")

In [None]:
def calculate_test_auc(metadata_all, 
                       logits, 
                       idx):
    return sklearn.metrics.roc_auc_score(y_true=metadata_all["label"].loc[idx].values.astype(int),
                             y_score=logits.loc[idx].values)

print(calculate_test_auc(metadata_all=variable_dict[dataset_name][f"classifier_model_{hospital_1}_eval"]["metadata"],
                  logits=variable_dict[dataset_name][f"classifier_model_{hospital_1}_eval"]["logits"],
                   idx=variable_dict["isic"][f"classifier_dataloader_{hospital_2}"][1].dataset.metadata_all.index))

print(calculate_test_auc(metadata_all=variable_dict[dataset_name][f"classifier_model_{hospital_1}_eval"]["metadata"],
                  logits=variable_dict[dataset_name][f"classifier_model_{hospital_1}_eval"]["logits"],
                   idx=variable_dict["isic"][f"classifier_dataloader_{hospital_1}"][1].dataset.metadata_all.index))


In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"].keys()

In [None]:
print(calculate_test_auc(metadata_all=variable_dict[dataset_name][f"classifier_model_{hospital_1}_eval"]["metadata"],
                  logits=variable_dict[dataset_name][f"classifier_model_{hospital_1}_eval"]["logits"],
                   idx=variable_dict["isic"][f"classifier_dataloader_{hospital_1}"][0].dataset.metadata_all.index))

In [None]:
print(calculate_test_auc(metadata_all=variable_dict[dataset_name][f"classifier_model_{hospital_2}_eval"]["metadata"],
                  logits=variable_dict[dataset_name][f"classifier_model_{hospital_2}_eval"]["logits"],
                   idx=variable_dict["isic"][f"classifier_dataloader_{hospital_1}"][1].dataset.metadata_all.index))

print(calculate_test_auc(metadata_all=variable_dict[dataset_name][f"classifier_model_{hospital_2}_eval"]["metadata"],
                  logits=variable_dict[dataset_name][f"classifier_model_{hospital_2}_eval"]["logits"],
                   idx=variable_dict["isic"][f"classifier_dataloader_{hospital_2}"][1].dataset.metadata_all.index))