In [None]:
import os

import hydra
import omegaconf
import pyrootutils

root = pyrootutils.setup_root(os.path.abspath("data_auditing.ipynb"), pythonpath=True)

os.chdir(root)

In [None]:
import sys

sys.path.append(str(root / "src"))

In [None]:
import os
import time
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib import gridspec
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn.metrics
import torch
import tqdm
from IPython import display
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

import clip

In [None]:
# !cp /System/Library/Fonts/Supplemental ~/.local/share/fonts/
# rm -fr ~/.cache/matplotlib
from matplotlib import font_manager
from matplotlib.lines import Line2D
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.edgecolor']='1.0'
plt.rcParams['legend.framealpha']=0

# https://github.com/dsc/colorbrewer-python/blob/master/colorbrewer.py

Set1 = {
    3: [[228,26,28], [55,126,184], [77,175,74]],
    4: [[228,26,28], [55,126,184], [77,175,74], [152,78,163]],
    5: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0]],
    6: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51]],
    7: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40]],
    8: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191]],
    9: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191], [153,153,153]],
}

Paired = {
    3: [(166,206,227), [31,120,180], [178,223,138]],
    4: [[166,206,227], [31,120,180], [178,223,138], [51,160,44]],
    5: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153]],
    6: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28]],
    7: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111]],
    8: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0]],
    9: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214]],
    10: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154]],
    11: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153]],
    12: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153], [177,89,40]]
}

color_qual_7=['#F53345',
            '#87D303',
            '#04CBCC',
            '#8650CD',
            (160/256, 95/256, 0),
            '#F5A637',              
            '#DBD783',            
             ]

pd.set_option('display.max_rows', 500)

In [None]:
import scipy.special
import tqdm.contrib.concurrent

In [None]:
# import importlib
# importlib.reload(sys.modules["MONET.utils.static"])
# from MONET.utils.static import (
#     concept_to_prompt)

In [None]:
from MONET.datamodules.multiplex_datamodule import MultiplexDatamodule
from MONET.utils.loader import custom_collate_per_key, dataloader_apply_func
from MONET.utils.metrics import skincon_calcualte_auc_all
from MONET.utils.static import (
    concept_to_prompt,
    fitzpatrick17k_disease_label,
    fitzpatrick17k_ninelabel,
    fitzpatrick17k_threelabel,
    skincon_cols,
)
from MONET.utils.text_processing import generate_prompt_token_from_concept

In [None]:
def wandb_to_exppath(wandb, log_path="/gscratch/cse/chanwkim/MONET_log/train/runs"):
    log_path = Path(log_path)
    for experiment in os.listdir(log_path):
        if os.path.exists(log_path / experiment / "wandb"):
            filenames = os.listdir(log_path / experiment / "wandb")
            filename = [filename for filename in filenames if filename.startswith("run")][0][-8:]
            if filename == wandb:
                return log_path / experiment
    raise RuntimeError("not found")


exppath = wandb_to_exppath(
    wandb="baqqmm5v", log_path="/projects/leelab2/chanwkim/dermatology_datasets/logs/train/runs"
)
print([exppath / "checkpoints" / ckpt for ckpt in os.listdir(exppath / "checkpoints/")])

In [None]:
variable_dict={}

In [None]:
def setup_dataloader(dataset_name):    
    if dataset_name=="clinical_fd_clean_nodup_nooverlap":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "clinical_fd_clean_nodup_nooverlap=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()  
        
        
    elif dataset_name=="derm7pt_derm_nodup":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt_derm_nodup=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()    
        
    elif dataset_name=="derm7pt_clinical_nodup":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt_clinical_nodup=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()   
        
    elif dataset_name=="derm7pt":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()
    
        
    elif dataset_name=="isic_nodup_nooverlap":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "isic_nodup_nooverlap=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()    
        
    return {"dataloader": dataloader}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap"
                    ]:
    variable_dict.setdefault(dataset_name, {})
    variable_dict[dataset_name].update(setup_dataloader(dataset_name))

In [None]:
# for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "isic"]:  
#     variable_dict.setdefault(dataset_name, {})
#     variable_dict[dataset_name].update(setup_dataloader(dataset_name))

In [None]:
!gpustat

# initialize model

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:2"

In [None]:
cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

In [None]:
model = hydra.utils.instantiate(cfg_model)
model.to(model_device)
model.eval()

In [None]:
model_path_dir = {
    "zt0n2xd0": "logs/train/runs/2023-01-17_20-58-15/checkpoints/last.ckpt",
}
if model_name != "ViT-L/14":
    model_path = model_path_dir[model_name]
    loaded = torch.load(model_path, map_location=model_device)
    model.load_state_dict(loaded["state_dict"])
    print('loaded')

In [None]:
# torch.save(loaded["state_dict"], "/cse/web/research/aimslab/MONET/weight.pt")

# Image features

In [None]:
def batch_func(batch):
    with torch.no_grad():
        batch["image"] = batch["image"].to(model_device)
        image_features = model.model_step_with_image(batch)["image_features"]
    # print(batch["metadata"])
    return {
        "image_features": image_features.detach().cpu(),
        "metadata": batch["metadata"],
    }

In [None]:
def setup_features(dataset_name, dataloader):
    if dataset_name=="isic_nodup_nooverlap":
        loader_applied=torch.load(log_dir/"image_features"/"isic_nodup_nooverlap.pt", map_location="cpu")
        image_features = loader_applied["image_features"].cpu()
        metadata_all = loader_applied["metadata_all"]        
    else:
        loader_applied = dataloader_apply_func(
            dataloader=dataloader,
            func=batch_func,
            collate_fn=custom_collate_per_key,
        )
        image_features = loader_applied["image_features"].cpu()
        metadata_all = loader_applied["metadata"]

    return {"image_features":image_features, "metadata_all": metadata_all}

In [None]:
log_dir = Path("logs")

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap"
                    ]:
    print(dataset_name)
    variable_dict[dataset_name].update(setup_features(dataset_name, variable_dict[dataset_name]["dataloader"])) 

In [None]:
diagnosis_malignant_mapping=\
{'AIMP':'indeterminate',
'acrochordon':'benign',
'actinic keratosis':'benign', # 
'angiofibroma or fibrous papule':'benign', 
'angiokeratoma':'benign',
'angioma':'benign',
'atypical melanocytic proliferation':'indeterminate',
'atypical spitz tumor':'indeterminate', #
'basal cell carcinoma':'malignant', #
'cafe-au-lait macule':'benign',
'clear cell acanthoma':'benign',
'dermatofibroma':'benign', #
'lentigo NOS':'benign',
'lentigo simplex':'benign',
'lichenoid keratosis':'benign',
'melanoma':'malignant',
'melanoma metastasis':'malignant',
'neurofibroma':'benign',
'nevus':'benign',
'other':'indeterminate',
'pigmented benign keratosis':'benign', #??
'scar':'benign',
'seborrheic keratosis':'benign',
'solar lentigo':'benign',
'squamous cell carcinoma':'malignant',
'vascular lesion':'unknown', # 
'verruca':'benign'
}

def map_diagnosis_malignant(diagnosis, benign_malignant):
#     if diagnosis=="basal cell carcinoma":
#         print(diagnosis_malignant_mapping[diagnosis], benign_malignant)    
    if isinstance(benign_malignant, str):
        return benign_malignant
    elif diagnosis in diagnosis_malignant_mapping.keys():
        return diagnosis_malignant_mapping[diagnosis]
    elif np.isnan(diagnosis):
        return "indeterminate"
    else:
        raise RuntimeError

In [None]:
def set_config(dataset_name, metadata_all):
    if "clinical_fd_clean" in dataset_name:
        y_pos=(((metadata_all["source"]=="fitz")&(metadata_all["three_partition_label"]=="malignant"))|
              ((metadata_all["source"]=="ddi")&(metadata_all["malignant"] == True))).values
        
        valid_idx=(metadata_all["skincon_Do not consider this image"]!=1).values
        
        concept_list=skincon_cols
        
    elif "isic" in dataset_name:  
        metadata_all["benign_malignant_full"]=\
        metadata_all.apply(lambda x: map_diagnosis_malignant(x["diagnosis"], x["benign_malignant"]), axis=1)
        #metadata_all["benign_malignant_full"].value_counts()
        #metadata_all.groupby("diagnosis").apply(lambda x: x["benign_malignant_full"].value_counts())
        metadata_all["benign_malignant_bool"]=metadata_all["benign_malignant_full"].str.contains("malignant")
        
        y_pos=metadata_all["benign_malignant_bool"].values
        
        valid_idx = (metadata_all["benign_malignant_full"].str.contains("malignant")|metadata_all["benign_malignant_full"].str.contains("benign")).values
        
        concept_list=skincon_cols
        
        concept_list=concept_list+["purple pen", 
                     "nail", 
                     "pinkish", 
                     "red", 
                     "hair", 
                     "orange sticker", 
                     "dermoscope border",
                     "gel",]   
        
        concept_list=concept_list+[f"derm7ptconcept_{derm7ptconcept}" for derm7ptconcept in ["pigment network", "typical pigment network", "atypical pigment network",
                                   "regression structure",
                                   "pigmentation", "regular pigmentation", "irregular pigmentation",
                                   "blue whitish veil", 
                                   "vascular structures", "typical vascular structures", "atypical vascular structures",
                                   "streaks", "regular streaks", "irregular streaks",
                                   "dots and globules", "regular dots and globules", "irregular dots and globules",
                                  ]]        
        
        concept_list=concept_list+[f"isicconcept_{isicconcept}" for isicconcept in ["pigment_network", 
                                                                                   "negative_network",
                                                                                   "milia_like_cyst", 
                                                                                   "streaks", 
                                                                                   "globules"]]
        
        concept_list=concept_list+[f"disease_{disease_name}" for disease_name in ['seborrheic keratosis', 'nevus', 'squamous cell carcinoma',
                        'melanoma', 'lichenoid keratosis', 'lentigo',
                        'actinic keratosis', 'basal cell carcinoma', 'dermatofibroma',
                        'atypical melanocytic proliferation', 'verruca',
                        'clear cell acanthoma', 'angiofibroma or fibrous papule', 'scar',
                        'angioma', 'atypical spitz tumor', 'solar lentigo', 'AIMP',
                        'neurofibroma', 'lentigo simplex', 'acrochordon', 
                        'angiokeratoma', 'vascular lesion', 'cafe-au-lait macule',
                        'pigmented benign keratosis']]
        
        
    elif "derm7pt" in dataset_name:  
        metadata_all["benign_malignant_full"]=metadata_all["diagnosis"]
        #metadata_all["benign_malignant_full"].value_counts()
        #metadata_all.groupby("diagnosis").apply(lambda x: x["benign_malignant_full"].value_counts())
        metadata_all["benign_malignant_bool"]=metadata_all["benign_malignant_full"].str.contains("malignant")
        
        y_pos=metadata_all["benign_malignant_bool"].values
        
        valid_idx = (~metadata_all["diagnosis"].isnull()).values
        
        concept_list=skincon_cols
        
        concept_list=concept_list+["purple pen", 
                             "nail", 
                             "pinkish", 
                             "red", 
                             "hair", 
                             "orange sticker", 
                             "dermoscope border",
                             "gel",]    
        
        concept_list=concept_list+[f"derm7ptconcept_{derm7ptconcept}" for derm7ptconcept in ["pigment network", "typical pigment network", "atypical pigment network",
                                   "regression structure",
                                   "pigmentation", "regular pigmentation", "irregular pigmentation",
                                   "blue whitish veil", 
                                   "vascular structures", "typical vascular structures", "atypical vascular structures",
                                   "streaks", "regular streaks", "irregular streaks",
                                   "dots and globules", "regular dots and globules", "irregular dots and globules",
                                  ]]
        
        concept_list=concept_list+[f"disease_{disease_name}" for disease_name in ['basal cell carcinoma', 'blue nevus', 'clark nevus',
                                                               'combined nevus', 'congenital nevus', 'dermal nevus',
                                                               'dermatofibroma', 'lentigo', 'melanoma', 'melanosis',
                                                                'recurrent nevus', 'reed or spitz nevus',
                                                               'seborrheic keratosis', 'vascular lesion']]        
        
    return {"valid_idx": valid_idx,
            "y_pos": y_pos,
            "metadata_all": metadata_all,
            "concept_list": concept_list}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap"
                    ]:
    variable_dict[dataset_name].update(
        set_config(dataset_name, variable_dict[dataset_name]["metadata_all"])
    )

In [None]:
variable_dict["isic_nodup_nooverlap"]["metadata_all"][variable_dict["isic_nodup_nooverlap"]["metadata_all"]["benign_malignant"]!=variable_dict["isic_nodup_nooverlap"]["metadata_all"]["benign_malignant_full"]]["diagnosis"].value_counts()

In [None]:
def normalize_embedding(dataset_name, image_features):
    #prompt_ref_tokenized = clip.tokenize(prompt_ref, truncate=True)
    #output = model.model_step_with_text({"text": prompt_ref_tokenized.to(model_device)})
    #prompt_ref_embedding=output["text_features"].detach().cpu()
    #prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=1, keepdim=True)      
    
    image_features_norm = image_features / image_features.norm(dim=1, keepdim=True)
    
    return {"image_features_norm": image_features_norm}    

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap","derm7pt_derm_nodup",
                     "isic_nodup_nooverlap"
                    ]:
    variable_dict[dataset_name].update(
    normalize_embedding(dataset_name, 
                        variable_dict[dataset_name]["image_features"])
    )

In [None]:
def get_concept_embedding(dataset_name, concept_list):
    prompt_info={}
    
    for concept_name in concept_list:
        if "clinical_fd_clean" in dataset_name:
            prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
            prompt_engineered_list = []
            for k, v in prompt_dict.items():
                if k != "original":
                    prompt_engineered_list += v    
            concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                      for prompt in prompt_engineered_list]))
            prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
            #prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]            
            prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
            
            prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]        
        
        elif "isic" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
                
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil","blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)         
                    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                  
                
            elif concept_name.startswith("isicconcept_"):
                isicconcept=concept_name[12:]
                if isicconcept=="pigment_network":
                    concept_term_list=["pigment network"]
                elif isicconcept=="negative_network":
                    concept_term_list=["negative network"]
                elif isicconcept=="milia_like_cyst":
                    concept_term_list=["milia like cyst"]
                    concept_term_list=["seborrheic keratosis"]
                elif isicconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif isicconcept=="globules":
                    concept_term_list=["globules"]
                else:
                    raise ValueError(isicconcept)                
            
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                 
                
            elif concept_name.startswith("disease_"):  
                if concept_name=="disease_AIMP":
                    disease_name=concept_name[8:]
                    prompt_target=[["This is dermatoscopy of AIMP",
                                    "This is dermatoscopy of Atypical intraepidermal melanocytic proliferation"],
                                   ["This is dermoscopy of AIMP",
                                    "This is dermoscopy of Atypical intraepidermal melanocytic proliferation"]]
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                else:
                    disease_name=concept_name[8:]
                    prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                                   [f"This is dermoscopy of {disease_name}"]] 
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                    
        elif "derm7pt_derm" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil","blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)
    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
            elif concept_name.startswith("disease_"):  
                disease_name=concept_name[8:]
                prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                               [f"This is dermoscopy of {disease_name}"]] 
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]                    
                
        
        #print(prompt_target, prompt_ref)
        # target embedding
        prompt_target_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_target]
        with torch.no_grad():
            prompt_target_embedding = torch.stack([model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_target_tokenized])
        prompt_target_embedding_norm=prompt_target_embedding/prompt_target_embedding.norm(dim=2, keepdim=True)          

        # reference embedding
        prompt_ref_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_ref]
        with torch.no_grad():
            prompt_ref_embedding = torch.stack([model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_ref_tokenized])
        prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=2, keepdim=True)                

        prompt_info[concept_name]={"prompt_ref_embedding_norm":prompt_ref_embedding_norm,
                                   "prompt_target_embedding_norm":prompt_target_embedding_norm,
                                  }
        print(dataset_name, concept_name, prompt_target, prompt_ref)
        print(prompt_target_embedding_norm)
        print(prompt_ref_embedding_norm)
        del prompt_ref
        del prompt_target
    
    return {"prompt_info": prompt_info}

for dataset_name in ["clinical_fd_clean_nodup_nooverlap","derm7pt_derm_nodup",
                     "isic_nodup_nooverlap"]:
# for dataset_name in ["isic_nodup_nooverlap"]:
    variable_dict[dataset_name].update(
        get_concept_embedding(dataset_name, 
                      concept_list=variable_dict[dataset_name]["concept_list"]))

In [None]:
!ls logs/experiment_results/data_audit.pt

In [None]:
variable_dict=torch.load("logs/experiment_results/data_audit_new_0429.pt", map_location="cpu")

In [None]:
#torch.save(variable_dict, "logs/experiment_results/data_audit_new.pt")

In [None]:
# torch.save(variable_dict, "logs/experiment_results/data_audit_new_0429.pt")

In [None]:
variable_dict=torch.load("logs/experiment_results/data_audit_new.pt", map_location="cpu")

In [None]:
variable_dict.keys()

In [None]:
variable_dict["isic"]["image_features_norm"]

In [None]:
variable_dict["isic"]["prompt_info"]["skincon_Vesicle"]["prompt_target_embedding_norm"]

# concept retrieval

In [None]:
def calculate_similaity_score(image_features_norm, 
                              prompt_target_embedding_norm,
                              prompt_ref_embedding_norm,
                              temp=1,
                              normalize=True):

    target_similarity=prompt_target_embedding_norm.float()@image_features_norm.T.float()
    ref_similarity=prompt_ref_embedding_norm.float()@image_features_norm.T.float()
    
    
    target_similarity_mean=target_similarity.mean(dim=[1])
    ref_similarity_mean=ref_similarity.mean(axis=1)
         
    if normalize:
        similarity_score=scipy.special.softmax([target_similarity_mean.numpy()/temp, 
                            ref_similarity_mean.numpy()/temp], axis=0)[0,:].mean(axis=0)   
    else:
        similarity_score=target_similarity_mean.mean(axis=0)

    
    return similarity_score

In [None]:
def shorten_concept_name(concept_name):
    if concept_name=="skincon_Erythema":
        short_name="Erythema"
    elif concept_name=="skincon_Bulla":
        short_name="Bulla"
    elif concept_name=="skincon_Lichenification":
        short_name="Lichenification"
    elif concept_name=="skincon_Pustule":
        short_name="Pustule"
    elif concept_name=="skincon_Ulcer":
        short_name="Ulcer"
    elif concept_name=="skincon_Warty/Papillomatous":
        short_name="Warty"
    elif concept_name=="White(Hypopigmentation)":
        short_name="Hypopigmentation"
    elif concept_name=="purple pen":
        short_name="Purple pen"
    elif concept_name=="nail":
        short_name="Nail"  
    elif concept_name=="orange sticker":
        short_name="Orange sticker"          
    elif concept_name=="hair":
        short_name="Hair"          
    elif concept_name=="gel":
        short_name="Gel"
    elif concept_name=="dermoscope border":
        short_name="Dermoscopic border"
    elif concept_name=="pinkish":
        short_name="Pink"        
       
    else:
        if concept_name.startswith("skincon_"):
            short_name=concept_name[8:]
        elif concept_name.startswith("derm7ptconcept_"):
            short_name=concept_name[15:]   
        elif concept_name.startswith("isicconcept_"):
            short_name=concept_name[12:]
        else:
            short_name=concept_name
            
    return short_name

In [None]:
def check_image(dataset_name, idx):
    if "clinical_fd_clean_nodup" in dataset_name:
        if idx in ["58b4bc079ca94e6e9377a42ca7564b40.jpg",
         "720cf31558966c82c118ab75b50632eb.jpg",
         "5f046cda32a3cc547205662e7be774f9.jpg",
         "d8bf377acc45a3beb0c6e81bf7ac1ff5.jpg"]:
            return False
        else:
            return True            
    else:
        return True

In [None]:
from collections import OrderedDict
collection_id_mapping=OrderedDict(((61, "Challenge 2016: Test"),
(74, "Challenge 2016: Training"),
(69, "Challenge 2017: Test"),
(60, "Challenge 2017: Training"),
(71, "Challenge 2017: Validation"),
(63, "Challenge 2018: Task 1-2: Training"),
(66, "Challenge 2018: Task 3: Training"),
(65, "Challenge 2019: Training"),
(70, "Challenge 2020: Training"),
(97, "Collection for ISBI 2016: 100 Lesion Classification"),
(162, "Consecutive biopsies for melanoma across year 2020"),
(75, "Consumer AI apps"),
(166, "EASY Dermoscopy Expert Agreement Study"),
(168, "Longitudinal overview images of posterior trunks"),
(163, "MSKCC Consecutive biopsies across year 2020_cohort"),
(77, "Melanocytic lesions used for dermoscopic feature annotations"),
(170, "PROVe-AI"),
# (172, "screenshot_public_230207"))
))
def shorten_collection_name(name):
    return name.replace("Challenge ","")
def shorten_attribution_name(name, very_short=True):
    if very_short:
        mapping_dict={'Hospital Clínic de Barcelona': "Barcelona",
        'ViDIR Group, Department of Dermatology, Medical University of Vienna': "Vienna",
        'Department of Dermatology, Hospital Clínic de Barcelona': "Barcelona2",
        'Pascale Guitera': "Pascale Guitera",
        'MSKCC': "MSKCC1",
        'Memorial Sloan Kettering Cancer Center': "MSKCC2"}        
    else:
        mapping_dict={'Hospital Clínic de Barcelona': "Barcelona",
        'ViDIR Group, Department of Dermatology, Medical University of Vienna': "Med Univ. of Vienna",
        'Department of Dermatology, Hospital Clínic de Barcelona': "Barcelona Hospital2",
        'Pascale Guitera': "Pascale Guitera",
        'MSKCC': "MSKCC1",
        'Memorial Sloan Kettering Cancer Center': "MSKCC2"}
    
    if name in mapping_dict.keys():
        return mapping_dict[name]
    else:
        return name

In [None]:
def make_collection_statistics(dataset_name, metadata_all, valid_idx):    
    if dataset_name=="isic_nodup_nooverlap":
        metadata_all["attribution"]=metadata_all["attribution"].str.replace("ViDIR group", "ViDIR Group")
        columns_collection=["collection_"+str(key) for key in collection_id_mapping.keys()]

        collection_statistics_malignant=\
        metadata_all[valid_idx][metadata_all[valid_idx]["benign_malignant_bool"]].groupby("attribution").apply(lambda x: x[columns_collection].sum())
        collection_statistics_benign=\
        metadata_all[valid_idx][~metadata_all[valid_idx]["benign_malignant_bool"]].groupby("attribution").apply(lambda x: x[columns_collection].sum())
        collection_statistics=(collection_statistics_malignant.astype(str)+" / "+collection_statistics_benign.astype(str)).fillna("0 / 0")
        
#         collection_statistics.columns=\
#         collection_statistics.columns.map(lambda x: collection_id_mapping[int(x.split('_')[-1])])
#         collection_statistics
    
    return {"collection_statistics": collection_statistics}

In [None]:
for dataset_name in ["isic_nodup_nooverlap"]: 
    x=make_collection_statistics(
        dataset_name=dataset_name,
        metadata_all=variable_dict[dataset_name]["metadata_all"],
        valid_idx=variable_dict[dataset_name]["valid_idx"])
    variable_dict[dataset_name].update(x)
# variable_dict[dataset_name]["collection_statistics"].rename(columns={f"collection_{key}":value for key, value in collection_id_mapping.items()})    

In [None]:
variable_dict.keys()

In [None]:
#variable_dict['isic']["collection_statistics"]

In [None]:
def get_subset_index(dataset_name, metadata_all, attribution):
    if dataset_name=="isic_nodup_nooverlap":
        if attribution=="all":
            #pd.Series([True], index=metadata_all)
            subset_idx=np.array([True]*len(metadata_all))
        else:
            collection_65=(metadata_all["collection_65"]==1).values
            
            if attribution=="barcelona_all":
                subset_idx=((metadata_all["attribution"]=="Department of Dermatology, Hospital Clínic de Barcelona")|(metadata_all["attribution"]=="Hospital Clínic de Barcelona")).values
            elif attribution=="mskcc_all":
                subset_idx=((metadata_all["attribution"]=="MSKCC")|(metadata_all["attribution"]=="Memorial Sloan Kettering Cancer Center")).values
            else:
                subset_idx=(metadata_all["attribution"]==attribution).values          
                
            subset_idx=subset_idx&collection_65
        #for attribution in [None]+["barcelona", "vienna", "barcelona_all"]:
        #for attribution in ["all"]+list(metadata_all["attribution"].unique())+["barcelona_all", "mskcc_all"]:
  
                    
    return subset_idx

In [None]:
def diff_test(image_features_norm, embedding_dict, idx_pos, idx_neg):
    image_features_norm_pos=image_features_norm[idx_pos].float().numpy()
    image_features_norm_neg=image_features_norm[idx_neg].float().numpy()

    image_features_norm_pos_mean=image_features_norm_pos.mean(axis=0)
    image_features_norm_neg_mean=image_features_norm_neg.mean(axis=0)

    image_features_norm_diff=(image_features_norm_pos_mean-image_features_norm_neg_mean)            

    result_dict={}
    for concept_name, embedding_norm in embedding_dict.items():
        embedding_norm=embedding_dict[concept_name]
        embedding_norm_mean=embedding_norm.numpy().mean(axis=(0,1))
        diff_score=(image_features_norm_diff@embedding_norm_mean).item()
        result_dict[concept_name]=diff_score
        
    return result_dict

In [None]:
def make_concept_diff_sole(dataset_name, image_features_norm, metadata_all, y_pos, valid_idx, prompt_info, attribution_list):
    if dataset_name=="isic_nodup_nooverlap":
        diff_score_dict={}
        #for attribution in [None]+["barcelona", "vienna", "barcelona_all"]:
        #for attribution in ["all"]+list(metadata_all["attribution"].unique())+["barcelona_all", "mskcc_all"]:
        for attribution in attribution_list:
            subset_idx=get_subset_index(dataset_name=dataset_name, metadata_all=metadata_all, attribution=attribution)
            
            idx_pos=(y_pos==True)&(valid_idx==True)&(subset_idx)
            idx_neg=(y_pos==False)&(valid_idx==True)&(subset_idx)
            
            if idx_pos.sum()==0 or idx_neg.sum()==0:
                print(f"- {attribution} pos: {(idx_pos.sum())} neg: {(idx_neg.sum())}")
                continue
            else:
                print(f"+ {attribution} pos: {(idx_pos.sum())} neg: {(idx_neg.sum())}")
                                             
            result_dict=diff_test(image_features_norm=image_features_norm, embedding_dict={concept_name:prompt_info[concept_name]["prompt_target_embedding_norm"] for concept_name in prompt_info.keys()},
                                  idx_pos=idx_pos, idx_neg=idx_neg)
                
            diff_score_dict[attribution]=result_dict
        
    return {"concept_diff_sole": diff_score_dict}

for dataset_name in ["isic_nodup_nooverlap"]: 
    x=make_concept_diff_sole(dataset_name=dataset_name, 
                      image_features_norm=variable_dict[dataset_name]["image_features_norm"],                      
                      metadata_all=variable_dict[dataset_name]["metadata_all"],
                      y_pos=variable_dict[dataset_name]["y_pos"],
                      valid_idx=variable_dict[dataset_name]["valid_idx"],
                      prompt_info=variable_dict[dataset_name]["prompt_info"],   
                      attribution_list=["all"]+['Memorial Sloan Kettering Cancer Center',
                                       'Attributed to Konstantinos Liopyris',
                                       'For educational purpose only',
                                       'Department of Dermatology, Hospital Clínic de Barcelona',
                                       'Pascale Guitera',
                                       'The University of Queensland Diamantina Institute, The University of Queensland, Dermatology Research Centre',
                                       'ViDIR Group, Department of Dermatology, Medical University of Vienna',
                                       'MSKCC', 'Hospital Clínic de Barcelona', 'Anonymous',
                                       'Dermoscopedia']+["barcelona_all", "mskcc_all"])
    variable_dict[dataset_name].update(x)

In [None]:
6097+62055

In [None]:
1824+8049

In [None]:
9990+60525

In [None]:
6021+6250

In [None]:
1824+8187

In [None]:
na, malignant: n = 1, 824 / benign: n = 8, 049

In [None]:
malignant: n = 6, 097 / benign: n = 6, 205)

In [None]:
def make_concept_diff_cross(dataset_name, image_features_norm, metadata_all, y_pos, valid_idx, prompt_info, attribution_pair_list):
    diff_score_dict={}
    for attribution_name_1, attribution_name_2 in attribution_pair_list:
        subset_idx_1=get_subset_index(dataset_name=dataset_name, metadata_all=metadata_all, attribution=attribution_name_1)
        subset_idx_2=get_subset_index(dataset_name=dataset_name, metadata_all=metadata_all, attribution=attribution_name_2)
        
        idx_pos=(valid_idx==True)&(subset_idx_1)
        idx_neg=(valid_idx==True)&(subset_idx_2) 
        
        if idx_pos.sum()==0 or idx_neg.sum()==0:
            print(f"- {(attribution_name_1, attribution_name_2)} pos: {(idx_pos.sum())} neg: {(idx_neg.sum())}")
            continue
        else:
            print(f"+ {(attribution_name_1, attribution_name_2)} pos: {(idx_pos.sum())} neg: {(idx_neg.sum())}")        
        
        result_dict=diff_test(image_features_norm=image_features_norm, embedding_dict={concept_name:prompt_info[concept_name]["prompt_target_embedding_norm"] for concept_name in prompt_info.keys()},
                              idx_pos=idx_pos, idx_neg=idx_neg)
        diff_score_dict[(attribution_name_1, attribution_name_2)]=result_dict
        
    return {"concept_diff_cross": diff_score_dict}

for dataset_name in ["isic_nodup_nooverlap"]: 
    x=make_concept_diff_cross(dataset_name=dataset_name, 
                      image_features_norm=variable_dict[dataset_name]["image_features_norm"],                      
                      metadata_all=variable_dict[dataset_name]["metadata_all"],
                      y_pos=variable_dict[dataset_name]["y_pos"],
                      valid_idx=variable_dict[dataset_name]["valid_idx"],
                      prompt_info=variable_dict[dataset_name]["prompt_info"],   
                      attribution_pair_list=[("Hospital Clínic de Barcelona", "ViDIR Group, Department of Dermatology, Medical University of Vienna")],
                     )
    variable_dict[dataset_name].update(x)    

In [None]:
diff_temp=pd.Series(variable_dict[dataset_name]["concept_diff_sole"]['Hospital Clínic de Barcelona'])\
-pd.Series(variable_dict[dataset_name]["concept_diff_sole"]['ViDIR Group, Department of Dermatology, Medical University of Vienna'])

In [None]:
# mskcc_all: orange sticker (-) red sticker (-)
# Memorial Sloan Kettering Cancer Center (162): purple pen (-) orange sticker (-)
# MSKCC (70): orange sticker (-)
    
# Vienna (65): purple pen (-) orange sticker (-)

# barcelona_all: red sticker(+)
# 65 
# 70

# variable_dict["isic"]["diff_score_collections"]["mskcc_all"].sort_values()
# variable_dict["isic"]["diff_score_collections"]["MSKCC"].sort_values()
# variable_dict["isic"]["diff_score_collections"]["Memorial Sloan Kettering Cancer Center"].sort_values()
# variable_dict["isic"]["diff_score_collections"]["ViDIR Group, Department of Dermatology, Medical University of Vienna"].sort_values()
# variable_dict["isic"]["diff_score_collections"]["barcelona_all"].sort_values()
# variable_dict["isic"]["diff_score_collections"]["Hospital Clínic de Barcelona"].sort_values()
# variable_dict["isic"]["diff_score_collections"]["Department of Dermatology, Hospital Clínic de Barcelona"].sort_values()

In [None]:
# def check_concept_name(dataset_name, concept_name):
#     if dataset_name=="isic":
#         if concept_name.startswith("disease"):
#             return False
#         elif concept_name in ["melanoma", "red sticker"]:
#             return False
#         else:
#             return True
#     else:
#         raise NotImplemented(dataset_name)
        
def check_concept_name(dataset_name, concept_name):
    if "isic" in dataset_name:
        if concept_name.startswith("disease"):
            return False
        elif concept_name.startswith("isicconcept_"):
            return False        
        elif concept_name.startswith("derm7ptconcept_"):
            if 'typical' in concept_name:
                return False                
            elif 'regular' in concept_name:
                return False
            elif "pigmentation" in concept_name:
                return False
            else:
                return True
        elif concept_name in ["melanoma", "red sticker"]:
            return False
        else:
            return True
        
    elif "clinical_fd_clean" in dataset_name:
        if concept_name.startswith("disease"):
            return False
        else:
            return True
    else:
        raise NotImplemented(dataset_name)        

# def shorten_concept_name(concept_name):
#     if concept_name=="skincon_Erythema":
#         short_name="Erythema"
#     elif concept_name=="skincon_Bulla":
#         short_name="Bulla"
#     elif concept_name=="skincon_Lichenification":
#         short_name="Lichenification"
#     elif concept_name=="skincon_Pustule":
#         short_name="Pustule"
#     elif concept_name=="skincon_Ulcer":
#         short_name="Ulcer"
#     elif concept_name=="skincon_Warty/Papillomatous":
#         short_name="Warty"
#     elif concept_name=="skincon_White(Hypopigmentation)":
#         short_name="Hypopigmentation"
#     elif concept_name=="skincon_Brown(Hyperpigmentation)":
#         short_name="Hyperpigmentation"
#     elif concept_name=="purple pen":
#         short_name="Purple pen"
#     elif concept_name=="nail":
#         short_name="Nail"  
#     elif concept_name=="orange sticker":
#         short_name="Orange sticker"          
#     elif concept_name=="hair":
#         short_name="Hair"          
#     elif concept_name=="gel":
#         short_name="Gel"
#     elif concept_name=="red":
#         short_name="Red"      
#     elif concept_name=="dermoscope border":
#         short_name="Dermoscope border"
#     elif concept_name=="pinkish":
#         short_name="Pinkish"        
#     else:
#         if concept_name.startswith("skincon_"):
#             short_name=concept_name[8:]
#         else:
#             raise NotImplementedError(concept_name)
            
#     return short_name

def shorten_concept_name(concept_name):
    if concept_name=="skincon_Erythema":
        short_name="Erythema"
    elif concept_name=="skincon_Bulla":
        short_name="Bulla"
    elif concept_name=="skincon_Lichenification":
        short_name="Lichenification"
    elif concept_name=="skincon_Pustule":
        short_name="Pustule"
    elif concept_name=="skincon_Ulcer":
        short_name="Ulcer"
    elif concept_name=="skincon_Warty/Papillomatous":
        short_name="Warty"
    elif concept_name=="skincon_White(Hypopigmentation)":
        short_name="Hypopigmentation"
    elif concept_name=="skincon_Brown(Hyperpigmentation)":
        short_name="Hyperpigmentation"
    elif concept_name=="purple pen":
        short_name="Purple pen"
    elif concept_name=="nail":
        short_name="Nail"  
    elif concept_name=="orange sticker":
        short_name="Orange sticker"          
    elif concept_name=="hair":
        short_name="Hair"          
    elif concept_name=="gel":
        short_name="Gel"
    elif concept_name=="pinkish":
        short_name="Pink"        
    elif concept_name=="red":
        short_name="Red"      
    elif concept_name=="dermoscope border":
        short_name="Dermoscope border"
    elif concept_name=="pinkish":
        short_name="Pinkish"  
    elif concept_name=="derm7ptconcept_blue whitish veil":
        short_name="Blue whitish veil"   
    elif concept_name=="derm7ptconcept_pigment network":
        short_name="Pigment network"  
    elif concept_name=="derm7ptconcept_vascular structures":
        short_name="Vascular structures"          
    elif concept_name=="derm7ptconcept_regression structure":
        short_name="Regression structure"        
    
    
    else:
        if concept_name.startswith("skincon_"):
            short_name=concept_name[8:]
        else:
            raise NotImplementedError(concept_name)
            
    return short_name

def shorten_hospital_name(hospital_name):
    if hospital_name=="Hospital Clínic de Barcelona":
        short_name="Hospital Clínic de Barcelona"
    elif hospital_name=="ViDIR Group, Department of Dermatology, Medical University of Vienna":
        short_name="Medical University of Vienna"    
    return short_name

In [None]:
concept_diff_all=pd.Series(variable_dict["isic_nodup_nooverlap"]["concept_diff_sole"]["all"])
concept_diff_all_filtered=concept_diff_all[concept_diff_all.index.map(lambda x: check_concept_name("isic_nodup_nooverlap", x))]
concept_diff_all_filtered=concept_diff_all_filtered.loc[concept_diff_all_filtered.abs().sort_values(ascending=False).iloc[:-5].index]
concept_diff_all_filtered_top=pd.concat([concept_diff_all_filtered.sort_values(ascending=False).iloc[:5], 
                                     concept_diff_all_filtered.sort_values(ascending=False).iloc[-5:]]) #concept_diff_isic_all_top
concept_diff_all_filtered_top.index=concept_diff_all_filtered_top.index.map(shorten_concept_name)
concept_diff_all_filtered_top

In [None]:
concept_diff_all.index

In [None]:
hospital_1="ViDIR Group, Department of Dermatology, Medical University of Vienna"
hospital_2="Hospital Clínic de Barcelona"

In [None]:
# hospital_1="Hospital Clínic de Barcelona"
# hospital_2="ViDIR Group, Department of Dermatology, Medical University of Vienna"

concept_diff_institution=pd.concat([
    pd.Series(variable_dict["isic_nodup_nooverlap"]["concept_diff_sole"][hospital_1]).rename(hospital_1),
    pd.Series(variable_dict["isic_nodup_nooverlap"]["concept_diff_sole"][hospital_2]).rename(hospital_2)
], axis=1)

concept_diff_institution_pos_pos=concept_diff_institution[(concept_diff_institution[hospital_1]>0)&(concept_diff_institution[hospital_2]>0)]
concept_diff_institution_pos_neg=concept_diff_institution[(concept_diff_institution[hospital_1]>0)&(concept_diff_institution[hospital_2]<0)]
concept_diff_institution_neg_pos=concept_diff_institution[(concept_diff_institution[hospital_1]<0)&(concept_diff_institution[hospital_2]>0)]
concept_diff_institution_neg_neg=concept_diff_institution[(concept_diff_institution[hospital_1]<0)&(concept_diff_institution[hospital_2]<0)]
# concept_diff_institution_top=concept_diff_institution.copy()
# concept_diff_institution_top=concept_diff_institution_top[(concept_diff_institution_top["Hospital Clínic de Barcelona"]*concept_diff_institution_top["Medical University of Vienna"])<0]
# concept_diff_institution_top["diff_abs"]=(concept_diff_institution_top["Hospital Clínic de Barcelona"]-concept_diff_institution_top["Medical University of Vienna"]).abs()
# concept_diff_institution_top=concept_diff_institution_top.sort_values("diff_abs")
# concept_diff_institution_top

In [None]:
concept_diff_institution

In [None]:
variable_dict["isic"]["concept_diff_sole"].keys()

In [None]:
concept_diff_institution_top["concept_name"].map(lambda x: check_concept_name(dataset_name=dataset_name, concept_name=x))





In [None]:
!ls logs/experiment_results/

In [None]:
variable_dict=torch.load("logs/experiment_results/concept_annotation_data_auditing_0930.pt", map_location="cpu")

In [None]:
from matplotlib.patches import Patch

In [None]:
fig = plt.figure(figsize=(3*10, 3*(2+7+4) + 0.3*2   ))



example_per_concept=10
dataset_name="isic_nodup_nooverlap"
normalize=True

hospital_1="ViDIR Group, Department of Dermatology, Medical University of Vienna"
hospital_2="Hospital Clínic de Barcelona"

attribution_concept_list=[(hospital_1, "red"), 
                          (hospital_2,"red")]

red_color=np.array([212,17,89]) #np.array((222,40,40))
green_color=np.array([26,133,255]) #np.array((40,200,40))
[31, 120, 180], [51, 160, 44]
# two_color=[np.array([90, 0, 220]), np.array([51, 160, 44])]
#two_color=[np.array([31, 120, 180]), np.array([51, 160, 44])]
two_color=[np.array([60, 50, 180]), np.array([51, 160, 44])]

plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in two_color])

box1 = gridspec.GridSpec(3, 1,
                         height_ratios=[2, 7, 4],
                         wspace=0.0,
                         hspace=0.3)

axd={}
for idx1, stage in enumerate(["overview", "institution","investigation"]):
    if stage=="overview":
        
        plot_key=stage
        ax=plt.Subplot(fig, box1[idx1])
        fig.add_subplot(ax) 
        axd[plot_key]=ax
        
    elif stage=="institution":
        box2 = gridspec.GridSpecFromSubplotSpec(1, 4,
                        subplot_spec=box1[idx1], wspace=0., hspace=0, width_ratios=[0.3, 1, 0.4, 1])
        
        for idx2, instituion_type in enumerate(["empty_all", "all", "empty_between",  "between"]):
            
            plot_key=f"{stage}_{instituion_type}"
            ax=plt.Subplot(fig, box2[idx2])
            fig.add_subplot(ax)
            axd[plot_key]=ax
        
#     elif stage=="investigation":        
#         box2 = gridspec.GridSpecFromSubplotSpec(len(attribution_concept_list), 1,
#                         subplot_spec=box1[idx1], wspace=0.0, hspace=0.1)        
#         for idx2, (attribution_name, concept_name) in enumerate(attribution_concept_list):
#             box3 = gridspec.GridSpecFromSubplotSpec(example_per_concept//10, 10,
#                                                     subplot_spec=box2[idx2], wspace=0, hspace=0.1)
#             for rank_num in range(example_per_concept):
                
#                 plot_key=f"{stage}_{attribution_name}_{concept_name}_{rank_num}"                
#                 ax=plt.Subplot(fig, box3[rank_num])
#                 fig.add_subplot(ax)
#                 axd[plot_key]=ax         
    elif stage=="investigation":
        box2 = gridspec.GridSpecFromSubplotSpec(1, 2,
                        subplot_spec=box1[idx1], wspace=0.2, hspace=0.)        
        for idx2, investigation_type in enumerate(["example", "statistics"]):
            if investigation_type=="example":
                box3 = gridspec.GridSpecFromSubplotSpec(len(attribution_concept_list), 1,
                                subplot_spec=box2[idx2], wspace=0.0, hspace=0.15)
                for idx3, (attribution_name, concept_name) in enumerate(attribution_concept_list):
                    box4 = gridspec.GridSpecFromSubplotSpec(example_per_concept//5, 5,
                                                            subplot_spec=box3[idx3], wspace=0, hspace=0.1)
                    for rank_num in range(example_per_concept):
                        plot_key=f"{stage}_{attribution_name}_{concept_name}_{rank_num}"                
                        ax=plt.Subplot(fig, box4[rank_num])
                        fig.add_subplot(ax)
                        axd[plot_key]=ax                  
            elif investigation_type=="statistics":
#                 plot_key=f"{stage}_{investigation_type}_{statistics_split}"
#                 ax=plt.Subplot(fig, box2[idx2])
#                 fig.add_subplot(ax)
#                 axd[plot_key]=ax                                             
                box3 = gridspec.GridSpecFromSubplotSpec(2, 1,
                                subplot_spec=box2[idx2], wspace=0.0, hspace=0.25, height_ratios=[999, 1])
                for idx3, statistics_split in enumerate(["main", "empty"]):                
                    plot_key=f"{stage}_{investigation_type}_{statistics_split}"
                    ax=plt.Subplot(fig, box3[idx3])
                    fig.add_subplot(ax)
                    axd[plot_key]=ax                             
    
  
    
for plot_key in axd.keys():
    if 'overview' in plot_key:
        axd[plot_key].set_xticks([])
        axd[plot_key].set_yticks([])
        for axis in ['top','bottom','left','right']:
            axd[plot_key].spines[axis].set_linewidth(0) 
            
    if 'empty' in plot_key:
        axd[plot_key].set_xticks([])
        axd[plot_key].set_yticks([])
        for axis in ['top','bottom','left','right']:
            axd[plot_key].spines[axis].set_linewidth(0)             
            
            
for idx1, stage in enumerate(["overview", "institution", "investigation"]):
    if stage=="overview":
        plot_key=stage
        axd[plot_key].text(x=-0.01, y=1.0, transform=axd[plot_key].transAxes,
                             s="A", fontsize=35, weight='bold')            
        axd[plot_key].text(x=0.5, y=1.0, transform=axd[plot_key].transAxes,
                             s="B", fontsize=35, weight='bold')                    
        
    elif stage=="institution":
        box2 = gridspec.GridSpecFromSubplotSpec(1, 2,
                        subplot_spec=box1[idx1], wspace=0.0, hspace=0.1)
        
        for idx2, instituion_type in enumerate(["empty_all", "all", "empty_between", "between"]):
            plot_key=f"{stage}_{instituion_type}"
            
            if instituion_type=="all":
                concept_diff_all_top=pd.concat([concept_diff_all.sort_values(ascending=False).iloc[:20],concept_diff_all.sort_values(ascending=False).iloc[-14:]])
                concept_diff_all_top=concept_diff_all_top.reset_index(name="diff")
                concept_diff_all_top=concept_diff_all_top.rename(columns={"index": "concept_name"})
                
                concept_diff_all_top=concept_diff_all_top[concept_diff_all_top["concept_name"].map(lambda x: check_concept_name(dataset_name=dataset_name, concept_name=x))]
                concept_diff_all_top["concept_name"]=concept_diff_all_top["concept_name"].map(shorten_concept_name)
                
                sns.barplot(y="concept_name", x="diff", 
                            data=concept_diff_all_top, 
                            edgecolor='black',
                            linewidth=2,
                            width=0.35,
                            color=np.array(Paired[12][7])/256,
                            #color=(0.8, 0.8, 0.8, 0),
                            #palette="vlag",
                            ax=axd[plot_key])
                
                axd[plot_key].set_title("All institutions", pad=20, fontsize=35)
                axd[plot_key].set_xlim(-0.022, 0.022)
                
                
                #3################################################################################
#                 axd[plot_key+"_inner"] = axd[plot_key].inset_axes(
#                                     #loc='lower center',
#                                     #bbox_to_anchor=(0.5,-1,0,0),
#                                     bounds=(0.25, -0.12, 0.5, 0.02),
#                                     transform=axd[plot_key].transAxes)
                
# #                 axd[plot_key+"_inner"] = inset_axes(axd[plot_key],
# #                                                     width="100%",  # width = 5% of parent_bbox width
# #                                                     height="10%",
# #                                                     #loc='lower center',
# #                                                     #bbox_to_anchor=(0.5,-1,0,0),
# #                                                     bounds=(0,0,0.8,0.1),
# #                                                     bbox_transform=axd[plot_key].transAxes)

#                 cbar=fig.colorbar(mpl.cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=-1, vmax=1), cmap=sns.color_palette("icefire", as_cmap=True)),
#                                   cax=axd[plot_key+"_inner"],
#                                   format=ticker.FuncFormatter(lambda x, pos: f'{int(x):2d}'),
#                                   orientation='horizontal')
#                 cbar.outline.set_linewidth(0.3)
#                 cbar.set_ticks([-2, -1, 0, 1, 2])
                
                
            elif instituion_type=="between":
                
                concept_diff_institution["diff_abs"]=(concept_diff_institution[hospital_1]-concept_diff_institution[hospital_2]).abs()
                concept_diff_institution_top=concept_diff_institution.sort_values("diff_abs", ascending=False).iloc[:26]                
                concept_diff_institution_top=concept_diff_institution_top.drop(columns=["diff_abs"])
                #concept_diff_institution_top.plot.barh(ax=axd[plot_key])
                
                concept_diff_institution_top_1=concept_diff_institution_top[hospital_1].reset_index(name="diff")
                concept_diff_institution_top_1["source"]=hospital_1
                concept_diff_institution_top_2=concept_diff_institution_top[hospital_2].reset_index(name="diff")
                concept_diff_institution_top_2["source"]=hospital_2
                concept_diff_institution_top=pd.concat([concept_diff_institution_top_1, concept_diff_institution_top_2])                
                concept_diff_institution_top=concept_diff_institution_top.rename(columns={"index": "concept_name"})
                
                concept_diff_institution_top=concept_diff_institution_top[concept_diff_institution_top["concept_name"].map(lambda x: check_concept_name(dataset_name=dataset_name, concept_name=x))]
                concept_diff_institution_top["concept_name"]=concept_diff_institution_top["concept_name"].map(shorten_concept_name)
                
                sns.barplot(y="concept_name", x="diff", 
                            hue="source",
                            data=concept_diff_institution_top, 
                            hue_order=[hospital_1, hospital_2],
                            edgecolor='black',
                            linewidth=2,
                            width=0.5,
                            ax=axd[plot_key])                
                
                
                axd[plot_key].set_title("Per institution", pad=20, fontsize=35)
                axd[plot_key].set_xlim(-0.029, 0.029)
                
                legend_elements=[Patch(facecolor=two_color[0]/256, 
                                       edgecolor="black", linewidth=2, label=shorten_hospital_name(hospital_1)),
                                 Patch(facecolor=two_color[1]/256, 
                                       edgecolor="black", linewidth=2, label=shorten_hospital_name(hospital_2))]
#                 legend_elements = [Line2D([0], [0], color=np.array(Paired[12][1])/256, linestyle=['-'][0], linewidth=8, label=shorten_hospital_name(hospital_1)),
#                                    Line2D([0], [0], color=np.array(Paired[12][3])/256, linestyle=['-'][0], linewidth=8, label=shorten_hospital_name(hospital_2)),
#                                   ]
                axd[plot_key].legend(handles=legend_elements, 
                            ncol=1, 
                            handlelength=3,
                            handletextpad=0.6, 
                            columnspacing=1.5,
                            fontsize=23,
                            loc='lower center', bbox_to_anchor=(0.5, -0.15))                   
                
                
            if instituion_type=="all" or instituion_type=="between":
                axd[plot_key].axvline(x=0, ymin=0, ymax=1, color='black', alpha=0.7, linewidth=5, zorder=-5)

                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(3)                

                axd[plot_key].set_xlabel('Expression difference', fontsize=25)

                axd[plot_key].spines['right'].set_visible(False)
                axd[plot_key].spines['left'].set_visible(False)
                axd[plot_key].spines['top'].set_visible(False)   

                axd[plot_key].xaxis.set_major_locator(MultipleLocator(0.010))
                axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.005))
                axd[plot_key].xaxis.grid(True, which='major', linewidth=1, alpha=0.6)
                axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.2)

                axd[plot_key].yaxis.grid(True, which='major', linewidth=0.5, alpha=0.6)

                axd[plot_key].tick_params(axis='x', which='major', labelsize=25)
                #axd[plot_key].tick_params(axis='y', which='major', labelsize=25, pad=-100)
                axd[plot_key].tick_params(axis='y', which='major', labelsize=25)
                axd[plot_key].set_ylabel(None)
                #axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.5, alpha=0.2)            
            
            
            if instituion_type=="empty_all":
                pass
#                 axd[plot_key].text(x=-0.1, y=1.0, transform=axd[plot_key].transAxes,
#                                      s="A", fontsize=35, weight='bold')              
                
                #axd[plot_key].set_ylabel('Concepts', fontsize=30)
                
                
            if instituion_type=="empty_between":
                pass
#                 axd[plot_key].text(x=-0.0, y=1.0, transform=axd[plot_key].transAxes,
#                                      s="B", fontsize=35, weight='bold')  
        
        
    elif stage=="investigation":       
        for idx2, investigation_type in enumerate(["example", "statistics"]):
            
            if investigation_type=="example":
                for idx3, (attribution_name, concept_name) in enumerate(attribution_concept_list):            

                    subset_idx=get_subset_index(dataset_name=dataset_name, metadata_all=variable_dict[dataset_name]["metadata_all"], attribution=attribution_name)
#                     subset_idx=get_subset_index(dataset_name=dataset_name, metadata_all=variable_dict[dataset_name]["metadata_all"], attribution="all")
                    valid_idx=variable_dict[dataset_name]["valid_idx"]
                    subset_idx_select=(valid_idx==True)&(subset_idx)
                    #idx_neg=(y_pos==False)&(valid_idx==True)&(subset_idx)            

                    print(subset_idx)
                    similaity_score=calculate_similaity_score(
                                    image_features_norm=variable_dict[dataset_name]["image_features_norm"],
                                    prompt_target_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_target_embedding_norm"],
                                    prompt_ref_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_ref_embedding_norm"],
                                    normalize=normalize)

                    image_idx_list=pd.Series(similaity_score)[subset_idx_select
                                                             ].sort_values(ascending=False).index.tolist()
                    image_idx_list=np.random.default_rng(10).permutation(image_idx_list[:100])
                    image_idx_list=np.random.default_rng(10).permutation(image_idx_list)
                    #image_idx_list=image_idx_list[:1][::-1][:10]
                                        
                    #image_idx_list=image_idx_list[:50]
                    #np.random.
                    plot_key=f"{stage}_{attribution_name}_{concept_name}_{rank_num}"
                    count=0
                    rank_num=0
                    while rank_num<example_per_concept:
                        if check_image(dataset_name, image_idx_list[count]):
                            pass
                        else:
                            count+=1
                            continue

                        plot_key=f"{stage}_{attribution_name}_{concept_name}_{rank_num}"                

                        item=variable_dict[dataset_name]["dataloader"].dataset.getitem(image_idx_list[count])
                        #image=
                        axd[plot_key].imshow(item["image"].resize((300, 300)))

                        if variable_dict["isic_nodup_nooverlap"]["metadata_all"].loc[item["metadata"].name]["benign_malignant_bool"]==True:
        #                     axd[plot_key].scatter(x=[0.9], y=[0.9], s=400, 
        #                                linewidths=2,
        #                                edgecolor=np.array((0,0,0))/256,
        #                                color=np.array((222,40,40))/256,
        #                                marker="o",
        #                                transform=axd[plot_key].transAxes)  

        #                     axd[plot_key].scatter(x=[0.9], y=[0.9], s=600, 
        #                                linewidths=1.5,
        #                                edgecolor=np.array((1,1,1))/256,
        #                                color=np.array((1,1,1,0))/256,
        #                                marker="o",
        #                                transform=axd[plot_key].transAxes) 
        #                     axd[plot_key].scatter(x=[0.9], y=[0.9], s=300, 
        #                                linewidths=5,
        #                                edgecolor=np.array((222,40,40))/256,
        #                                color=np.array((1,1,1,0))/256,
        #                                marker="o",
        #                                transform=axd[plot_key].transAxes) 

        #                     axd[plot_key].scatter(x=[0.9], y=[0.9], s=100, 
        #                                linewidths=1,
        #                                edgecolor=np.array((1,1,1))/256,
        #                                color=np.array((1,1,1,0))/256,
        #                                marker="o",
        #                                transform=axd[plot_key].transAxes)  

        #                     axd[plot_key].scatter(x=[0.9], y=[0.9], s=1800, 
        #                                linewidths=1,
        #                                edgecolor=np.array((0,0,0))/256,
        #                                color=np.array((222,40,40))/256,
        #                                marker=".",
        #                                transform=axd[plot_key].transAxes) 

                            axd[plot_key].scatter(x=[0.92], y=[0.92], s=400, 
                                       linewidths=1.5,
                                       edgecolor=np.array((0,0,0, 120))/256,
                                       #color=np.array((222,40,40))/256,
                                        color=red_color/256,
                                       marker="s",
                                       transform=axd[plot_key].transAxes)     




                        elif variable_dict["isic_nodup_nooverlap"]["metadata_all"].loc[item["metadata"].name]["benign_malignant_bool"]==False:
        #                     axd[plot_key].scatter(x=[0.9], y=[0.9], s=600, 
        #                                linewidths=1,
        #                                 edgecolor=np.array((0,0,0))/256,
        #                                 color=np.array((40,200,40))/256,
        #                                marker="X",
        #                                transform=axd[plot_key].transAxes)

                            axd[plot_key].scatter(x=[0.92], y=[0.92], s=400, 
                                       linewidths=1.5,
                                       edgecolor=np.array((0,0,0, 120))/256,
                                       #color=np.array((40,200,40))/256,
                                       color=green_color/256,
                                       marker="s",
                                       transform=axd[plot_key].transAxes)                 

                        else:
                            raise NotImplementedError

                        #.scatter(x=[1,1], y=[1, 1], c='r', s=40)
                        #print(axd[plot_key].get_legend())
                        axd[plot_key].set_xticks([])
                        axd[plot_key].set_yticks([])        

                        if rank_num==2:   
                            #axd[plot_key].set_ylabel(shorten_attribution_name(attribution_name, very_short=True), fontsize=30, zorder=-10)
                            axd[plot_key].set_title(shorten_hospital_name(attribution_name), fontsize=30, zorder=-10)
                            #shorten_hospital_name(hospital_1)
                        if rank_num==0 and idx3==0:
                            axd[plot_key].text(x=-0.33, y=1.05, transform=axd[plot_key].transAxes,
                                                 s="C", fontsize=35, weight='bold')



        #                     legend_elements=[Line2D([0], [0], marker='x', color=np.array((222,40,40))/256,), 
        #                                      Line2D([0], [0], marker='x', color=np.array((40,200,40))/256),]


                            legend_elements = [Line2D([0], [0], 
                                                      marker='s', #marker='o'
                                                      color=(1,1,1,1), 
                                                      markerfacecolor=red_color/256, 
                                                      markeredgecolor=np.array((0,0,0))/256, 
                                                      markersize=30, 
                                                      label="Maligant"),
                                               Line2D([0], [0], 
                                                      marker='s', #marker='o'
                                                      color=(1,1,1,1), 
                                                      markerfacecolor=green_color/256, 
                                                      markeredgecolor=np.array((0,0,0))/256, 
                                                      markersize=30, label="Benign"),]        

                            axd[plot_key].legend(handles=legend_elements, 
                                                ncol=2, 
                                                handlelength=3,
                                                handletextpad=-0.1, 
                                                columnspacing=1.5,
                                                fontsize=23,
                                                loc='lower center', 
                                                bbox_to_anchor=(3., -4))  
#                             axd[plot_key].legend(handles=legend_elements, 
#                                                 ncol=1, 
#                                                 handlelength=3,
#                                                 handletextpad=-0.1, 
#                                                 columnspacing=1.5,
#                                                 fontsize=23,
#                                                 loc='lower center', 
#                                                 bbox_to_anchor=(3., -4))  

                        for axis in ['top','bottom','left','right']:
                            axd[plot_key].spines[axis].set_linewidth(1)      

                        rank_num+=1
                        count+=1

#                     if idx3==0 and rank_num==0:
#                         axd[plot_key].text(x=-0.0, y=1.0, transform=axd[plot_key].transAxes,
#                                          s="D", fontsize=35, weight='bold')  
                        
            if investigation_type=="statistics":           
                for idx3, statistics_split in enumerate(["main", "empty"]):
                    if statistics_split=="empty":
                        continue
                    plot_key=f"{stage}_{investigation_type}_{statistics_split}"

                    similaity_score=calculate_similaity_score(
                                    image_features_norm=variable_dict[dataset_name]["image_features_norm"],
                                    prompt_target_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_target_embedding_norm"],
                                    prompt_ref_embedding_norm=variable_dict[dataset_name]["prompt_info"][concept_name]["prompt_ref_embedding_norm"],
                                    normalize=normalize)

                    #axd[plot_key].plot(np.arange(0,1,0.01),np.arange(0,1,0.01),c=(0.8,0.8,0.8,0.8),linewidth=4.0,linestyle='--')

                    top_record_list=[]
                    for idx4, (attribution_name, concept_name) in enumerate(attribution_concept_list):
                        subset_idx=get_subset_index(dataset_name=dataset_name, metadata_all=variable_dict[dataset_name]["metadata_all"], attribution=attribution_name)
                        valid_idx=variable_dict[dataset_name]["valid_idx"]
                        subset_idx_select=(valid_idx==True)&(subset_idx)



                        metadata_all_statistics_select=variable_dict[dataset_name]["metadata_all"][subset_idx_select]
                        similaity_select=similaity_score[subset_idx_select]

    #                     fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true=metadata_all_statistics_select["benign_malignant_bool"], y_score=similaity_select, drop_intermediate=False)
    #                     print(len(metadata_all_statistics_select),len(tpr))
    #                     axd[plot_key].plot(fpr, tpr, 
    #                                        color=[np.array(Paired[12][1])/256, np.array(Paired[12][3])/256][idx3], linewidth=4)


                        if idx4==0:
                            precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_true=metadata_all_statistics_select["benign_malignant_bool"], probas_pred=similaity_select)
    #                         precision, recall, thresholds=precision[:-1], recall[:-1], thresholds[:-1]

                            assert (len(precision)-1)==(len(recall)-1)==len(thresholds), f"{len(precision)}, {len(recall)}, {len(thresholds)}"
                            precision, recall=precision[:-1], recall[:-1]
    #                         precision, recall, thresholds=precision[:-30], recall[:-30], thresholds[:-30]

                            print(attribution_name, )
                            print(precision)
                            print(recall)
                            print(thresholds)


                            axd[plot_key].plot(recall,precision , 
                                               color=[np.array([ 40,  30, 180])/256, 
                                                      two_color[1]/256][idx4], linewidth=4)                       
    #                         plt.plot(recall, precision,
    #                                            color=[np.array(Paired[12][1])/256, np.array(Paired[12][3])/256][idx3], linewidth=4)                                        

                            sns.scatterplot(x="recall", y="precision", 
                                            style="top_n",
                                            color="red",                                        
                                            s=500,
                                            alpha=0.9,
                                            zorder=10,
                                            edgecolor='black',
                                        data=pd.DataFrame([{"precision": precision[-top_n], "recall": recall[-top_n], "top_n": top_n} for top_n in [100, 500, 1000]]),
                                           ax=axd[plot_key]) 

                            #aaa
                        else:
                            precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_true=metadata_all_statistics_select["benign_malignant_bool"], probas_pred=similaity_select)
    #                         precision, recall, thresholds=precision[:-1], recall[:-1], thresholds[:-1]

                            assert (len(precision)-1)==(len(recall)-1)==len(thresholds), f"{len(precision)}, {len(recall)}, {len(thresholds)}"
                            precision, recall=precision[:-1], recall[:-1]
    #                         precision, recall, thresholds=precision[:-30], recall[:-30], thresholds[:-30]                        


                            print(attribution_name, )
                            print(precision)
                            print(recall)
                            print(thresholds)    

                            axd[plot_key].plot(recall,precision , 
                                               color=[two_color[0]/256, 
                                                      two_color[1]/256][idx4], linewidth=4)                                                            

                            sns.scatterplot(x="recall", y="precision", 
                                            style="top_n",
                                            color="red",
                                            s=500,
                                            alpha=0.9,
                                            zorder=10,        
                                            edgecolor='black',
                                        data=pd.DataFrame([{"precision": precision[-top_n], "recall": recall[-top_n], "top_n": top_n} for top_n in [100, 500, 1000]]),
                                           ax=axd[plot_key])                        


    #                     fpr, tpr, thresholds = sklearn.metrics.precision_recall_curve(y_true=metadata_all_statistics_select["benign_malignant_bool"], probas_pred=similaity_select)
    #                     axd[plot_key].plot(fpr, tpr, 
    #                                        color=[np.array(Paired[12][1])/256, np.array(Paired[12][3])/256][idx3], linewidth=4)                    

                        #print(len(metadata_all_statistics_select),len(tpr))

    #                     sns.scatterplot(x="fpr", y="tpr", 
    #                                     s=400,
    #                                 data=pd.DataFrame([{"fpr": fpr[top_n], "tpr": tpr[top_n], "top_n": top_n} for top_n in [100, 500, 1000]]),
    #                                    ax=axd[plot_key])



                        print(attribution_name, metadata_all_statistics_select["benign_malignant_bool"].sum()/len(metadata_all_statistics_select["benign_malignant_bool"]))


                        num_top_list=[]
                        num_true_list=[]
                        for num_top in [50, 100, 200, 500, 1000, 2000]:
                            num_top_list.append(num_top)
                            num_true_list.append(metadata_all_statistics_select["benign_malignant_bool"][np.argsort(similaity_select)[::-1][:num_top]].sum()/num_top)
                            top_record_list.append({"num_top": num_top,
                                                    "proportion": metadata_all_statistics_select["benign_malignant_bool"][np.argsort(similaity_select)[::-1][:num_top]].sum()/num_top,
                                                    "attribution_name": attribution_name
                                                   })
                    #b=sns.barplot(y="proportion", x="num_top", hue="attribution_name", data=pd.DataFrame(top_record_list), ax=axd[plot_key])
                    #b.legend_.remove()
                    #axd[plot_key].plot(num_top_list, num_true_list)

                    legend_elements = [Line2D([0], [0], marker='.', color=(1,1,1,1), 
                                              markerfacecolor="red", 
                                              markeredgecolor=np.array((0,0,0))/256, 
                                              markersize=40, 
                                              label="Top 100"),
                                       Line2D([0], [0], marker='X', color=(1,1,1,1), 
                                              markerfacecolor="red", 
                                              markeredgecolor=np.array((0,0,0))/256, 
                                              markersize=20, label="Top 500"),
                                       Line2D([0], [0], marker='s', color=(1,1,1,1), 
                                              markerfacecolor="red", 
                                              markeredgecolor=np.array((0,0,0))/256, 
                                              markersize=15, label="Top 1000"),                                   

                                      ]        

                    leg=axd[plot_key].legend(handles=legend_elements, 
                                        ncol=1, 
                                        handlelength=3,
                                        handletextpad=-0.1, 
                                        columnspacing=1.5,
                                        fontsize=23,
                                        loc='lower center', 
                                        #bbox_to_anchor=(0.85, 0.75)
                                        bbox_to_anchor=(0.85, 0.73),
                                        facecolor='white', framealpha=0.5
                                            )     

                    axd[plot_key].add_artist(leg)



                    legend_elements=[Patch(facecolor=two_color[0]/256, 
                                           edgecolor="black", linewidth=2, label=shorten_hospital_name(hospital_1)),
                                     Patch(facecolor=two_color[1]/256, 
                                           edgecolor="black", linewidth=2, label=shorten_hospital_name(hospital_2))]
    #                 legend_elements = [Line2D([0], [0], color=np.array(Paired[12][1])/256, linestyle=['-'][0], linewidth=8, label=shorten_hospital_name(hospital_1)),
    #                                    Line2D([0], [0], color=np.array(Paired[12][3])/256, linestyle=['-'][0], linewidth=8, label=shorten_hospital_name(hospital_2)),
    #                                   ]
                    axd[plot_key].legend(handles=legend_elements, 
                                ncol=1, 
                                handlelength=3,
                                handletextpad=0.6, 
                                columnspacing=1.5,
                                fontsize=23,
                                loc='lower center', bbox_to_anchor=(0.5, -0.3))   #facecolor='white', framealpha=0.5               

                    for axis in ['top','bottom','left','right']:
                        axd[plot_key].spines[axis].set_linewidth(3)                


                    axd[plot_key].set_ylim(-0.02, 1.05)
                    axd[plot_key].set_xlim(-0.02, 1.05)
                    axd[plot_key].set_xlabel('Recall rate', fontsize=25)
                    axd[plot_key].set_ylabel('Precision', fontsize=25)

                    axd[plot_key].spines['top'].set_visible(False)
                    axd[plot_key].spines['right'].set_visible(False)   

                    axd[plot_key].xaxis.set_major_locator(MultipleLocator(0.2))
    #                 axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.05))
    #                 axd[plot_key].xaxis.grid(True, which='major', linewidth=1, alpha=0.6)
    #                 axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.2)

                    axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
                    #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.05))                
                    axd[plot_key].yaxis.grid(True, which='major', linewidth=0.5, alpha=0.6)                
                    axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.5, alpha=0.2)   

                    axd[plot_key].tick_params(axis='x', which='major', labelsize=25)
                    axd[plot_key].tick_params(axis='y', which='major', labelsize=25)                

                    axd[plot_key].set_title("")


                    axd[plot_key].text(x=-0.1, y=1.0, transform=axd[plot_key].transAxes,
                                             s="D", fontsize=35, weight='bold')                  
                        
# fig.savefig(log_dir/"plots"/"data_audit_main.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"data_audit_main.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"data_audit_main.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"data_audit_main.pdf", bbox_inches='tight')
#plt.close(fig)                    

In [None]:
variable_dict_=torch.load("logs/experiment_results/concept_annotation_data_auditing_0930.pt", map_location="cpu")

In [None]:
fig.savefig(log_dir/"plots"/"data_audit_main.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"data_audit_main.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"data_audit_main.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"data_audit_main.pdf", bbox_inches='tight')

In [None]:
log_dir = Path("logs")

In [None]:
item["metadata"]

In [None]:
concept_diff_institution.sort_values("diff_abs", ascending=False).iloc[:20]                

In [None]:
np.array(Paired[12][1])/256, np.array(Paired[12][3])/256

In [None]:
(Paired[12][1],Paired[12][3])

In [None]:
variable_dict["clinical_fd_clean"]["benchmark_dict_list"]

In [None]:
skincon_cols

In [None]:
concept_diff_all_top

In [None]:
skincon_cols

In [None]:
thresholds

In [None]:
from matplotlib.patches import Patch

In [None]:
from matplotlib.lines import Line2D

In [None]:
from sklearn.model_selection import train_test_split
from MONET.datamodules.components.base_dataset import BaseDataset

def get_test_data(dataloader, metadata_all, valid_idx, y_pos, n_px=None):
    metadata_all_new = dataloader.dataset.metadata_all.copy()
    metadata_all_new["label"]=y_pos.astype(int)
    # metadata_all_new=metadata_all_new.iloc[list(true_set.union(false_set))]
    print("test:", len(metadata_all_new))
    
    if n_px is None:
        n_px=dataloader.dataset.n_px    
    
    data_all = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new,
        integrity_level="weak",
        return_label=["label"],
    )    

    from MONET.utils.loader import custom_collate
    all_dataloader = torch.utils.data.DataLoader(
        dataset=data_all,
        batch_size=64,
        num_workers=4,
        pin_memory=False,
        persistent_workers=False,
        shuffle=False,
        collate_fn=custom_collate,
    )    
          
    
    return all_dataloader

In [None]:
for dataset_name in ["isic"]: 
    x=get_test_data(dataloader=variable_dict["isic"]["dataloader"], 
                  metadata_all=variable_dict["isic"]["metadata_all"],
                  valid_idx=variable_dict["isic"]["valid_idx"],
                  y_pos=variable_dict["isic"]["y_pos"])
    variable_dict[dataset_name].update({f"classifier_dataloader_all": x}) 

In [None]:
for dataset_name in ["isic"]: 
    x=get_test_data(dataloader=variable_dict["isic"]["dataloader"], 
                  metadata_all=variable_dict["isic"]["metadata_all"],
                  valid_idx=variable_dict["isic"]["valid_idx"],
                  y_pos=variable_dict["isic"]["y_pos"], n_px=299)
    variable_dict[dataset_name].update({f"classifier_dataloader299_all": x}) 

In [None]:
from sklearn.model_selection import train_test_split
from MONET.datamodules.components.base_dataset import BaseDataset

def get_training_data(dataloader, metadata_all, valid_idx, y_pos, subset_idx_train, n_px=None):
    metadata_all_new = dataloader.dataset.metadata_all.copy()
    metadata_all_new["label"]=y_pos.astype(int)
    metadata_all_new_train=metadata_all_new[valid_idx&subset_idx_train]
    # metadata_all_new=metadata_all_new.iloc[list(true_set.union(false_set))]

    train_idx, val_idx = train_test_split(metadata_all_new_train.index, test_size=0.2, random_state=42)
    
    print("train:", len(metadata_all_new_train))

    if n_px is None:
        n_px=dataloader.dataset.n_px
    
    data_train = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new_train.loc[train_idx],
        integrity_level="weak",
        return_label=["label"],
    )

    data_val = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new_train.loc[val_idx],
        integrity_level="weak",
        return_label=["label"],
    )

    from MONET.utils.loader import custom_collate

    train_dataloader = torch.utils.data.DataLoader(
        dataset=data_train,
        batch_size=64,
        num_workers=4,
        pin_memory=False,
        persistent_workers=False,
        shuffle=True,
        collate_fn=custom_collate,
    )
    val_dataloader = torch.utils.data.DataLoader(
        dataset=data_val,
        batch_size=64,
        num_workers=4,
        pin_memory=False,
        persistent_workers=False,
        shuffle=False,
        collate_fn=custom_collate,
    )       
    
    return train_dataloader, val_dataloader

In [None]:
for dataset_name in ["isic"]: 
    x=get_training_data(dataloader=variable_dict["isic"]["dataloader"], 
                  metadata_all=variable_dict["isic"]["metadata_all"],
                  valid_idx=variable_dict["isic"]["valid_idx"],
                  y_pos=variable_dict["isic"]["y_pos"],
                  subset_idx_train=get_subset_index(dataset_name="isic", metadata_all=variable_dict["isic"]["metadata_all"], attribution=hospital_1),
                 )
    variable_dict[dataset_name].update({f"classifier_dataloader_{hospital_1}": x}) 
    
    x=get_training_data(dataloader=variable_dict["isic"]["dataloader"], 
                  metadata_all=variable_dict["isic"]["metadata_all"],
                  valid_idx=variable_dict["isic"]["valid_idx"],
                  y_pos=variable_dict["isic"]["y_pos"],
                  subset_idx_train=get_subset_index(dataset_name="isic", metadata_all=variable_dict["isic"]["metadata_all"], attribution=hospital_2),
                    )
    variable_dict[dataset_name].update({f"classifier_dataloader_{hospital_2}": x}) 

In [None]:
for dataset_name in ["isic"]: 
    x=get_training_data(dataloader=variable_dict["isic"]["dataloader"], 
                  metadata_all=variable_dict["isic"]["metadata_all"],
                  valid_idx=variable_dict["isic"]["valid_idx"],
                  y_pos=variable_dict["isic"]["y_pos"],
                  subset_idx_train=get_subset_index(dataset_name="isic", metadata_all=variable_dict["isic"]["metadata_all"], attribution=hospital_1),
                  n_px=299
                 )
    variable_dict[dataset_name].update({f"classifier_dataloader299_{hospital_1}": x}) 
    
    x=get_training_data(dataloader=variable_dict["isic"]["dataloader"], 
                  metadata_all=variable_dict["isic"]["metadata_all"],
                  valid_idx=variable_dict["isic"]["valid_idx"],
                  y_pos=variable_dict["isic"]["y_pos"],
                  subset_idx_train=get_subset_index(dataset_name="isic", metadata_all=variable_dict["isic"]["metadata_all"], attribution=hospital_2),
                  n_px=299
                  )
    variable_dict[dataset_name].update({f"classifier_dataloader299_{hospital_2}": x}) 

In [None]:
from MONET.datamodules.components.base_dataset import BaseDataset

In [None]:
!gpustat

In [None]:
import torchvision
from torch import nn
from torch.nn import functional as F
from torchmetrics import AUROC

class Classifier(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.backbone = torchvision.models.resnet50(weights="ResNet50_Weights.IMAGENET1K_V1")

        for param in self.backbone.parameters():
            param.requires_grad = True
            # pass

        head_in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.head = nn.Linear(head_in_features, output_dim)

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x
    
class Inception(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.inception = torchvision.models.inception_v3(weights="Inception_V3_Weights.IMAGENET1K_V1")
        self.fc = nn.Linear(2048, output_dim)

    def forward(self, x):
        x = self.input_to_representation(x)
        x = self.fc(x)
        # N x 200
        return x

    def input_to_representation(self, x):
        # N x 3 x 299 x 299
        x = self.inception.Conv2d_1a_3x3(x)
        # N x 32 x 149 x 149
        x = self.inception.Conv2d_2a_3x3(x)
        # N x 32 x 147 x 147
        x = self.inception.Conv2d_2b_3x3(x)
        # N x 64 x 147 x 147
        x = self.inception.maxpool1(x)
        # N x 64 x 73 x 73
        x = self.inception.Conv2d_3b_1x1(x)
        # N x 80 x 73 x 73
        x = self.inception.Conv2d_4a_3x3(x)
        # N x 192 x 71 x 71
        x = self.inception.maxpool2(x)
        # N x 192 x 35 x 35
        x = self.inception.Mixed_5b(x)
        # N x 256 x 35 x 35
        x = self.inception.Mixed_5c(x)
        # N x 288 x 35 x 35
        x = self.inception.Mixed_5d(x)
        # N x 288 x 35 x 35
        x = self.inception.Mixed_6a(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6b(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6c(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6d(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6e(x)
        # N x 768 x 17 x 17
        aux: Optional[torch.Tensor] = None
        if self.inception.AuxLogits is not None:
            if self.inception.training:
                aux = self.inception.AuxLogits(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_7a(x)
        # N x 1280 x 8 x 8
        x = self.inception.Mixed_7b(x)
        # N x 2048 x 8 x 8
        x = self.inception.Mixed_7c(x)
        # N x 2048 x 8 x 8
        # Adaptive average pooling
        x = self.inception.avgpool(x)
        # N x 2048 x 1 x 1
        x = self.inception.dropout(x)
        # N x 2048 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 2048
        return x    


class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
        
def train_classifier(train_dataloader, val_dataloader, test_dataloader, classifier_type="resnet"):
    if classifier_type=="resnet":
        classifier = Classifier(output_dim=1)
    elif classifier_type=="inception":
        classifier = Inception(output_dim=1)
    classifier_device = "cuda:3"
    classifier.to(classifier_device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=2, verbose=True)
    early_stopper = EarlyStopper(patience=5, min_delta=0)

    train_auroc = AUROC(task="binary")
    val_auroc = AUROC(task="binary")
    for epoch in range(20):
        train_loss = 0
        train_correct = 0
        classifier.train()
        for batch in tqdm.tqdm(train_dataloader):
            image, label = batch["image"].to(classifier_device), batch["label"].to(classifier_device)
            logits = classifier(image)
            weight = torch.ones(label.shape[0], device=label.device)
            loss = F.binary_cross_entropy_with_logits(
                input=logits[:, 0], target=(label == 1).float()
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * image.size(0)
            train_auroc.update(logits, (label == 1))

        val_loss = 0
        classifier.eval()
        with torch.no_grad():
            for batch in tqdm.tqdm(val_dataloader):
                image, label = batch["image"].to(classifier_device), batch["label"].to(
                    classifier_device
                )
                logits = classifier(image)
                loss = F.binary_cross_entropy_with_logits(
                    input=logits[:, 0], target=(label == 1).float()
                )
                val_loss += loss.item() * image.size(0)
                val_auroc.update(logits, (label == 1))

        print(
            f"Train loss: {train_loss/len(train_dataloader.dataset):.3f} AUROC: {train_auroc.compute():.3f} Val loss: {val_loss/len(val_dataloader.dataset):.3f} AUROC: {val_auroc.compute():.3f}"
        )
        scheduler.step(val_loss)
        if early_stopper.early_stop(val_loss):
            print("break")
            break
        train_auroc.reset()
        val_auroc.reset() 
        
    test_auroc = AUROC(task="binary")    
    test_loss = 0
    classifier.eval()
    with torch.no_grad():
        for batch in tqdm.tqdm(test_dataloader):
            image, label = batch["image"].to(classifier_device), batch["label"].to(
                classifier_device
            )
            logits = classifier(image)
            loss = F.binary_cross_entropy_with_logits(
                input=logits[:, 0], target=(label == 1).float()
            )
            test_loss += loss.item() * image.size(0)
            test_auroc.update(logits, (label == 1))

    print(
        f"Test loss: {test_loss/len(test_dataloader.dataset):.3f} AUROC: {test_auroc.compute():.3f}"
    )   
    return classifier