In [None]:
import os

import hydra
import omegaconf
import pyrootutils

root = pyrootutils.setup_root(os.path.abspath("compare_manual-230911.ipynb"), pythonpath=True)
import os

os.chdir(root)

In [None]:
import sys

sys.path.append(str(root / "src"))

In [None]:
import os
import time
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib import gridspec
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn.metrics
import torch
import tqdm
from IPython import display
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

import clip

In [None]:
# !cp /System/Library/Fonts/Supplemental ~/.local/share/fonts/
# rm -fr ~/.cache/matplotlib
from matplotlib import font_manager
from matplotlib.lines import Line2D
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.edgecolor']='1.0'
plt.rcParams['legend.framealpha']=0

# https://github.com/dsc/colorbrewer-python/blob/master/colorbrewer.py

Set1 = {
    3: [[228,26,28], [55,126,184], [77,175,74]],
    4: [[228,26,28], [55,126,184], [77,175,74], [152,78,163]],
    5: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0]],
    6: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51]],
    7: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40]],
    8: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191]],
    9: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191], [153,153,153]],
}

Paired = {
    3: [(166,206,227), [31,120,180], [178,223,138]],
    4: [[166,206,227], [31,120,180], [178,223,138], [51,160,44]],
    5: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153]],
    6: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28]],
    7: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111]],
    8: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0]],
    9: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214]],
    10: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154]],
    11: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153]],
    12: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153], [177,89,40]]
}

color_qual_7=['#F53345',
            '#87D303',
            '#04CBCC',
            '#8650CD',
            (160/256, 95/256, 0),
            '#F5A637',              
            '#DBD783',            
             ]

pd.set_option('display.max_rows', 500)

In [None]:
import scipy.special
import tqdm.contrib.concurrent

In [None]:
from MONET.datamodules.multiplex_datamodule import MultiplexDatamodule
from MONET.utils.loader import custom_collate_per_key, dataloader_apply_func
from MONET.utils.metrics import skincon_calcualte_auc_all
from MONET.utils.static import (
    concept_to_prompt,
    fitzpatrick17k_disease_label,
    fitzpatrick17k_ninelabel,
    fitzpatrick17k_threelabel,
    skincon_cols,
)
from MONET.utils.text_processing import generate_prompt_token_from_concept

In [None]:
def wandb_to_exppath(wandb, log_path="/gscratch/cse/chanwkim/MONET_log/train/runs"):
    log_path = Path(log_path)
    for experiment in os.listdir(log_path):
        if os.path.exists(log_path / experiment / "wandb"):
            filenames = os.listdir(log_path / experiment / "wandb")
            filename = [filename for filename in filenames if filename.startswith("run")][0][-8:]
            if filename == wandb:
                return log_path / experiment
    raise RuntimeError("not found")


exppath = wandb_to_exppath(
    wandb="baqqmm5v", log_path="/projects/leelab2/chanwkim/dermatology_datasets/logs/train/runs"
)
print([exppath / "checkpoints" / ckpt for ckpt in os.listdir(exppath / "checkpoints/")])

In [None]:
variable_dict={}

In [None]:
def setup_dataloader(dataset_name):    
    if dataset_name=="clinical_fd_clean_nodup_nooverlap":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "clinical_fd_clean_nodup_nooverlap=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()  
        
    elif dataset_name=="derm7pt_derm_nodup":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt_derm_nodup=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()       
    
    elif dataset_name=="isic_nodup_nooverlap":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "isic_nodup_nooverlap=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()    
        
    return {"dataloader": dataloader}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", 
                    ]:
    variable_dict.setdefault(dataset_name, {})
    variable_dict[dataset_name].update(setup_dataloader(dataset_name))

In [None]:
!gpustat

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:4"

In [None]:
cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

In [None]:
model = hydra.utils.instantiate(cfg_model)
model.to(model_device)
model.eval()

In [None]:
model_path_dir = {
    "zt0n2xd0": "logs/train/runs/2023-01-17_20-58-15/checkpoints/last.ckpt",
}
if model_name != "ViT-L/14":
    model_path = model_path_dir[model_name]
    loaded = torch.load(model_path, map_location=model_device)
    model.load_state_dict(loaded["state_dict"])
    print('loaded')

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:4"

cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

model_vanilla = hydra.utils.instantiate(cfg_model)
model_vanilla.to(model_device)
model_vanilla.eval()

In [None]:
log_dir = Path("logs")

In [None]:
def batch_func(batch):
    with torch.no_grad():
        batch["image"] = batch["image"].to(model_device)
        image_features = model.model_step_with_image(batch)["image_features"]
        image_features_vanilla = model_vanilla.model_step_with_image(batch)["image_features"]
    # print(batch["metadata"])
    return {
        "image_features": image_features.detach().cpu(),
        "image_features_vanilla": image_features_vanilla.detach().cpu(),
        "metadata": batch["metadata"],
    }

def setup_features(dataset_name, dataloader):
    if dataset_name=="isic_nodup_nooverlap":
        loader_applied = torch.load(log_dir/"image_features"/"isic_nodup_nooverlap.pt", map_location="cpu")
        image_features = loader_applied["image_features"].cpu()
        metadata_all = loader_applied["metadata_all"]
        
        loader_applied2 = torch.load(log_dir/"image_features"/"isic_nodup_nooverlap_vanilla.pt", map_location="cpu")
        image_features_vanilla = loader_applied2["image_features_vanilla"].cpu()
        metadata_all_vanilla = loader_applied2["metadata_all"]  
        
        assert np.all(metadata_all.index==metadata_all_vanilla.index)

        return {"image_features":image_features, 
                "image_features_vanilla":image_features_vanilla,
                "metadata_all": metadata_all}
    
    else:
        loader_applied = dataloader_apply_func(
            dataloader=dataloader,
            func=batch_func,
            collate_fn=custom_collate_per_key,
        )
        image_features = loader_applied["image_features"].cpu()
        image_features_vanilla = loader_applied["image_features_vanilla"].cpu()
        metadata_all = loader_applied["metadata"]

        return {"image_features":image_features, 
                "image_features_vanilla":image_features_vanilla,
                "metadata_all": metadata_all}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", 
                    ]:
    print(dataset_name)
    variable_dict[dataset_name].update(setup_features(dataset_name, variable_dict[dataset_name]["dataloader"])) 

In [None]:
isic_diagnosis_malignant_mapping=\
{'AIMP':'indeterminate',
'acrochordon':'benign',
'actinic keratosis':'benign', # 
'angiofibroma or fibrous papule':'benign', 
'angiokeratoma':'benign',
'angioma':'benign',
'atypical melanocytic proliferation':'indeterminate',
'atypical spitz tumor':'indeterminate', #
'basal cell carcinoma':'malignant', #
'cafe-au-lait macule':'benign',
'clear cell acanthoma':'benign',
'dermatofibroma':'benign', #
'lentigo NOS':'benign',
'lentigo simplex':'benign',
'lichenoid keratosis':'benign',
'melanoma':'malignant',
'melanoma metastasis':'malignant',
'neurofibroma':'benign',
'nevus':'benign',
'other':'indeterminate',
'pigmented benign keratosis':'benign', #??
'scar':'benign',
'seborrheic keratosis':'benign',
'solar lentigo':'benign',
'squamous cell carcinoma':'malignant',
'vascular lesion':'benign', # 
'verruca':'benign'
}

def isic_map_diagnosis_malignant(diagnosis, benign_malignant):
#     if diagnosis=="basal cell carcinoma":
#         print(diagnosis_malignant_mapping[diagnosis], benign_malignant)    
    if isinstance(benign_malignant, str):
        return benign_malignant
    elif diagnosis in isic_diagnosis_malignant_mapping.keys():
        return isic_diagnosis_malignant_mapping[diagnosis]
    elif np.isnan(diagnosis):
        return "indeterminate"
    else:
        raise RuntimeError

In [None]:
derm7pt_diagnosis_malignant_mapping=\
{   'basal cell carcinoma':'malignant', 
    'blue nevus': 'benign', #
    'clark nevus':'benign', # 
    'combined nevus': 'benign', #
    'congenital nevus': 'benign', #
    'dermal nevus': 'benign', 
    'dermatofibroma':'benign', 
    'lentigo': 'benign',
    'melanoma (in situ)': 'malignant',
    'melanoma (less than 0.76 mm)': 'malignant',
    'melanoma (0.76 to 1.5 mm)': 'malignant',
    'melanoma (more than 1.5 mm)': 'malignant',
    'melanoma metastasis': 'malignant',
    'melanosis': 'benign',# 
    'miscellaneous': 'unknown', #
    'recurrent nevus': 'benign', #
    'reed or spitz nevus': 'benign', #
    'seborrheic keratosis':'benign',
    'vascular lesion': 'benign', 
    'melanoma': 'malignant',
}

def derm7pt_map_diagnosis_malignant(diagnosis):
#     if diagnosis=="basal cell carcinoma":
#         print(diagnosis_malignant_mapping[diagnosis], benign_malignant)    
    if diagnosis in derm7pt_diagnosis_malignant_mapping.keys():
        return derm7pt_diagnosis_malignant_mapping[diagnosis]
    elif np.isnan(diagnosis):
        return "indeterminate"
    else:
        raise RuntimeError

In [None]:
ddi_map = {
    "acral-melanotic-macule": "melanoma look-alike",
    "atypical-spindle-cell-nevus-of-reed": "melanoma look-alike",
    "benign-keratosis": "melanoma look-alike",
    "blue-nevus": "melanoma look-alike",
    "dermatofibroma": "melanoma look-alike",
    "dysplastic-nevus": "melanoma look-alike",
    "epidermal-nevus": "melanoma look-alike",
    "hyperpigmentation": "melanoma look-alike",
    "keloid": "melanoma look-alike",
    "inverted-follicular-keratosis": "melanoma look-alike",
    "melanocytic-nevi": "melanoma look-alike",
    "melanoma": "melanoma",
    "melanoma-acral-lentiginous": "melanoma",
    "melanoma-in-situ": "melanoma",
    "nevus-lipomatosus-superficialis": "melanoma look-alike",
    "nodular-melanoma-(nm)": "melanoma",
    "pigmented-spindle-cell-nevus-of-reed": "melanoma look-alike",
    "seborrheic-keratosis": "melanoma look-alike",
    "seborrheic-keratosis-irritated": "melanoma look-alike",
    "solar-lentigo": "melanoma look-alike",
}

In [None]:
def set_config(dataset_name, metadata_all):
    if "clinical_fd_clean" in dataset_name:
        
        valid_idx=(metadata_all["skincon_Do not consider this image"]!=1).values
        
        concept_list=skincon_cols
        
        concept_list=concept_list+["skintone_light", "skintone_dark"]    
        
        
        concept_list=concept_list+[f"fitzdisease_{disease_name}" for disease_name in ['dermatofibroma', 'granuloma annulare', 'necrobiosis lipoidica', 
         'hidradenitis', 'melanoma', 'actinic keratosis', 'scleroderma', 
         'drug eruption', 'neutrophilic dermatoses', 'nevocytic nevus', 
         'superficial spreading melanoma ssm', 'milia', 'granuloma pyogenic',
         'neurotic excoriations', 'epidermal nevus', 'erythema annulare centrifigum',
         'pustular psoriasis', 'lyme disease', 'striae', 'folliculitis', 
         'basal cell carcinoma', 'calcinosis cutis', 'allergic contact dermatitis', 
         'acne vulgaris', 'eczema', 'lichen planus', 'fixed eruptions', 
         'dariers disease', 'psoriasis', 'pilar cyst', 'neurodermatitis', 
         'lentigo maligna', 'sarcoidosis', 'squamous cell carcinoma', 
         'keloid', 'halo nevus', 'basal cell carcinoma morpheiform', 'porokeratosis of mibelli',
         'nematode infection', 'solid cystic basal cell carcinoma', 'scabies', 'becker nevus', 
         'stasis edema', 'pyogenic granuloma', 'porphyria', 'tick bite', 'malignant melanoma',
         'factitial dermatitis', 'papilomatosis confluentes and reticulate', 'sun damaged skin',
         'pityriasis rosea', 'lupus subacute', 'lupus erythematosus', 'photodermatoses', 
         'pediculosis lids', 'porokeratosis actinic', 'mycosis fungoides', 'acanthosis nigricans',
         'urticaria', 'lymphangioma', 'keratosis pilaris', 'dyshidrotic eczema', 
         'juvenile xanthogranuloma', 'nevus sebaceous of jadassohn', 'acne', 'xanthomas', 
         'neurofibromatosis', 'naevus comedonicus', 'hailey hailey disease', 
         'incontinentia pigmenti', 'vitiligo', 'mucinosis', 'erythema multiforme', 
         'xeroderma pigmentosum', 'tuberous sclerosis', 'kaposi sarcoma',
         'telangiectases', 'seborrheic keratosis', 'seborrheic dermatitis', 
         'urticaria pigmentosa', 'lichen simplex', 'pityriasis rubra pilaris', 
         'prurigo nodularis', 'dermatomyositis', 'acquired autoimmune bullous diseaseherpes gestationis', 
         'stevens johnson syndrome', 'congenital nevus', 'scleromyxedema', 
         'aplasia cutis', 'rosacea', 'ichthyosis vulgaris', 'disseminated actinic porokeratosis',
         'syringoma', 'drug induced pigmentary changes', 'tungiasis', 'fordyce spots',
         'pilomatricoma', 'erythema elevatum diutinum', 'lichen amyloidosis',
         'port wine stain', 'langerhans cell histiocytosis', 'pityriasis lichenoides chronica',
         'epidermolysis bullosa', 'ehlers danlos syndrome', 'mucous cyst', 
         'perioral dermatitis', 'behcets disease', 'erythema nodosum', 'livedo reticularis',
         'cheilitis', 'myiasis', 'rhinophyma', 'acrodermatitis enteropathica']]
        
        
        
        concept_list=concept_list+[f"ddidisease_{disease_name}" for disease_name in ['melanoma-in-situ', 'mycosis-fungoides', 'squamous-cell-carcinoma-in-situ', 
                                                                                     'basal-cell-carcinoma', 'squamous-cell-carcinoma', 'melanoma-acral-lentiginous', 
                                                                                     'basal-cell-carcinoma-superficial', 'squamous-cell-carcinoma-keratoacanthoma', 
                                                                                     'subcutaneous-t-cell-lymphoma', 'melanocytic-nevi', 'seborrheic-keratosis-irritated', 
                                                                                     'focal-acral-hyperkeratosis', 'hyperpigmentation', 'lipoma', 'foreign-body-granuloma', 
                                                                                     'blue-nevus', 'verruca-vulgaris', 'acrochordon', 'wart', 
                                                                                     'abrasions-ulcerations-and-physical-injuries', 'basal-cell-carcinoma-nodular',
                                                                                     'epidermal-cyst', 'acquired-digital-fibrokeratoma', 'epidermal-nevus', 
                                                                                     'seborrheic-keratosis', 'trichilemmoma', 'pyogenic-granuloma', 'neurofibroma', 
                                                                                     'syringocystadenoma-papilliferum', 'nevus-lipomatosus-superficialis', 'benign-keratosis',
                                                                                     'inverted-follicular-keratosis', 'onychomycosis', 'dermatofibroma', 'trichofolliculoma',
                                                                                     'lymphocytic-infiltrations', 'prurigo-nodularis', 'kaposi-sarcoma', 'scar', 
                                                                                     'eccrine-poroma', 'angioleiomyoma', 'keloid', 'hematoma', 'metastatic-carcinoma',
                                                                                     'melanoma', 'angioma', 'folliculitis', 'atypical-spindle-cell-nevus-of-reed',
                                                                                     'xanthogranuloma', 'eczema-spongiotic-dermatitis', 'arteriovenous-hemangioma', 
                                                                                     'acne-cystic', 'verruciform-xanthoma', 'molluscum-contagiosum', 'condyloma-accuminatum',
                                                                                     'morphea', 'neuroma', 'dysplastic-nevus', 'nodular-melanoma-(nm)', 'actinic-keratosis',
                                                                                     'pigmented-spindle-cell-nevus-of-reed', 'dermatomyositis', 'glomangioma', 
                                                                                     'cellular-neurothekeoma', 'fibrous-papule', 'graft-vs-host-disease', 'lichenoid-keratosis',
                                                                                     'reactive-lymphoid-hyperplasia', 'coccidioidomycosis', 'leukemia-cutis', 
                                                                                     'sebaceous-carcinoma', 'chondroid-syringoma', 'tinea-pedis', 'solar-lentigo',
                                                                                     'clear-cell-acanthoma', 'abscess', 'blastic-plasmacytoid-dendritic-cell-neoplasm', 'acral-melanotic-macule']]
        
           

        
    elif "isic" in dataset_name:
        valid_idx=np.ones(len(metadata_all))==1
        
        concept_list=[]
        concept_list=concept_list+[f"isicdisease_{disease_name}" for disease_name in ['seborrheic keratosis', 'nevus', 'squamous cell carcinoma',
                        'melanoma', 'lichenoid keratosis', 'lentigo',
                        'actinic keratosis', 'basal cell carcinoma', 'dermatofibroma',
                        'atypical melanocytic proliferation', 'verruca',
                        'clear cell acanthoma', 'angiofibroma or fibrous papule', 'scar',
                        'angioma', 'atypical spitz tumor', 'solar lentigo', 'AIMP',
                        'neurofibroma', 'lentigo simplex', 'acrochordon', 
                        'angiokeratoma', 'vascular lesion', 'cafe-au-lait macule',
                        'pigmented benign keratosis']]
        
        
        
    elif "derm7pt" in dataset_name:  
        
        valid_idx = (~metadata_all["diagnosis"].isnull()).values
             
        
        concept_list=[f"derm7ptconcept_{derm7ptconcept}" for derm7ptconcept in ["pigment network", "typical pigment network", "atypical pigment network",
                                   "regression structure",
                                   "pigmentation", "regular pigmentation", "irregular pigmentation",
                                   "blue whitish veil", 
                                   "vascular structures", "typical vascular structures", "atypical vascular structures",
                                   "streaks", "regular streaks", "irregular streaks",
                                   "dots and globules", "regular dots and globules", "irregular dots and globules",
                                  ]]
        
             
        
        concept_list=concept_list+[f"derm7ptdisease_{disease_name}" for disease_name in ['basal cell carcinoma', 'blue nevus', 'clark nevus',
                                                               'combined nevus', 'congenital nevus', 'dermal nevus',
                                                               'dermatofibroma', 'lentigo', 'melanoma', 'melanosis',
                                                                'recurrent nevus', 'reed or spitz nevus',
                                                               'seborrheic keratosis', 'vascular lesion']]           
        
        
    
#     print("y_pos_malignant", y_pos_malignant.sum(), "null", np.isnan(y_pos_malignant).sum())
    print("valid_idx", valid_idx.sum(), "null", np.isnan(valid_idx).sum())
#     print("y_pos_melanoma", y_pos_melanoma.sum(), "null", np.isnan(y_pos_melanoma).sum())
    
    return {"valid_idx": valid_idx,
#             "y_pos_malignant": y_pos_malignant,
#             "y_pos_melanoma": y_pos_melanoma,
            "metadata_all": metadata_all,
            "concept_list": concept_list}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", 
                    ]:
    print(dataset_name)
    variable_dict[dataset_name].update(set_config(dataset_name, variable_dict[dataset_name]["metadata_all"]))
    print()

In [None]:
def normalize_embedding(dataset_name, image_features):
    #prompt_ref_tokenized = clip.tokenize(prompt_ref, truncate=True)
    #output = model.model_step_with_text({"text": prompt_ref_tokenized.to(model_device)})
    #prompt_ref_embedding=output["text_features"].detach().cpu()
    #prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=1, keepdim=True)      
    
    image_features_norm = image_features / image_features.norm(dim=1, keepdim=True)
    
    return {"image_features_norm": image_features_norm}    

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup","isic_nodup_nooverlap", 
                    ]:
    variable_dict[dataset_name].update(
        {"image_features_vanilla_norm":normalize_embedding(dataset_name, 
                        variable_dict[dataset_name]["image_features_vanilla"])["image_features_norm"]}
    )           
    variable_dict[dataset_name].update(
        {"image_features_norm":normalize_embedding(dataset_name, 
                        variable_dict[dataset_name]["image_features"])["image_features_norm"]}
    )  
#     variable_dict[dataset_name].update(
#         {"image_features_vanilla_norm":normalize_embedding(dataset_name, 
#                         variable_dict[dataset_name]["image_features_vanilla"])["image_features_norm"]}
#     )    

In [None]:
def get_concept_embedding(dataset_name, concept_list, clip_model):
    prompt_info={}
    
    for concept_name in concept_list:
        if "clinical_fd_clean" in dataset_name:
            if concept_name=='skintone_light':
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                prompt_target=[[prompt_template.format(term) for term in ["white skin", "light skin tone"]] for prompt_template in prompt_template_list]
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                
            elif concept_name=='skintone_dark':
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                prompt_target=[[prompt_template.format(term) for term in ["brown skin", "black skin", "dark skin tone"]] for prompt_template in prompt_template_list]
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                
                
            elif concept_name.startswith("cbm_"):  
                cbm_concept_name=concept_name[4:]
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                prompt_target=[[prompt_template.format(term) for term in [cbm_concept_name]] for prompt_template in prompt_template_list]
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]
            elif concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v    
                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                #prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]            
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]

                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]          
            elif concept_name.startswith("ddidisease_"):
                ddi_concept_name=concept_name[11:]
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                prompt_target=[[prompt_template.format(term) for term in [ddi_concept_name.replace('-',' ')]] for prompt_template in prompt_template_list]
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                
                
            elif concept_name.startswith("fitzdisease_"):
                fitz_concept_name=concept_name[12:]
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                prompt_target=[[prompt_template.format(term) for term in [fitz_concept_name]] for prompt_template in prompt_template_list]
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                                
            
        
        elif "isic" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
                
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil","blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)         
                    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                     
                
            elif concept_name.startswith("isicconcept_"):
                isicconcept=concept_name[12:]
                if isicconcept=="pigment_network":
                    concept_term_list=["pigment network"]
                elif isicconcept=="negative_network":
                    concept_term_list=["negative network"]
                elif isicconcept=="milia_like_cyst":
                    concept_term_list=["milia like cyst"]
                    concept_term_list=["seborrheic keratosis"]
                elif isicconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif isicconcept=="globules":
                    concept_term_list=["globules"]
                else:
                    raise ValueError(isicconcept)                
            
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                 
                
                
            elif concept_name.startswith("cbm_"):  
                cbm_concept_name=concept_name[4:]
                prompt_target=[[f"This is dermatoscopy of {cbm_concept_name}"],
                               [f"This is dermoscopy of {cbm_concept_name}"]] 
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                  
                
                
            elif concept_name.startswith("isicdisease_"):  
                if concept_name=="isicdisease_AIMP":
                    disease_name=concept_name[12:]
                    prompt_target=[["This is dermatoscopy of AIMP",
                                    "This is dermatoscopy of Atypical intraepidermal melanocytic proliferation"],
                                   ["This is dermoscopy of AIMP",
                                    "This is dermoscopy of Atypical intraepidermal melanocytic proliferation"]]
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                else:
                    disease_name=concept_name[12:]
                    prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                                   [f"This is dermoscopy of {disease_name}"]] 
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                    
        elif "derm7pt_derm" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil","blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)
    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
            elif concept_name.startswith("isicconcept_"):
                isicconcept=concept_name[12:]
                if isicconcept=="pigment_network":
                    concept_term_list=["pigment network"]
                elif isicconcept=="negative_network":
                    concept_term_list=["negative network"]
                elif isicconcept=="milia_like_cyst":
                    concept_term_list=["milia like cyst"]
                    concept_term_list=["seborrheic keratosis"]
                elif isicconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif isicconcept=="globules":
                    concept_term_list=["globules"]
                else:
                    raise ValueError(isicconcept)                
            
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                  
                
            elif concept_name.startswith("cbm_"):  
                cbm_concept_name=concept_name[4:]
                prompt_target=[[f"This is dermatoscopy of {cbm_concept_name}"],
                               [f"This is dermoscopy of {cbm_concept_name}"]] 
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                  
                
            elif concept_name.startswith("derm7ptdisease_"):  
                disease_name=concept_name[15:]
                prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                               [f"This is dermoscopy of {disease_name}"]] 
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]                    
                
        
        #print(prompt_target, prompt_ref)
        # target embedding
        prompt_target_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_target]
        with torch.no_grad():
            prompt_target_embedding = torch.stack([clip_model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_target_tokenized])
        prompt_target_embedding_norm=prompt_target_embedding/prompt_target_embedding.norm(dim=2, keepdim=True)          

        # reference embedding
        prompt_ref_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_ref]
        with torch.no_grad():
            prompt_ref_embedding = torch.stack([clip_model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_ref_tokenized])
        prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=2, keepdim=True)                

        prompt_info[concept_name]={"prompt_ref_embedding_norm":prompt_ref_embedding_norm,
                                   "prompt_target_embedding_norm":prompt_target_embedding_norm,
                                  }
        print(dataset_name, concept_name, "\n" , prompt_target, prompt_ref)
        print('-----------')
        #print(prompt_target_embedding_norm)
        #print(prompt_ref_embedding_norm)
        del prompt_ref
        del prompt_target
    
    return {"prompt_info": prompt_info}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap",]:

    variable_dict[dataset_name].update(
        {"prompt_info_vanilla":get_concept_embedding(dataset_name, 
                      concept_list=variable_dict[dataset_name]["concept_list"],
                             clip_model=model_vanilla)["prompt_info"]}) 
    
    variable_dict[dataset_name].update(
        {"prompt_info":get_concept_embedding(dataset_name, 
                      concept_list=variable_dict[dataset_name]["concept_list"],
                             clip_model=model)["prompt_info"]})      
        

In [None]:
def calculate_similaity_score(image_features_norm, 
                              prompt_target_embedding_norm,
                              prompt_ref_embedding_norm,
                              temp=1,
                              normalize=True):

    target_similarity=prompt_target_embedding_norm.float()@image_features_norm.T.float()
    ref_similarity=prompt_ref_embedding_norm.float()@image_features_norm.T.float()
    
    
    target_similarity_mean=target_similarity.mean(dim=[1])
    ref_similarity_mean=ref_similarity.mean(axis=1)
         
    if normalize:
        similarity_score=scipy.special.softmax([target_similarity_mean.numpy()/temp, 
                            ref_similarity_mean.numpy()/temp], axis=0)[0,:].mean(axis=0)   
    else:
        similarity_score=target_similarity_mean.mean(axis=0)

    
    return similarity_score

In [None]:
from sklearn.model_selection import train_test_split
from MONET.datamodules.components.base_dataset import BaseDataset

def get_training_data(dataloader, metadata_all, valid_idx, y_pos, subset_idx_train, subset_idx_test, n_px=None):
    metadata_all_new = dataloader.dataset.metadata_all.copy()
    metadata_all_new["label"]=y_pos.astype(int)
    metadata_all_new_train=metadata_all_new[valid_idx&subset_idx_train]
    # metadata_all_new=metadata_all_new.iloc[list(true_set.union(false_set))]

    train_idx, val_idx = train_test_split(metadata_all_new_train.index, test_size=0.2, random_state=42)
    
    print("All:", len(metadata_all_new))
    print("train:", len(metadata_all_new_train))
    print("train_train:", len(metadata_all_new_train.loc[train_idx]))
    print("train_val:", len(metadata_all_new_train.loc[val_idx]))
    print("test:", len(metadata_all_new[valid_idx&subset_idx_test]))

    if n_px is None:
        n_px=dataloader.dataset.n_px
    
    data_train = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new_train.loc[train_idx],
        integrity_level="weak",
        return_label=["label"],
    )

    data_val = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new_train.loc[val_idx],
        integrity_level="weak",
        return_label=["label"],
    )
    
    data_test = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new[valid_idx&subset_idx_test],
        integrity_level="weak",
        return_label=["label"],
    )    

    from MONET.utils.loader import custom_collate

    train_dataloader = torch.utils.data.DataLoader(
        dataset=data_train,
        batch_size=64,
        num_workers=4,
#         pin_memory=True,
#         persistent_workers=False,
        shuffle=True,
        collate_fn=custom_collate,
    )
    val_dataloader = torch.utils.data.DataLoader(
        dataset=data_val,
        batch_size=64,
        num_workers=4,
#         pin_memory=True,
#         persistent_workers=False,
        shuffle=False,
        collate_fn=custom_collate,
    )   
    
    test_dataloader = torch.utils.data.DataLoader(
        dataset=data_test,
        batch_size=64,
        num_workers=4,
#         pin_memory=True,
#         persistent_workers=False,
        shuffle=False,
        collate_fn=custom_collate,
    )       
    
    return train_dataloader, val_dataloader, test_dataloader

In [None]:
import torchvision
from torch import nn
from torch.nn import functional as F
from torchmetrics import AUROC

class Classifier(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.backbone = torchvision.models.resnet50(weights="ResNet50_Weights.IMAGENET1K_V1")

        for param in self.backbone.parameters():
            param.requires_grad = True
            # pass

        head_in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.head = nn.Linear(head_in_features, output_dim)

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x 


class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
        
def train_classifier(train_dataloader, val_dataloader, test_dataloader, n_epochs=20, classifier_type="resnet", verbose=True):
    if classifier_type=="resnet":
        classifier = Classifier(output_dim=1)
    elif classifier_type=="inception":
        classifier = Inception(output_dim=1)
    classifier_device = "cuda:5"
    classifier.to(classifier_device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=2, verbose=True)
    early_stopper = EarlyStopper(patience=5, min_delta=0)

    train_auroc = AUROC(task="binary")
    val_auroc = AUROC(task="binary")
    for epoch in range(n_epochs):
        train_loss = 0
        train_correct = 0
        classifier.train()
        if verbose:
            pbar=tqdm.tqdm(train_dataloader)
        else:
            pbar=train_dataloader        
        for batch in pbar:
            image, label = batch["image"].to(classifier_device), batch["label"].to(classifier_device)
            logits = classifier(image)
            weight = torch.ones(label.shape[0], device=label.device)
            loss = F.binary_cross_entropy_with_logits(
                input=logits[:, 0], target=(label == 1).float()
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * image.size(0)
            train_auroc.update(logits, (label == 1))

        val_loss = 0
        classifier.eval()
        with torch.no_grad():   
            if verbose:
                pbar=tqdm.tqdm(val_dataloader)
            else:
                pbar=val_dataloader             
            for batch in pbar:
                image, label = batch["image"].to(classifier_device), batch["label"].to(
                    classifier_device
                )
                logits = classifier(image)
                loss = F.binary_cross_entropy_with_logits(
                    input=logits[:, 0], target=(label == 1).float()
                )
                val_loss += loss.item() * image.size(0)
                val_auroc.update(logits, (label == 1))
        if verbose:
            print(
                f"Train loss: {train_loss/len(train_dataloader.dataset):.3f} AUROC: {train_auroc.compute():.3f} Val loss: {val_loss/len(val_dataloader.dataset):.3f} AUROC: {val_auroc.compute():.3f}"
            )
        scheduler.step(val_loss)
        if early_stopper.early_stop(val_loss):
            print("break")
            break
        train_auroc.reset()
        val_auroc.reset() 
        
    test_auroc = AUROC(task="binary")    
    test_loss = 0
    classifier.eval()
    
    logits_list=[]
    label_list=[]
    metadata_list=[]
    
    with torch.no_grad():
        if verbose:
            pbar=tqdm.tqdm(test_dataloader)
        else:
            pbar=test_dataloader          
        for batch in tqdm.tqdm(test_dataloader):
            image, label = batch["image"].to(classifier_device), batch["label"].to(
                classifier_device
            )
            logits = classifier(image)
            loss = F.binary_cross_entropy_with_logits(
                input=logits[:, 0], target=(label == 1).float()
            )
            test_loss += loss.item() * image.size(0)
            test_auroc.update(logits, (label == 1))
            logits_list.append(logits.detach().cpu().numpy())
            label_list.append(label.detach().cpu().numpy())
            metadata_list.append(batch["metadata"])
            
    if verbose:
        print(
            f"Test loss: {test_loss/len(test_dataloader.dataset):.3f} AUROC: {test_auroc.compute():.3f}"
        )   
    return test_auroc.compute(), classifier, logits_list, label_list, metadata_list

In [None]:
!gpustat

In [None]:
def train_classifier_multilabel(train_dataloader, val_dataloader, test_dataloader, 
                                labels, n_epochs=20,
                                classifier_type="resnet", verbose=True):
    if classifier_type=="resnet":
        classifier = Classifier(output_dim=len(labels))
    elif classifier_type=="inception":
        classifier = Inception(output_dim=len(labels))
    classifier_device = "cuda:6"
    classifier.to(classifier_device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=2, verbose=True)
    early_stopper = EarlyStopper(patience=5, min_delta=0)

    train_auc=AUROC(task="multilabel", num_labels=len(labels), average="none")
    val_auc=AUROC(task="multilabel", num_labels=len(labels), average="none")
    for epoch in range(n_epochs):
        train_loss = 0
        train_correct = 0
        classifier.train()
        if verbose:
            pbar=tqdm.tqdm(train_dataloader)
        else:
            pbar=train_dataloader        
        for batch in pbar:
            image, label = batch["image"].to(classifier_device), batch["label"].to(classifier_device)
            logits = classifier(image)
            weight = torch.ones(label.shape[0], device=label.device)
            loss = F.cross_entropy(input=logits.float(), target=label.long())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * image.size(0) 
            train_auc(preds=logits.cpu(), target=F.one_hot(label.long().cpu(), num_classes=len(labels)))
            

        val_loss = 0
        classifier.eval()
        with torch.no_grad():   
            if verbose:
                pbar=tqdm.tqdm(val_dataloader)
            else:
                pbar=val_dataloader             
            for batch in pbar:
                image, label = batch["image"].to(classifier_device), batch["label"].to(
                    classifier_device
                )
                logits = classifier(image)
                loss = F.cross_entropy(input=logits.float(), target=label.long())
                val_loss += loss.item() * image.size(0)
                val_auc(preds=logits.cpu(), target=F.one_hot(label.long().cpu(), num_classes=len(labels)))
        if verbose:
            print(
                f"Train loss: {train_loss/len(train_dataloader.dataset):.3f} Val loss: {val_loss/len(val_dataloader.dataset):.3f}"
            )
            print(f"train: {train_auc.compute()}  val: {val_auc.compute()}")
        scheduler.step(val_loss)
        if early_stopper.early_stop(val_loss):
            print("break")
            break
        train_auc.reset()
        val_auc.reset()             
           
    test_auc=AUROC(task="multilabel", num_labels=len(labels), average="none")
    test_loss = 0
    classifier.eval()
    
    logits_list=[]
    label_list=[]
    metadata_list=[]
    
    with torch.no_grad():
        if verbose:
            pbar=tqdm.tqdm(test_dataloader)
        else:
            pbar=test_dataloader          
        for batch in tqdm.tqdm(test_dataloader):
            image, label = batch["image"].to(classifier_device), batch["label"].to(
                classifier_device
            )
            logits = classifier(image)
            loss = F.cross_entropy(input=logits.float(), target=label.long())
            test_loss += loss.item() * image.size(0)
            test_auc(preds=logits.cpu(), target=F.one_hot(label.long().cpu(), num_classes=len(labels)))
            
            logits_list.append(logits.detach().cpu().numpy())
            label_list.append(label.detach().cpu().numpy())
            metadata_list.append(batch["metadata"])            
            
    if verbose:
        print(
            f"Test loss: {test_loss/len(test_dataloader.dataset):.3f}"
        )   
        print(f"test: {test_auc.compute()}")
    return test_auc.compute(), classifier, logits_list, label_list, metadata_list

In [None]:
variable_dict_=torch.load(f="logs/experiment_results/compare_manaul_0930.pt", map_location="cpu")

In [None]:
variable_dict_.keys()

In [None]:
# torch.save(variable_dict, "logs/experiment_results/compare_manaul_0930.pt")

In [None]:
# variable_dict["clinical_fd_clean_nodup_nooverlap"]["disease_eval"]=variable_dict_["clinical_fd_clean_nodup_nooverlap"]["disease_eval"]

In [None]:
# variable_dict["derm7pt_derm_nodup"]["disease_eval"]=variable_dict_["derm7pt_derm_nodup"]["disease_eval"]

In [None]:
variable_dict["derm7pt_derm_nodup"].keys()

In [None]:
variable_dict_["derm7pt_derm_nodup"].keys()

In [None]:
variable_dict["isic_nodup_nooverlap"].keys()

In [None]:
variable_dict_["isic_nodup_nooverlap"].keys()

# disease eval

In [None]:
dataset_name="clinical_fd_clean_nodup_nooverlap"
#dataset_name

In [None]:
disease_list=variable_dict[dataset_name]["metadata_all"]["label"].value_counts()

In [None]:
disease_list=disease_list[disease_list>30]

In [None]:
disease_list.shape

In [None]:
variable_dict[dataset_name]["image_features_norm"]

In [None]:
def evaluate_disease(image_features_norm,
                     image_features_norm_vanilla,
                     prompt_info,
                     prompt_info_vanilla,
                     prompt_prefix,
                     group_info,
                     
                     dataloader, 
                     metadata_all,
                     disease_list,
                     valid_idx, 
                     num_repeat, 
                     y_pos):
    record_dict_list_disease=[]
    #for manual_automatic in ["automatic", "manual"]:  
    for manual_automatic in ["automatic", "manual"]:  
        if manual_automatic=="manual":            
            for random_seed in range(num_repeat):
                subset_idx_train_, subset_idx_test_ = train_test_split(np.arange(len(valid_idx))[valid_idx], 
                                                                       test_size=0.2, 
                                                                       random_state=random_seed)
                subset_idx_train=np.zeros(len(valid_idx)).astype(bool)
                subset_idx_train[subset_idx_train_]=True

                subset_idx_test=np.zeros(len(valid_idx)).astype(bool)
                subset_idx_test[subset_idx_test_]=True

                print(subset_idx_train, subset_idx_test)

                train_dataloader, val_dataloader, test_dataloader=\
                get_training_data(dataloader=dataloader, 
                                  metadata_all=metadata_all, 
                                  valid_idx=valid_idx, 
                                  y_pos=y_pos, 
                                  subset_idx_train=subset_idx_train, 
                                  subset_idx_test=subset_idx_test, 
                                  n_px=None)



                best_auc, clf, logits_list, label_list, metadata_list = train_classifier_multilabel(train_dataloader=train_dataloader, 
                                            val_dataloader=val_dataloader, 
                                            test_dataloader=test_dataloader, 
                                            labels=disease_list, n_epochs=20,
                                            classifier_type="resnet", verbose=True)        
                
                for disease_name, auc in zip(disease_list, best_auc):
                    for group_key in ["all"]+list(group_info.keys()):
                        if group_key=="all":
                            group_idx=valid_idx
                        else:
                            group_idx=group_info[group_key]
                        
                        y_true=(y_pos[valid_idx&subset_idx_test]==disease_list.index(disease_name))[group_idx[valid_idx&subset_idx_test]]
                        y_score=(np.vstack(logits_list)[:,disease_list.index(disease_name)])[group_idx[valid_idx&subset_idx_test]]
                        
                        if y_true.sum()==0:
                            auc=np.nan
                        else:
                            auc=sklearn.metrics.roc_auc_score(y_true=y_true, 
                                                  y_score=y_score)
                        
                        record_dict_list_disease.append(
                            {"manual_automatic": manual_automatic,
                             "group_key": group_key,
                             "disease_name": disease_name,
                             "random_seed": random_seed,
                             "num_test": (y_true).sum()+(~y_true).sum(),
                             "num_test_pos": (y_true).sum(),
                             "num_test_neg": (~y_true).sum(),
                             "auc": auc,
                            }
                        )    
                        print(record_dict_list_disease[-1])                
            
        
        elif manual_automatic=="automatic":
            for disease_name in disease_list:
                for trained in ["MONET", "vanilla"]:
                    #print(concept_name, manual_automatic, trained, engineered, prompt_target, prompt_ref)
                    if trained=="MONET":
                        similaity_score=calculate_similaity_score(
                                        image_features_norm=image_features_norm,
                                        prompt_target_embedding_norm=prompt_info[prompt_prefix+disease_name]["prompt_target_embedding_norm"],
                                        prompt_ref_embedding_norm=prompt_info[prompt_prefix+disease_name]["prompt_ref_embedding_norm"],
                                        temp=1/np.exp(4.5944),
                                        normalize=True)                         
                    elif trained=="vanilla":
                        similaity_score=calculate_similaity_score(
                                        image_features_norm=image_features_norm_vanilla,
                                        prompt_target_embedding_norm=prompt_info_vanilla[prompt_prefix+disease_name]["prompt_target_embedding_norm"],
                                        prompt_ref_embedding_norm=prompt_info_vanilla[prompt_prefix+disease_name]["prompt_ref_embedding_norm"],
                                        temp=1/np.exp(4.5944),
                                        normalize=True)                      

                   
                    
                    for group_key in ["all"]+list(group_info.keys()):
                        if group_key=="all":
                            group_idx=valid_idx
                        else:
                            group_idx=group_info[group_key]
                        
                        y_true=(y_pos==disease_list.index(disease_name))
                        y_true=y_true[valid_idx&group_idx]
                        y_score=similaity_score[valid_idx&group_idx]
                        
                        if y_true.sum()==0:
                            auc=np.nan
                        else:                                 
                            auc=sklearn.metrics.roc_auc_score(y_true=y_true, 
                                  y_score=y_score)


                        record_dict_list_disease.append(
                            {"manual_automatic": manual_automatic,
                             "group_key": group_key,
                             "trained": trained,
                             "disease_name": disease_name,
                             "num_test": (y_true).sum()+(~y_true).sum(),
                             "num_test_pos": (y_true).sum(),
                             "num_test_neg": (~y_true).sum(),
                             "auc": auc,
                            }
                        )
                        print(record_dict_list_disease[-1])
    return record_dict_list_disease

In [None]:
!ls logs/experiment_results/compare_manaul_* -l

In [None]:
# disease_list.index.tolist()
dataset_name="clinical_fd_clean_nodup_nooverlap"
record_dict_list_disease_subset={}
for subset_name in ["ddi", "fitzpatrick17k"]:
# for subset_name in ["fitzpatrick17k"]:
# for subset_name in ["ddi"]:
    if subset_name=="fitzpatrick17k":
        disease_list=variable_dict[dataset_name]["metadata_all"]["label"].unique().tolist()
        disease_list=[disease_name for disease_name in disease_list if not isinstance(disease_name, float)]
#         disease_list=variable_dict[dataset_name]["metadata_all"]["label"].value_counts()
#         disease_list=disease_list[disease_list>30].index.tolist()
        valid_idx=(variable_dict[dataset_name]["metadata_all"]["source"]=="fitz").values
        valid_idx=valid_idx&(variable_dict[dataset_name]["metadata_all"]["skincon_Do not consider this image"]!=1)
        print(valid_idx.sum())
        valid_idx=valid_idx&(variable_dict[dataset_name]["metadata_all"]["label"].map(lambda x: x in disease_list).values)
        print(valid_idx.sum())
        y_pos=variable_dict[dataset_name]["metadata_all"]["label"].map(lambda x: disease_list.index(x) if x in disease_list else -9)
        prompt_prefix="fitzdisease_"
        
    elif subset_name=="ddi":
        disease_list=variable_dict[dataset_name]["metadata_all"]["disease"].unique().tolist()
        disease_list=[disease_name for disease_name in disease_list if not isinstance(disease_name, float)]
#         disease_list=variable_dict[dataset_name]["metadata_all"]["disease"].value_counts()
#         disease_list=disease_list[disease_list>30].index.tolist()
        valid_idx=(variable_dict[dataset_name]["metadata_all"]["source"]=="ddi").values        
        valid_idx=valid_idx&(variable_dict[dataset_name]["metadata_all"]["skincon_Do not consider this image"]!=1)
        print(valid_idx.sum())
        valid_idx=valid_idx&(variable_dict[dataset_name]["metadata_all"]["disease"].map(lambda x: x in disease_list).values)
        print(valid_idx.sum())
        y_pos=variable_dict[dataset_name]["metadata_all"]["disease"].map(lambda x: disease_list.index(x) if x in disease_list else -9)
        prompt_prefix="ddidisease_"
    else:
        raise ValueError
        
    group_info={}
    for tone_split in ["12", "34", "56"]:
        if tone_split=="12":
            tone_idx=((variable_dict[dataset_name]["metadata_all"]["fitzpatrick_scale"].fillna(-9).astype(int).isin([1,2]))|\
                    (variable_dict[dataset_name]["metadata_all"]["skin_tone"].fillna(-9).astype(int).isin([12]))).values
        elif tone_split=="34":
            tone_idx=((variable_dict[dataset_name]["metadata_all"]["fitzpatrick_scale"].fillna(-9).astype(int).isin([3,4]))|\
                    (variable_dict[dataset_name]["metadata_all"]["skin_tone"].fillna(-9).astype(int).isin([34]))).values                            
        elif tone_split=="56":
            tone_idx=((variable_dict[dataset_name]["metadata_all"]["fitzpatrick_scale"].fillna(-9).astype(int).isin([5,6]))|\
                    (variable_dict[dataset_name]["metadata_all"]["skin_tone"].fillna(-9).astype(int).isin([56]))).values                            
        else:
            ValueError
        group_info[tone_split]=tone_idx

    record_dict_list_disease=evaluate_disease(
                     image_features_norm=variable_dict[dataset_name]["image_features_norm"],
                     image_features_norm_vanilla=variable_dict[dataset_name]["image_features_vanilla_norm"],
                     prompt_info=variable_dict[dataset_name]["prompt_info"],
                     prompt_info_vanilla=variable_dict[dataset_name]["prompt_info_vanilla"],
                     prompt_prefix=prompt_prefix,
                     group_info=group_info,
        
                     dataloader=variable_dict[dataset_name]["dataloader"], 
                     metadata_all=variable_dict[dataset_name]["metadata_all"],
                     disease_list=disease_list,
                     valid_idx=valid_idx,
                     y_pos=y_pos,
                     num_repeat=20)
    
    record_dict_list_disease_subset[subset_name]=record_dict_list_disease

In [None]:
variable_dict[dataset_name]["disease_eval"]=record_dict_list_disease_subset

In [None]:
record_dict_list_disease_df=pd.DataFrame(
    
variable_dict["clinical_fd_clean_nodup_nooverlap"]["disease_eval"]["fitzpatrick17k"])

disease_name_filtered=record_dict_list_disease_df[
    (record_dict_list_disease_df["trained"]=="MONET")&
    (record_dict_list_disease_df["num_test_pos"]>=30)
]["disease_name"].unique().tolist()

print(len(disease_name_filtered))

print("mean AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: x.groupby("disease_name")["auc"].mean().mean()))

print('\n\n')

print("std AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: x.groupby("disease_name")["auc"].mean().std()))

print('\n\n')

print("count AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: (x.groupby("disease_name")["auc"].mean()>=0.7).sum()))
print('\n\n')

In [None]:
record_dict_list_disease_df=pd.DataFrame(
    
variable_dict["clinical_fd_clean_nodup_nooverlap"]["disease_eval"]["ddi"])

disease_name_filtered=record_dict_list_disease_df[
    (record_dict_list_disease_df["trained"]=="MONET")&
    (record_dict_list_disease_df["num_test_pos"]>=30)
]["disease_name"].unique().tolist()

print(len(disease_name_filtered))

print("mean AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: x.groupby("disease_name")["auc"].mean().mean()))

print('\n\n')

print("std AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: x.groupby("disease_name")["auc"].mean().std()))

print('\n\n')

print("count AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: (x.groupby("disease_name")["auc"].mean()>=0.7).sum()))
print('\n\n')

print("total AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: len(x.groupby("disease_name")["auc"].mean()>=0.7)))
print('\n\n')

In [None]:
variable_dict["derm7pt_derm_nodup"]["metadata_all"]["diagnosis"].value_counts()

In [None]:
# disease_list.index.tolist()
dataset_name="derm7pt_derm_nodup"
disease_list=['basal cell carcinoma', 'blue nevus', 'clark nevus',
       'combined nevus', 'congenital nevus', 'dermal nevus',
       'dermatofibroma', 'lentigo', 
        'melanoma (in situ)', 'melanoma (less than 0.76 mm)', 'melanoma (0.76 to 1.5 mm)',
        'melanoma (more than 1.5 mm)', 'melanoma metastasis', 
        'melanosis', 'recurrent nevus', 'reed or spitz nevus',
       'seborrheic keratosis', 'vascular lesion', 'melanoma']
# disease_list=variable_dict[dataset_name]["metadata_all"]["diagnosis"].unique().tolist()
#         disease_list=variable_dict[dataset_name]["metadata_all"]["label"].value_counts()
#         disease_list=disease_list[disease_list>30].index.tolist()
valid_idx=variable_dict[dataset_name]["valid_idx"]
print(valid_idx.sum())
valid_idx=valid_idx&(variable_dict[dataset_name]["metadata_all"]["diagnosis"].map(lambda x: x in disease_list).values)
print(valid_idx.sum())
disease_list=['basal cell carcinoma', 'blue nevus', 'clark nevus',
       'combined nevus', 'congenital nevus', 'dermal nevus',
       'dermatofibroma', 'lentigo', 
        'melanoma',
        'melanosis', 'recurrent nevus', 'reed or spitz nevus',
       'seborrheic keratosis', 'vascular lesion', 'melanoma']
y_pos=variable_dict[dataset_name]["metadata_all"]["diagnosis"].map(lambda x:"melanoma" if "melanoma" in x else x).map(lambda x: disease_list.index(x) if x in disease_list else -9)
print(pd.Series(y_pos).value_counts())

prompt_prefix="derm7ptdisease_"

group_info={}

record_dict_list_disease=evaluate_disease(
                 image_features_norm=variable_dict[dataset_name]["image_features_norm"],
                 image_features_norm_vanilla=variable_dict[dataset_name]["image_features_vanilla_norm"],
                 prompt_info=variable_dict[dataset_name]["prompt_info"],
                 prompt_info_vanilla=variable_dict[dataset_name]["prompt_info_vanilla"],
                 prompt_prefix=prompt_prefix,
                 group_info=group_info,

                 dataloader=variable_dict[dataset_name]["dataloader"], 
                 metadata_all=variable_dict[dataset_name]["metadata_all"],
                 disease_list=disease_list,
                 valid_idx=valid_idx,
                 y_pos=y_pos,
                 num_repeat=20)

In [None]:
variable_dict[dataset_name]["disease_eval"]=record_dict_list_disease

In [None]:
variable_dict.keys()

In [None]:
variable_dict["derm7pt_derm_nodup"]["disease_eval"]

In [None]:
variable_dict[dataset_name]["prompt_info"].keys()

In [None]:
disease_list

In [None]:
sorted(variable_dict["derm7pt_derm_nodup"]["metadata_all"]["diagnosis"].value_counts().index.tolist())

In [None]:
len([
    'melanoma',
    'basal cell carcinoma',
 'blue nevus',
 'clark nevus',
 'combined nevus',
 'congenital nevus',
 'dermal nevus',
 'dermatofibroma',
 'lentigo',
 'melanosis',
 'recurrent nevus',
 'reed or spitz nevus',
 'seborrheic keratosis',
 'vascular lesion'])

In [None]:
record_dict_list_disease_df["disease_name"].value_counts()

In [None]:
record_dict_list_disease_df=pd.DataFrame(
variable_dict["derm7pt_derm_nodup"]["disease_eval"])
record_dict_list_disease_df["subset_name"]="temp"

disease_name_filtered=record_dict_list_disease_df[
    (record_dict_list_disease_df["trained"]=="MONET")&
    (record_dict_list_disease_df["num_test_pos"]>=30)
]["disease_name"].unique().tolist()

print(len(disease_name_filtered))

print("mean AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: x.groupby("disease_name")["auc"].mean().mean()))

print('\n\n')

print("std AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: x.groupby("disease_name")["auc"].mean().std()))

print('\n\n')

print("count AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: (x.groupby("disease_name")["auc"].mean()>=0.7).sum()))
print('\n\n')

print("total AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: len(x.groupby("disease_name")["auc"].mean()>=0.7)))
print('\n\n')

In [None]:
print(pd.Series(list(zip(variable_dict[dataset_name]["metadata_all"]["diagnosis"],y_pos))).value_counts())

In [None]:
# disease_list.index.tolist()
dataset_name="isic_nodup_nooverlap"
disease_list=['seborrheic keratosis', 'nevus', 'squamous cell carcinoma',
            'melanoma', 'lichenoid keratosis', 'lentigo NOS',
            'actinic keratosis', 'basal cell carcinoma', 'dermatofibroma',
            'atypical melanocytic proliferation', 'verruca',
            'clear cell acanthoma', 'angiofibroma or fibrous papule', 'scar',
            'angioma', 'atypical spitz tumor', 'solar lentigo', 'AIMP',
            'neurofibroma', 'lentigo simplex', 'acrochordon', 
            'angiokeratoma', 'vascular lesion', 'cafe-au-lait macule',
            'pigmented benign keratosis', 'melanoma metastasis']
# disease_list=variable_dict[dataset_name]["metadata_all"]["diagnosis"].unique().tolist()
#         disease_list=variable_dict[dataset_name]["metadata_all"]["label"].value_counts()
#         disease_list=disease_list[disease_list>30].index.tolist()
valid_idx=variable_dict[dataset_name]["valid_idx"]
print(valid_idx.sum())
valid_idx=valid_idx&(variable_dict[dataset_name]["metadata_all"]["diagnosis"].map(lambda x: x in disease_list).values)
print(valid_idx.sum())
disease_list=['seborrheic keratosis', 'nevus', 'squamous cell carcinoma',
            'melanoma', 'lichenoid keratosis', 'lentigo',
            'actinic keratosis', 'basal cell carcinoma', 'dermatofibroma',
            'atypical melanocytic proliferation', 'verruca',
            'clear cell acanthoma', 'angiofibroma or fibrous papule', 'scar',
            'angioma', 'atypical spitz tumor', 'solar lentigo', 'AIMP',
            'neurofibroma', 'lentigo simplex', 'acrochordon', 
            'angiokeratoma', 'vascular lesion', 'cafe-au-lait macule',
            'pigmented benign keratosis'] #, 'melanoma metastasis'
y_pos=variable_dict[dataset_name]["metadata_all"]["diagnosis"]\
.map(lambda x:"lentigo" if (isinstance(x,str) and "lentigo NOS" in x) else x)\
.map(lambda x:"melanoma" if (isinstance(x,str) and "melanoma metastasis" in x) else x)\
.map(lambda x: disease_list.index(x) if x in disease_list else -9)
print(pd.Series(list(zip(variable_dict[dataset_name]["metadata_all"]["diagnosis"],y_pos))).value_counts())
prompt_prefix="isicdisease_"

group_info={}

record_dict_list_disease=evaluate_disease(
                 image_features_norm=variable_dict[dataset_name]["image_features_norm"],
                 image_features_norm_vanilla=variable_dict[dataset_name]["image_features_vanilla_norm"],
                 prompt_info=variable_dict[dataset_name]["prompt_info"],
                 prompt_info_vanilla=variable_dict[dataset_name]["prompt_info_vanilla"],
                 prompt_prefix=prompt_prefix,
                 group_info=group_info,

                 dataloader=variable_dict[dataset_name]["dataloader"], 
                 metadata_all=variable_dict[dataset_name]["metadata_all"],
                 disease_list=disease_list,
                 valid_idx=valid_idx,
                 y_pos=y_pos,
                 num_repeat=1)

In [None]:
print(valid_idx.sum())

In [None]:
disease_list

In [None]:
valid_idx.sum()

In [None]:
valid_idx

In [None]:
26976

In [None]:
variable_dict[dataset_name]["metadata_all"][valid_idx]["diagnosis"].fillna('null').value_counts()

In [None]:
y_pos

In [None]:
variable_dict["isic_nodup_nooverlap"].keys()

In [None]:
variable_dict[dataset_name]["disease_eval"]=record_dict_list_disease

In [None]:
['seborrheic keratosis', 'nevus', 'squamous cell carcinoma',
'melanoma', 'lichenoid keratosis', 'lentigo NOS',
'actinic keratosis', 'basal cell carcinoma', 'dermatofibroma',
'atypical melanocytic proliferation', 'verruca',
'clear cell acanthoma', 'angiofibroma or fibrous papule', 'scar',
'angioma', 'atypical spitz tumor', 'solar lentigo', 'AIMP',
'neurofibroma', 'lentigo simplex', 'acrochordon', 
'angiokeratoma', 'vascular lesion', 'cafe-au-lait macule',
'pigmented benign keratosis', 'melanoma metastasis']

In [None]:
record_dict_list_disease_df=pd.DataFrame(
variable_dict["isic_nodup_nooverlap"]["disease_eval"])
record_dict_list_disease_df["subset_name"]="temp"

disease_name_filtered=record_dict_list_disease_df[
    (record_dict_list_disease_df["trained"]=="MONET")&
    (record_dict_list_disease_df["num_test_pos"]>=30)
]["disease_name"].unique().tolist()

print(len(disease_name_filtered))

print("mean AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: x.groupby("disease_name")["auc"].mean().mean()))

print('\n\n')

print("std AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: x.groupby("disease_name")["auc"].mean().std()))

print('\n\n')

print("count AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: (x.groupby("disease_name")["auc"].mean()>=0.7).sum()))
print('\n\n')

print("total AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(disease_name_filtered))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: len(x.groupby("disease_name")["auc"].mean()>=0.7)))
print('\n\n')

In [None]:
disease_name_filtered

In [None]:
print("mean AUC")
print(record_dict_list_disease_df[(record_dict_list_disease_df["disease_name"].isin(["melanoma"]))&
                           (~record_dict_list_disease_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: x.groupby("disease_name")["auc"].mean().mean()))

print('\n\n')

In [None]:
record_dict_list_disease_df

In [None]:
# variable_dict=torch.load(f="logs/experiment_results/compare_manaul_0911.pt")

In [None]:
# variable_dict=torch.load(f="logs/experiment_results/compare_manaul_0911.pt")

# concept eval

In [None]:
def get_concept_bool_from_metadata(dataset_name, metadata_all, concept_name):
    if "derm7pt_derm" in dataset_name:
        if concept_name=="derm7ptconcept_pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
            concept_bool=(metadata_all["pigment_network"]!="absent")
        elif concept_name=="derm7ptconcept_typical pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
#                 valid_idx=(metadata_all["pigment_network"].str.contains("typical")).values
            concept_bool=(metadata_all["pigment_network"]=="typical")
        elif concept_name=="derm7ptconcept_atypical pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
#                 valid_idx=(metadata_all["pigment_network"].str.contains("typical")).values
            concept_bool=(metadata_all["pigment_network"]=="atypical")
        elif concept_name=="derm7ptconcept_regression structure":
            valid_idx=(~metadata_all["regression_structures"].isnull()).values
            concept_bool=(metadata_all["regression_structures"]!="absent")
        elif concept_name=="derm7ptconcept_pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
            concept_bool=(metadata_all["pigmentation"]!="absent")
        elif concept_name=="derm7ptconcept_regular pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
#                 valid_idx=(metadata_all["pigmentation"].str.contains("regular")).values
            concept_bool=(metadata_all["pigmentation"].str.contains(" regular"))
        elif concept_name=="derm7ptconcept_irregular pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
#                 valid_idx=(metadata_all["pigmentation"].str.contains("regular")).values
            concept_bool=(metadata_all["pigmentation"].str.contains(" irregular"))
        elif concept_name=="derm7ptconcept_blue whitish veil":
            valid_idx=(~metadata_all["blue_whitish_veil"].isnull()).values
            concept_bool=(metadata_all["blue_whitish_veil"]!="absent")
        elif concept_name=="derm7ptconcept_vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"]!="absent")
        elif concept_name=="derm7ptconcept_typical vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"].isin(["within regression", "arborizing", "comma", "hairpin", "wreath"]))
        elif concept_name=="derm7ptconcept_atypical vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"].isin(["dotted", "linear irregular"]))
        elif concept_name=="derm7ptconcept_streaks":
            #print(metadata_all["streaks"].value_counts())
            valid_idx=(~metadata_all["streaks"].isnull()).values
            concept_bool=(metadata_all["streaks"]!="absent")
        elif concept_name=="derm7ptconcept_regular streaks":
            valid_idx=(~metadata_all["streaks"].isnull()).values
            #valid_idx=(metadata_all["streaks"].str.contains("regular")).values
            concept_bool=(metadata_all["streaks"]=="regular")
        elif concept_name=="derm7ptconcept_irregular streaks":
            valid_idx=(~metadata_all["streaks"].isnull()).values
            #valid_idx=(metadata_all["streaks"].str.contains("regular")).values
            concept_bool=(metadata_all["streaks"]=="irregular")
        elif concept_name=="derm7ptconcept_dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]!="absent")
        elif concept_name=="derm7ptconcept_regular dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]=="regular")
        elif concept_name=="derm7ptconcept_irregular dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]=="irregular")
        else:
            raise ValueError(concept_name)
            
        concept_bool_true=concept_bool
        concept_bool_false=(~concept_bool)


    elif "isic" in dataset_name:
        if concept_name=='isicconcept_pigment_network':
            label=(metadata_all["pigment_network"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values
            concept_bool=(label>30)
        elif concept_name=='isicconcept_negative_network':
            label=(metadata_all["negative_network"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_milia_like_cyst':
            label=(metadata_all["milia_like_cyst"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_streaks':
            label=(metadata_all["streaks"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_globules':
            label=(metadata_all["globules"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values
            concept_bool=(label>30)
        else:
            raise ValueError(concept_name)            
            
    elif "clinical_fd_clean" in dataset_name:
        if concept_name.startswith("skincon_"):
            concept_bool_true=(metadata_all[concept_name]==1)
            concept_bool_false=(metadata_all[concept_name]==0)
        elif concept_name=="skintone_dark":
            concept_bool_true=((metadata_all["fitzpatrick_scale"].fillna(-9).astype(int).isin([5,6]))|\
                                (metadata_all["skin_tone"].fillna(-9).astype(int).isin([56])))
            concept_bool_false=((metadata_all["fitzpatrick_scale"].fillna(-9).astype(int).isin([1,2,3,4]))|\
                                (metadata_all["skin_tone"].fillna(-9).astype(int).isin([12,34])))
            
        else:
            raise ValueError(concept_name)
        
      
    return {"concept_bool_true": concept_bool_true,
            "concept_bool_false": concept_bool_false,
           }

In [None]:
def evaluate_concept(image_features_norm,
                     image_features_norm_vanilla,
                     prompt_info,
                     prompt_info_vanilla,
                     prompt_prefix,
                     group_info,
                     
                     dataset_name,
                     dataloader, 
                     metadata_all,
                     concept_list,
                     valid_idx, 
                     random_seed_range,
                    method_list=["automatic", "manual"]
                    ):
    record_dict_list_concept=[]
    
    for concept_name in concept_list:
        
        
        concept_bool=get_concept_bool_from_metadata(dataset_name=dataset_name, 
                                                     metadata_all=metadata_all, 
                                                     concept_name=concept_name)
        assert len(concept_bool["concept_bool_true"])==len(concept_bool["concept_bool_false"])
        assert len(concept_bool["concept_bool_true"][valid_idx])==concept_bool["concept_bool_true"][valid_idx].sum()+concept_bool["concept_bool_false"][valid_idx].sum()                        
        
        if concept_bool["concept_bool_true"][valid_idx].sum()<30:
            print(concept_name, "!!!!!!!!!!!!!!!!!!!!!!!! SKIPPED !!!!!!!!!!!!!!!!!!!!!!!!")
            continue       
            
            
        y_pos=concept_bool["concept_bool_true"].copy().fillna(0).values
        

        #for manual_automatic in ["automatic", "manual"]:  
        for manual_automatic in method_list:  
            if manual_automatic=="manual":            
                for random_seed in random_seed_range:
                    subset_idx_train_, subset_idx_test_ = train_test_split(np.arange(len(valid_idx))[valid_idx], 
                                                                           test_size=0.2, 
                                                                           random_state=random_seed)
                    subset_idx_train=np.zeros(len(valid_idx)).astype(bool)
                    subset_idx_train[subset_idx_train_]=True

                    subset_idx_test=np.zeros(len(valid_idx)).astype(bool)
                    subset_idx_test[subset_idx_test_]=True

                    print(subset_idx_train, subset_idx_test)
                    
                    if y_pos[subset_idx_train].sum()==0:
                        continue
                    if y_pos[subset_idx_test].sum()==0:
                        continue                    

                    train_dataloader, val_dataloader, test_dataloader=\
                    get_training_data(dataloader=dataloader, 
                                      metadata_all=metadata_all, 
                                      valid_idx=valid_idx, 
                                      y_pos=y_pos, 
                                      subset_idx_train=subset_idx_train, 
                                      subset_idx_test=subset_idx_test, 
                                      n_px=None)



                    auc, x, logits_list, label_list, metadata_list = train_classifier(train_dataloader=train_dataloader, 
                                                val_dataloader=val_dataloader, 
                                                test_dataloader=test_dataloader, 
                                                n_epochs=20,
                                                classifier_type="resnet", verbose=True)    
                    
                    
                    label_temp=np.hstack(label_list)
                    logits_temp=np.concatenate(logits_list)[:,0]
                    metadata_temp=pd.concat(metadata_list)                    


                    for group_key in ["all"]+list(group_info.keys()):
                        if group_key=="all":
                            group_idx=valid_idx
                        else:
                            group_idx=group_info[group_key]

                        y_true=(y_pos[valid_idx&subset_idx_test])[group_idx[valid_idx&subset_idx_test]]
                        y_score=(np.vstack(logits_list))[group_idx[valid_idx&subset_idx_test]]

                        if y_true.sum()==0:
                            auc=np.nan
                        else:
                            auc=sklearn.metrics.roc_auc_score(y_true=y_true, 
                                                  y_score=y_score)

#                         import ipdb
#                         ipdb.set_trace()
                        #valid_idx&subset_idx_test
                        #y_pos[valid_idx&subset_idx_test].sum()
                        #y_true
                        record_dict_list_concept.append(
                            {"manual_automatic": manual_automatic,
                             "group_key": group_key,
                             "concept_name": concept_name,
                             "random_seed": random_seed,
                             "num_train": y_pos[valid_idx&subset_idx_train].sum()+(1-y_pos[valid_idx&subset_idx_train]).sum(),                                          
                             "num_train_pos": y_pos[valid_idx&subset_idx_train].sum(),                     
                             "num_train_neg": (1-y_pos[valid_idx&subset_idx_train]).sum(),                        
                             "num_test": y_pos[valid_idx&subset_idx_test].sum()+(1-y_pos[valid_idx&subset_idx_test]).sum(),                                          
                             "num_test_pos": y_pos[valid_idx&subset_idx_test].sum(),                     
                             "num_test_neg": (1-y_pos[valid_idx&subset_idx_test]).sum(),
                             "auc": auc,
                            }
                        )    
                        print(record_dict_list_concept[-1])                


            elif manual_automatic=="automatic":
                for trained in ["MONET", "vanilla"]:
                    #print(concept_name, manual_automatic, trained, engineered, prompt_target, prompt_ref)
                    if trained=="MONET":
                        similaity_score=calculate_similaity_score(
                                        image_features_norm=image_features_norm,
                                        prompt_target_embedding_norm=prompt_info[prompt_prefix+concept_name]["prompt_target_embedding_norm"],
                                        prompt_ref_embedding_norm=prompt_info[prompt_prefix+concept_name]["prompt_ref_embedding_norm"],
                                        temp=1/np.exp(4.5944),
                                        normalize=True)                         
                    elif trained=="vanilla":
                        similaity_score=calculate_similaity_score(
                                        image_features_norm=image_features_norm_vanilla,
                                        prompt_target_embedding_norm=prompt_info_vanilla[prompt_prefix+concept_name]["prompt_target_embedding_norm"],
                                        prompt_ref_embedding_norm=prompt_info_vanilla[prompt_prefix+concept_name]["prompt_ref_embedding_norm"],
                                        temp=1/np.exp(4.5944),
                                        normalize=True)                      



                    for group_key in ["all"]+list(group_info.keys()):
                        if group_key=="all":
                            group_idx=valid_idx
                        else:
                            group_idx=group_info[group_key]

                        y_true=(y_pos)
                        y_true=y_true[valid_idx&group_idx]
                        y_score=similaity_score[valid_idx&group_idx]

                        if y_true.sum()==0:
                            auc=np.nan
                        else:                                 
                            auc=sklearn.metrics.roc_auc_score(y_true=y_true, 
                                  y_score=y_score)


                        record_dict_list_concept.append(
                            {"manual_automatic": manual_automatic,
                             "group_key": group_key,
                             "trained": trained,
                             "concept_name": concept_name,
                             "num_test": (y_true).sum()+(~y_true).sum(),
                             "num_test_pos": (y_true).sum(),
                             "num_test_neg": (~y_true).sum(),
                             "auc": auc,
                            }
                        )
                        print(record_dict_list_concept[-1])
    return record_dict_list_concept

In [None]:
def get_concept_embedding(dataset_name, concept_list, clip_model):
    prompt_info={}
    
    for concept_name in concept_list:
        if "clinical_fd_clean" in dataset_name:
            if concept_name=='skintone_light':
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                prompt_target=[[prompt_template.format(term) for term in ["white skin", "light skin tone"]] for prompt_template in prompt_template_list]
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                
            elif concept_name=='skintone_dark':
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
#                 prompt_target=[[prompt_template.format(term) for term in ["brown skin", "black skin", "dark skin tone"]] for prompt_template in prompt_template_list]
                prompt_target=[[prompt_template.format(term) for term in ["dark"]] for prompt_template in prompt_template_list]
                prompt_ref=[[prompt_template.format(term) for term in ["light"]] for prompt_template in prompt_template_list]
#                 prompt_target=[[prompt_template.format(term) for term in ["fitzpatrick type 5", "fitzpatrick type 6"]] for prompt_template in prompt_template_list]
#                 prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                
                prompt_target=[[prompt_template.format(term) for term in ["brown skin", 
                                                                          "black skin", 
                                                                       "dark skin tone", 
                                                                       "Fitzpatrick V skin tone",
                                                                       "Fitzpatrick VI skin tone",
                                                                      ]] for prompt_template in prompt_template_list]
    
                prompt_ref=[[prompt_template.format(term) for term in ["white skin", 
                                                                       "light skin tone", 
                                                                       "Fitzpatrick I skin tone",
                                                                       "Fitzpatrick II skin tone",
                                                                       "Fitzpatrick III skin tone",
                                                                       "Fitzpatrick IV skin tone",
                                                                      ]] for prompt_template in prompt_template_list]
    
                
            elif concept_name.startswith("cbm_"):  
                cbm_concept_name=concept_name[4:]
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                prompt_target=[[prompt_template.format(term) for term in [cbm_concept_name]] for prompt_template in prompt_template_list]
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]
            elif concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v    
                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                #prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]            
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]

                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]          
            elif concept_name.startswith("ddidisease_"):
                ddi_concept_name=concept_name[11:]
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                prompt_target=[[prompt_template.format(term) for term in [ddi_concept_name.replace('-',' ')]] for prompt_template in prompt_template_list]
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                
                
            elif concept_name.startswith("fitzdisease_"):
                fitz_concept_name=concept_name[12:]
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                prompt_target=[[prompt_template.format(term) for term in [fitz_concept_name]] for prompt_template in prompt_template_list]
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                                
            
        
        elif "isic" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
                
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil","blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)         
                    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                     
                
            elif concept_name.startswith("isicconcept_"):
                isicconcept=concept_name[12:]
                if isicconcept=="pigment_network":
                    concept_term_list=["pigment network"]
                elif isicconcept=="negative_network":
                    concept_term_list=["negative network"]
                elif isicconcept=="milia_like_cyst":
                    concept_term_list=["milia like cyst"]
                    concept_term_list=["seborrheic keratosis"]
                elif isicconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif isicconcept=="globules":
                    concept_term_list=["globules"]
                else:
                    raise ValueError(isicconcept)                
            
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                 
                
                
            elif concept_name.startswith("cbm_"):  
                cbm_concept_name=concept_name[4:]
                prompt_target=[[f"This is dermatoscopy of {cbm_concept_name}"],
                               [f"This is dermoscopy of {cbm_concept_name}"]] 
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                  
                
                
            elif concept_name.startswith("isicdisease_"):  
                if concept_name=="isicdisease_AIMP":
                    disease_name=concept_name[12:]
                    prompt_target=[["This is dermatoscopy of AIMP",
                                    "This is dermatoscopy of Atypical intraepidermal melanocytic proliferation"],
                                   ["This is dermoscopy of AIMP",
                                    "This is dermoscopy of Atypical intraepidermal melanocytic proliferation"]]
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                else:
                    disease_name=concept_name[12:]
                    prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                                   [f"This is dermoscopy of {disease_name}"]] 
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                    
        elif "derm7pt_derm" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil","blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)
    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
            elif concept_name.startswith("isicconcept_"):
                isicconcept=concept_name[12:]
                if isicconcept=="pigment_network":
                    concept_term_list=["pigment network"]
                elif isicconcept=="negative_network":
                    concept_term_list=["negative network"]
                elif isicconcept=="milia_like_cyst":
                    concept_term_list=["milia like cyst"]
                    concept_term_list=["seborrheic keratosis"]
                elif isicconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif isicconcept=="globules":
                    concept_term_list=["globules"]
                else:
                    raise ValueError(isicconcept)                
            
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                  
                
            elif concept_name.startswith("cbm_"):  
                cbm_concept_name=concept_name[4:]
                prompt_target=[[f"This is dermatoscopy of {cbm_concept_name}"],
                               [f"This is dermoscopy of {cbm_concept_name}"]] 
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                  
                
            elif concept_name.startswith("derm7ptdisease_"):  
                disease_name=concept_name[15:]
                prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                               [f"This is dermoscopy of {disease_name}"]] 
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]                    
                
        
        #print(prompt_target, prompt_ref)
        # target embedding
        prompt_target_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_target]
        with torch.no_grad():
            prompt_target_embedding = torch.stack([clip_model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_target_tokenized])
        prompt_target_embedding_norm=prompt_target_embedding/prompt_target_embedding.norm(dim=2, keepdim=True)          

        # reference embedding
        prompt_ref_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_ref]
        with torch.no_grad():
            prompt_ref_embedding = torch.stack([clip_model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_ref_tokenized])
        prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=2, keepdim=True)                

        prompt_info[concept_name]={"prompt_ref_embedding_norm":prompt_ref_embedding_norm,
                                   "prompt_target_embedding_norm":prompt_target_embedding_norm,
                                  }
        print(dataset_name, concept_name, "\n" , prompt_target, prompt_ref)
        print('-----------')
        #print(prompt_target_embedding_norm)
        #print(prompt_ref_embedding_norm)
        del prompt_ref
        del prompt_target
    
    return {"prompt_info": prompt_info}

for dataset_name in ["clinical_fd_clean_nodup_nooverlap",]:

    variable_dict[dataset_name].update(
        {"prompt_info_vanilla":get_concept_embedding(dataset_name, 
                      concept_list=variable_dict[dataset_name]["concept_list"],
                             clip_model=model_vanilla)["prompt_info"]}) 
    
    variable_dict[dataset_name].update(
        {"prompt_info":get_concept_embedding(dataset_name, 
                      concept_list=variable_dict[dataset_name]["concept_list"],
                             clip_model=model)["prompt_info"]})      
        

In [None]:
# disease_list.index.tolist()
dataset_name="clinical_fd_clean_nodup_nooverlap"

concept_list=["skintone_dark"]
valid_idx=(variable_dict[dataset_name]["metadata_all"]["skincon_Do not consider this image"]!=1)
print(valid_idx.sum())
valid_idx=valid_idx&((get_concept_bool_from_metadata(dataset_name=dataset_name, 
                                 metadata_all=variable_dict[dataset_name]["metadata_all"], 
                                 concept_name="skintone_dark")["concept_bool_true"]|\
get_concept_bool_from_metadata(dataset_name=dataset_name, 
                                 metadata_all=variable_dict[dataset_name]["metadata_all"], 
                                 concept_name="skintone_dark")["concept_bool_false"]).values)
print(valid_idx.sum())
prompt_prefix=""

group_info={}

record_dict_list_concept=evaluate_concept(
                 image_features_norm=variable_dict[dataset_name]["image_features_norm"],
                 image_features_norm_vanilla=variable_dict[dataset_name]["image_features_vanilla_norm"],
                 prompt_info=variable_dict[dataset_name]["prompt_info"],
                 prompt_info_vanilla=variable_dict[dataset_name]["prompt_info_vanilla"],
                 prompt_prefix=prompt_prefix,
                 group_info=group_info,
    
                 dataset_name=dataset_name,
                 dataloader=variable_dict[dataset_name]["dataloader"], 
                 metadata_all=variable_dict[dataset_name]["metadata_all"],
                 concept_list=concept_list,
                 valid_idx=valid_idx,
                 random_seed_range=list(range(20)),
                 method_list=["automatic", "manual"])

In [None]:
variable_dict[dataset_name]["skintone_eval"]=record_dict_list_concept

In [None]:
dataset_name

In [None]:
record_dict_list_concept[0]

In [None]:
pd.DataFrame(record_dict_list_concept)

In [None]:
(0.895+0.843+0.876+0.894+0.860)/5


In [None]:
(get_concept_bool_from_metadata(dataset_name=dataset_name, 
                                 metadata_all=variable_dict[dataset_name]["metadata_all"], 
                                 concept_name="skintone_dark")["concept_bool_true"]|\
get_concept_bool_from_metadata(dataset_name=dataset_name, 
                                 metadata_all=variable_dict[dataset_name]["metadata_all"], 
                                 concept_name="skintone_dark")["concept_bool_false"])

In [None]:
variable_dict[dataset_name]["metadata_all"],

In [None]:
dataset_name="clinical_fd_clean_nodup_nooverlap"

In [None]:
variable_dict[dataset_name]["concept_list"]

In [None]:
(variable_dict[dataset_name]["metadata_all"]["skincon_Do not consider this image"]!=1)

In [None]:
variable_dict[dataset_name]["metadata_all"]["skincon_Cyst"].fillna(-9).isin([0,1])

In [None]:
valid_idx

In [None]:
# disease_list.index.tolist()
dataset_name="clinical_fd_clean_nodup_nooverlap"

concept_list=skincon_cols
valid_idx=(variable_dict[dataset_name]["metadata_all"]["skincon_Do not consider this image"]!=1)
print(valid_idx.sum())
valid_idx=valid_idx&(variable_dict[dataset_name]["metadata_all"]["skincon_Cyst"].fillna(-9).isin([0,1]))
print(valid_idx.sum())
prompt_prefix=""

group_info={}
for tone_split in ["12", "34", "56"]:
    if tone_split=="12":
        tone_idx=((variable_dict[dataset_name]["metadata_all"]["fitzpatrick_scale"].fillna(-9).astype(int).isin([1,2]))|\
                (variable_dict[dataset_name]["metadata_all"]["skin_tone"].fillna(-9).astype(int).isin([12]))).values
    elif tone_split=="34":
        tone_idx=((variable_dict[dataset_name]["metadata_all"]["fitzpatrick_scale"].fillna(-9).astype(int).isin([3,4]))|\
                (variable_dict[dataset_name]["metadata_all"]["skin_tone"].fillna(-9).astype(int).isin([34]))).values                            
    elif tone_split=="56":
        tone_idx=((variable_dict[dataset_name]["metadata_all"]["fitzpatrick_scale"].fillna(-9).astype(int).isin([5,6]))|\
                (variable_dict[dataset_name]["metadata_all"]["skin_tone"].fillna(-9).astype(int).isin([56]))).values                            
    else:
        ValueError
    group_info[tone_split]=tone_idx

record_dict_list_concept=evaluate_concept(
                 image_features_norm=variable_dict[dataset_name]["image_features_norm"],
                 image_features_norm_vanilla=variable_dict[dataset_name]["image_features_vanilla_norm"],
                 prompt_info=variable_dict[dataset_name]["prompt_info"],
                 prompt_info_vanilla=variable_dict[dataset_name]["prompt_info_vanilla"],
                 prompt_prefix=prompt_prefix,
                 group_info=group_info,
    
                 dataset_name=dataset_name,
                 dataloader=variable_dict[dataset_name]["dataloader"], 
                 metadata_all=variable_dict[dataset_name]["metadata_all"],
                 concept_list=concept_list,
                 valid_idx=valid_idx,
                 random_seed_range=range(20))

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["concept_eval"]=record_dict_list_concept

In [None]:
pd.DataFrame(xxxx)

In [None]:
record_dict_list_concept_df=pd.DataFrame(
                            variable_dict["clinical_fd_clean_nodup_nooverlap"]["concept_eval"])
record_dict_list_concept_df["subset_name"]="temp"

concept_name_filtered=record_dict_list_concept_df[
    (record_dict_list_concept_df["trained"]=="MONET")&
    (record_dict_list_concept_df["num_test_pos"]>=30)
]["concept_name"].unique().tolist()

print(len(concept_name_filtered))

print("mean AUC")
print(record_dict_list_concept_df[(record_dict_list_concept_df["concept_name"].isin(concept_name_filtered))&
                           (~record_dict_list_concept_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: x.groupby("concept_name")["auc"].mean().mean()))

print('\n\n')

print("std AUC")
print(record_dict_list_concept_df[(record_dict_list_concept_df["concept_name"].isin(concept_name_filtered))&
                           (~record_dict_list_concept_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: x.groupby("concept_name")["auc"].mean().std()))

print('\n\n')

print("count AUC")
print(record_dict_list_concept_df[(record_dict_list_concept_df["concept_name"].isin(concept_name_filtered))&
                           (~record_dict_list_concept_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: (x.groupby("concept_name")["auc"].mean()>=0.7).sum()))
print('\n\n')

print("total AUC")
print(record_dict_list_concept_df[(record_dict_list_concept_df["concept_name"].isin(concept_name_filtered))&
                           (~record_dict_list_concept_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: len(x.groupby("concept_name")["auc"].mean()>=0.7)))
print('\n\n')

In [None]:
print(record_dict_list_concept_df[(record_dict_list_concept_df["concept_name"].isin(concept_name_filtered))&
                           (~record_dict_list_concept_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: x.groupby("concept_name")["num_test"].mean().mean()))

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["dataloader"].dataset.metadata_all["fitzpatrick_scale"].value_counts()

In [None]:
dataset_name="derm7pt_derm_nodup"
concept_list=["derm7ptconcept_pigment network",
             "derm7ptconcept_regression structure",
             "derm7ptconcept_pigmentation",
             "derm7ptconcept_blue whitish veil",
             "derm7ptconcept_vascular structures",
             "derm7ptconcept_streaks",
             "derm7ptconcept_dots and globules"
             ]
valid_idx=variable_dict[dataset_name]["valid_idx"]
print(valid_idx.sum())

prompt_prefix=""

group_info={}

record_dict_list_concept=evaluate_concept(
                 image_features_norm=variable_dict[dataset_name]["image_features_norm"],
                 image_features_norm_vanilla=variable_dict[dataset_name]["image_features_vanilla_norm"],
                 prompt_info=variable_dict[dataset_name]["prompt_info"],
                 prompt_info_vanilla=variable_dict[dataset_name]["prompt_info_vanilla"],
                 prompt_prefix=prompt_prefix,
                 group_info=group_info,

                 dataset_name=dataset_name,
                 dataloader=variable_dict[dataset_name]["dataloader"], 
                 metadata_all=variable_dict[dataset_name]["metadata_all"],
                 concept_list=concept_list,
                 valid_idx=valid_idx,
                 random_seed_range=range(20))

In [None]:
record_dict_list_concept[0]

In [None]:
variable_dict["derm7pt_derm_nodup"]["concept_eval"]=record_dict_list_concept

In [None]:
variable_dict["derm7pt_derm_nodup"]["concept_eval"]

In [None]:
record_dict_list_concept_df=pd.DataFrame(
                            record_dict_list_concept)
record_dict_list_concept_df["subset_name"]="temp"

concept_name_filtered=record_dict_list_concept_df[
    (record_dict_list_concept_df["trained"]=="MONET")&
    (record_dict_list_concept_df["num_test_pos"]>=30)
]["concept_name"].unique().tolist()

print(len(concept_name_filtered))

print("mean AUC")
print(record_dict_list_concept_df[(record_dict_list_concept_df["concept_name"].isin(concept_name_filtered))&
                           (~record_dict_list_concept_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: x.groupby("concept_name")["auc"].mean().mean()))

print('\n\n')

print("std AUC")
print(record_dict_list_concept_df[(record_dict_list_concept_df["concept_name"].isin(concept_name_filtered))&
                           (~record_dict_list_concept_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: x.groupby("concept_name")["auc"].mean().std()))

print('\n\n')

print("count AUC")
print(record_dict_list_concept_df[(record_dict_list_concept_df["concept_name"].isin(concept_name_filtered))&
                           (~record_dict_list_concept_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: (x.groupby("concept_name")["auc"].mean()>=0.7).sum()))
print('\n\n')

print("total AUC")
print(record_dict_list_concept_df[(record_dict_list_concept_df["concept_name"].isin(concept_name_filtered))&
                           (~record_dict_list_concept_df["auc"].isnull())
                           ]\
.fillna("null").groupby(["subset_name", "manual_automatic", "trained", "group_key"])\
.apply(lambda x: len(x.groupby("concept_name")["auc"].mean()>=0.7)))
print('\n\n')
.783±0.085

In [None]:
evaluate_concept??

In [None]:
get_concept_bool_from_metadata??

In [None]:
variable_dict["derm7pt_derm_nodup"]["concept_list"]

In [None]:
record_dict_list_concept_df=pd.DataFrame(record_dict_list_concept)#.fillna("null")
#.groupby(["trained", "group_key"]).mean()

In [None]:
record_dict_list_concept_df[~record_dict_list_concept_df["auc"].isnull()]\
.fillna("null")\
.groupby(["trained", "group_key"]).mean()

In [None]:
torch.save(record_dict_list_concept, f"{dataset_name}.pt")

In [None]:
dataset_name

In [None]:
!gpustat

In [None]:
plt.scatter(variable_dict[dataset_name]["metadata_all"]["fitzpatrick_scale"].values,
        variable_dict[dataset_name]["metadata_all"]["fitzpatrick_centaur"])

In [None]:
variable_dict[dataset_name]["metadata_all"]["fitzpatrick_scale"].values,

In [None]:
variable_dict[dataset_name]["metadata_all"]["fitzpatrick_centaur"].values,

In [None]:
variable_dict[dataset_name]["metadata_all"]["fitzpatrick_centaur"].value_counts()

In [None]:
variable_dict[dataset_name]["metadata_all"]["fitzpatrick_scale"].value_counts()

In [None]:
variable_dict["fitzpatrick17k_skincon"]["metadata_all"]["fitzpatrick_scale"].value_counts()

In [None]:
df_temp=pd.read_csv("https://raw.githubusercontent.com/mattgroh/fitzpatrick17k/main/fitzpatrick17k.csv")

In [None]:
df_temp["fitzpatrick_centaur"].value_counts()

In [None]:
df_temp["fitzpatrick_scale"].value_counts()

In [None]:
variable_dict.keys()

In [None]:
variable_dict[dataset_name]["metadata_all"]["skin_tone"].value_counts()

In [None]:
variable_dict.keys()

In [None]:
variable_dict["fitzpatrick17k_clean_threelabel"]["dataloader"].dataset.metadata_all[concept_name].sum()

In [None]:
prompt_info_latex=[]
for concept_name in skincon_cols:
    
    prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
    prompt_engineered_list = []
    for k, v in prompt_dict.items():
        if k != "original":
            prompt_engineered_list += v    
    concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                              for prompt in prompt_engineered_list]))
    prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
    
    fiz_count=variable_dict["fitzpatrick17k_clean_threelabel"]["dataloader"].dataset.metadata_all[concept_name].sum()
    ddi_count=variable_dict["ddi"]["dataloader"].dataset.metadata_all[concept_name].sum()
    prompt_info_latex.append({"concept_name": concept_name[8:], 
                              "term": concept_term_list
                             })
    
    #print(, fiz_count, ddi_count, concept_term_list)
#     print(concept_name)

In [None]:
pd.DataFrame(prompt_info_latex).to_latex?

In [None]:
print(pd.DataFrame(prompt_info_latex).sort_values("concept_name").to_latex(index=False).replace('[','').replace(']',''))