In [None]:
import os

import hydra
import omegaconf
import pyrootutils

root = pyrootutils.setup_root(os.path.abspath("compare_manual-230410.ipynb"), pythonpath=True)
import os

os.chdir(root)

In [None]:
import sys

sys.path.append(str(root / "src"))

In [None]:
import os
import time
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib import gridspec
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn.metrics
import torch
import tqdm
from IPython import display
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

import clip

In [None]:
# !cp /System/Library/Fonts/Supplemental ~/.local/share/fonts/
# rm -fr ~/.cache/matplotlib
from matplotlib import font_manager
from matplotlib.lines import Line2D
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.edgecolor']='1.0'
plt.rcParams['legend.framealpha']=0

# https://github.com/dsc/colorbrewer-python/blob/master/colorbrewer.py

Set1 = {
    3: [[228,26,28], [55,126,184], [77,175,74]],
    4: [[228,26,28], [55,126,184], [77,175,74], [152,78,163]],
    5: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0]],
    6: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51]],
    7: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40]],
    8: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191]],
    9: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191], [153,153,153]],
}

Paired = {
    3: [(166,206,227), [31,120,180], [178,223,138]],
    4: [[166,206,227], [31,120,180], [178,223,138], [51,160,44]],
    5: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153]],
    6: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28]],
    7: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111]],
    8: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0]],
    9: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214]],
    10: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154]],
    11: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153]],
    12: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153], [177,89,40]]
}

color_qual_7=['#F53345',
            '#87D303',
            '#04CBCC',
            '#8650CD',
            (160/256, 95/256, 0),
            '#F5A637',              
            '#DBD783',            
             ]

pd.set_option('display.max_rows', 500)

In [None]:
import scipy.special
import tqdm.contrib.concurrent

In [None]:
from MONET.datamodules.multiplex_datamodule import MultiplexDatamodule
from MONET.utils.loader import custom_collate_per_key, dataloader_apply_func
from MONET.utils.metrics import skincon_calcualte_auc_all
from MONET.utils.static import (
    concept_to_prompt,
    fitzpatrick17k_disease_label,
    fitzpatrick17k_ninelabel,
    fitzpatrick17k_threelabel,
    skincon_cols,
)
from MONET.utils.text_processing import generate_prompt_token_from_concept

In [None]:
def wandb_to_exppath(wandb, log_path="/gscratch/cse/chanwkim/MONET_log/train/runs"):
    log_path = Path(log_path)
    for experiment in os.listdir(log_path):
        if os.path.exists(log_path / experiment / "wandb"):
            filenames = os.listdir(log_path / experiment / "wandb")
            filename = [filename for filename in filenames if filename.startswith("run")][0][-8:]
            if filename == wandb:
                return log_path / experiment
    raise RuntimeError("not found")


exppath = wandb_to_exppath(
    wandb="baqqmm5v", log_path="/projects/leelab2/chanwkim/dermatology_datasets/logs/train/runs"
)
print([exppath / "checkpoints" / ckpt for ckpt in os.listdir(exppath / "checkpoints/")])

In [None]:
variable_dict={}

In [None]:
def setup_dataloader(dataset_name):    
    if dataset_name=="clinical_fd_clean_nodup_nooverlap":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "clinical_fd_clean_nodup_nooverlap=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()  
        
    elif dataset_name=="derm7pt_derm_nodup":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt_derm_nodup=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()    
        
    elif dataset_name=="derm7pt_clinical_nodup":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt_clinical_nodup=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()   
        
    elif dataset_name=="derm7pt":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()
        
    elif dataset_name=="allpubmedtextbook":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "pubmed=all,textbook=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()
        
        dataloader = dm.test_dataloader()          
    
    elif dataset_name=="isic_nodup_nooverlap":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "isic_nodup_nooverlap=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader() 
        
    elif dataset_name=="proveai":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "proveai=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()         
        
    return {"dataloader": dataloader}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", "proveai"
                    ]:
    variable_dict.setdefault(dataset_name, {})
    variable_dict[dataset_name].update(setup_dataloader(dataset_name))

In [None]:
!gpustat

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:5"

In [None]:
cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

In [None]:
model = hydra.utils.instantiate(cfg_model)
model.to(model_device)
model.eval()

In [None]:
model_path_dir = {
    "zt0n2xd0": "logs/train/runs/2023-01-17_20-58-15/checkpoints/last.ckpt",
}
if model_name != "ViT-L/14":
    model_path = model_path_dir[model_name]
    loaded = torch.load(model_path, map_location=model_device)
    model.load_state_dict(loaded["state_dict"])
    print('loaded')

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:5"

cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

model_vanilla = hydra.utils.instantiate(cfg_model)
model_vanilla.to(model_device)
model_vanilla.eval()

# Test

In [None]:
model_loaded=model

In [None]:
cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = "cpu"
cfg_model

In [None]:
model = hydra.utils.instantiate(cfg_model)
model.eval()

In [None]:
model_path_dir = {
    "zt0n2xd0": "logs/train/runs/2023-01-17_20-58-15/checkpoints/last.ckpt",
}
if model_name != "ViT-L/14":
    model_path = model_path_dir[model_name]
    loaded = torch.load(model_path, map_location=model_device)
    model.load_state_dict(loaded["state_dict"])
    print('loaded')

In [None]:
model_test, preprocess_test = clip.load("ViT-L/14", device="cpu", jit=False)

In [None]:
model_test.eval()

In [None]:
clip_weights=model_loaded.net.model.state_dict()
model_test.load_state_dict(clip_weights)

In [None]:
torch.save(clip_weights, "logs/train/runs/2023-01-17_20-58-15/checkpoints/last_clip.ckpt")

In [None]:
!ls /projects/leelab2/chanwkim/dermatology_datasets/logs/train/runs/2023-01-17_20-58-15/checkpoints/last_clip.ckpt

In [None]:
clip_weights_=torch.load("/projects/leelab2/chanwkim/dermatology_datasets/logs/train/runs/2023-01-17_20-58-15/checkpoints/last_clip.ckpt")

In [None]:
clip_weights["visual.transformer.resblocks.1.ln_2.weight"]

In [None]:
clip_weights_["visual.transformer.resblocks.1.ln_2.weight"]

In [None]:
with torch.no_grad():
    image_feautures_test=model_test.encode_image(
        variable_dict["clinical_fd_clean_nodup_nooverlap"]["dataloader"].dataset[1]["image"].unsqueeze(0)
                                                )
    image_features=model.model_step_with_image({"image": variable_dict["clinical_fd_clean_nodup_nooverlap"]["dataloader"].dataset[1]["image"].unsqueeze(0)})
    
    assert (image_feautures_test==image_features["image_features"]).all()
    
    text_feautures_test=model_test.encode_text(clip.tokenize(["This is a diagram"]))
    
    text_features=model.model_step_with_text({"text": clip.tokenize(["This is a diagram"])})
    
    assert (text_feautures_test==text_features["text_features"]).all()

In [None]:
text_feautures_test.shape

In [None]:
(text_feautures_test==text_features["text_features"]).all()

In [None]:
with torch.no_grad():

In [None]:
model.model_step_with_image??

In [None]:
image_feautures_test.shape

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["dataloader"].dataset[1]["image"].shape

In [None]:
model_test.forward

In [None]:
#dir(model)


In [None]:
!ls -lh logs/train/runs/2023-01-17_20-58-15/checkpoints/

In [None]:
model_test

In [None]:
!wget https://aimslab.cs.washington.edu/MONET/weight.pt

In [None]:
weight_test=torch.load("weight.pt")

In [None]:
weight_test.keys()

In [None]:
log_dir = Path("logs")

In [None]:
def batch_func(batch):
    with torch.no_grad():
        batch["image"] = batch["image"].to(model_device)
        image_features = model.model_step_with_image(batch)["image_features"]
        image_features_vanilla = model_vanilla.model_step_with_image(batch)["image_features"]
    # print(batch["metadata"])
    return {
        "image_features": image_features.detach().cpu(),
        "image_features_vanilla": image_features_vanilla.detach().cpu(),
        "metadata": batch["metadata"],
    }

def setup_features(dataset_name, dataloader):
    if dataset_name=="isic_nodup_nooverlap":
        loader_applied = torch.load(log_dir/"image_features"/"isic_nodup_nooverlap.pt", map_location="cpu")
        image_features = loader_applied["image_features"].cpu()
        metadata_all = loader_applied["metadata_all"]

        return {"image_features":image_features, 
#                 "image_features_vanilla":image_features_vanilla,
                "metadata_all": metadata_all}
    
    else:
        loader_applied = dataloader_apply_func(
            dataloader=dataloader,
            func=batch_func,
            collate_fn=custom_collate_per_key,
        )
        image_features = loader_applied["image_features"].cpu()
        image_features_vanilla = loader_applied["image_features_vanilla"].cpu()
        metadata_all = loader_applied["metadata"]

        return {"image_features":image_features, 
                "image_features_vanilla":image_features_vanilla,
                "metadata_all": metadata_all}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", "proveai"
                    ]:
    print(dataset_name)
    variable_dict[dataset_name].update(setup_features(dataset_name, variable_dict[dataset_name]["dataloader"])) 

In [None]:
for dataset_name in [ "proveai"
                    ]:
    print(dataset_name)
    variable_dict[dataset_name].update(setup_features(dataset_name, variable_dict[dataset_name]["dataloader"])) 

In [None]:
for dataset_name in [ "proveai"
                    ]:
    print(dataset_name)
    variable_dict[dataset_name].update(setup_features(dataset_name, variable_dict[dataset_name]["dataloader"])) 

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap",
                    ]:
    print(dataset_name)
    variable_dict[dataset_name].update(setup_features(dataset_name, variable_dict[dataset_name]["dataloader"])) 

In [None]:
!gpustat

In [None]:
import torchvision
def get_layer_feature(model, feature_layer_name, image):
    # image = self.normalize(self.toTensor(img)).unsqueeze(0).to(self.device)
    # embedding = torch.zeros(image.shape[0], num_features, 1, 1).to(image.device)
    feature_layer = model._modules.get(feature_layer_name)

    embedding = []

    def copyData(module, input, output):
        embedding.append(output.data)

    h = feature_layer.register_forward_hook(copyData)
    out = model(image.to(image.device))
    h.remove()
    embedding = embedding[0]
    assert embedding.shape[0] == image.shape[0], f"{embedding.shape[0]} != {image.shape[0]}"
    assert embedding.shape[2] == 1, f"{embedding.shape[2]} != 1"
    assert embedding.shape[2] == 1, f"{embedding.shape[3]} != 1"
    return embedding[:, :, 0, 0]

def batch_func_efficientnet(batch):
    with torch.no_grad():
        efficientnet_feature = get_layer_feature(
            efficientnet, "avgpool", batch["image"].to(efficientnet_device)
        )

    return {
        "efficientnet_feature": efficientnet_feature.detach().cpu(),
        "metadata": batch["metadata"],
    }

def setup_features_efficientnet(dataset_name, dataloader):
    loader_applied = dataloader_apply_func(
        dataloader=dataloader,
        func=batch_func_efficientnet,
        collate_fn=custom_collate_per_key,
    )
    image_features = loader_applied["efficientnet_feature"].cpu()

    return {"efficientnet_feature":image_features}

In [None]:
!gpustat

In [None]:
import torchvision

In [None]:
efficientnet_device="cuda:2"
efficientnet = torchvision.models.efficientnet_v2_s(
    weights=torchvision.models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
).to(efficientnet_device)
efficientnet.eval()

In [None]:
for dataset_name in ["proveai",
                    ]:
    print(dataset_name)
    variable_dict[dataset_name].update(setup_features_efficientnet(dataset_name, variable_dict[dataset_name]["dataloader"])) 

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", "proveai",
                    ]:
    print(dataset_name)
    variable_dict[dataset_name].update(setup_features_efficientnet(dataset_name, variable_dict[dataset_name]["dataloader"])) 

In [None]:
isic_diagnosis_malignant_mapping=\
{'AIMP':'indeterminate',
'acrochordon':'benign',
'actinic keratosis':'benign', # 
'angiofibroma or fibrous papule':'benign', 
'angiokeratoma':'benign',
'angioma':'benign',
'atypical melanocytic proliferation':'indeterminate',
'atypical spitz tumor':'indeterminate', #
'basal cell carcinoma':'malignant', #
'cafe-au-lait macule':'benign',
'clear cell acanthoma':'benign',
'dermatofibroma':'benign', #
'lentigo NOS':'benign',
'lentigo simplex':'benign',
'lichenoid keratosis':'benign',
'melanoma':'malignant',
'melanoma metastasis':'malignant',
'neurofibroma':'benign',
'nevus':'benign',
'other':'indeterminate',
'pigmented benign keratosis':'benign', #??
'scar':'benign',
'seborrheic keratosis':'benign',
'solar lentigo':'benign',
'squamous cell carcinoma':'malignant',
'vascular lesion':'benign', # 
'verruca':'benign'
}

def isic_map_diagnosis_malignant(diagnosis, benign_malignant):
#     if diagnosis=="basal cell carcinoma":
#         print(diagnosis_malignant_mapping[diagnosis], benign_malignant)    
    if isinstance(benign_malignant, str):
        return benign_malignant
    elif diagnosis in isic_diagnosis_malignant_mapping.keys():
        return isic_diagnosis_malignant_mapping[diagnosis]
    elif np.isnan(diagnosis):
        return "indeterminate"
    else:
        raise RuntimeError

In [None]:
derm7pt_diagnosis_malignant_mapping=\
{   'basal cell carcinoma':'malignant', 
    'blue nevus': 'benign', #
    'clark nevus':'benign', # 
    'combined nevus': 'benign', #
    'congenital nevus': 'benign', #
    'dermal nevus': 'benign', 
    'dermatofibroma':'benign', 
    'lentigo': 'benign',
    'melanoma (in situ)': 'malignant',
    'melanoma (less than 0.76 mm)': 'malignant',
    'melanoma (0.76 to 1.5 mm)': 'malignant',
    'melanoma (more than 1.5 mm)': 'malignant',
    'melanoma metastasis': 'malignant',
    'melanosis': 'benign',# 
    'miscellaneous': 'unknown', #
    'recurrent nevus': 'benign', #
    'reed or spitz nevus': 'benign', #
    'seborrheic keratosis':'benign',
    'vascular lesion': 'benign', 
    'melanoma': 'malignant',
}

def derm7pt_map_diagnosis_malignant(diagnosis):
#     if diagnosis=="basal cell carcinoma":
#         print(diagnosis_malignant_mapping[diagnosis], benign_malignant)    
    if diagnosis in derm7pt_diagnosis_malignant_mapping.keys():
        return derm7pt_diagnosis_malignant_mapping[diagnosis]
    elif np.isnan(diagnosis):
        return "indeterminate"
    else:
        raise RuntimeError

In [None]:
def set_config(dataset_name, metadata_all):
    if "clinical_fd_clean" in dataset_name:
        y_pos=(((metadata_all["source"]=="fitz")&(metadata_all["three_partition_label"]=="malignant"))|
              ((metadata_all["source"]=="ddi")&(metadata_all["malignant"] == True))).values
        
        valid_idx=(metadata_all["skincon_Do not consider this image"]!=1).values
        
        concept_list=skincon_cols
    
        
    elif "isic" in dataset_name:  
        metadata_all["benign_malignant_full"]=\
        metadata_all.apply(lambda x: isic_map_diagnosis_malignant(x["diagnosis"], x["benign_malignant"]), axis=1)
        #metadata_all["benign_malignant_full"].value_counts()
        #metadata_all.groupby("diagnosis").apply(lambda x: x["benign_malignant_full"].value_counts())
        metadata_all["benign_malignant_bool"]=metadata_all["benign_malignant_full"].str.contains("malignant")
        
        y_pos=metadata_all["benign_malignant_bool"].values
        
        valid_idx = (metadata_all["benign_malignant_full"].str.contains("malignant")|metadata_all["benign_malignant_full"].str.contains("benign")).values
        
        
        concept_list=skincon_cols
        
        concept_list=concept_list+\
                            ["purple pen", 
                             "finger", 
                             "nail", 
                             "pinkish", 
                             "red", 
                             "hair", 
                             "orange sticker", 
                             "blue sticker", 
                             "red sticker",
                             "dermoscope border",
                             "gel",
                             "malignant",
                             "melanoma"]      
        
        
        concept_list=concept_list+[f"derm7ptconcept_{derm7ptconcept}" for derm7ptconcept in ["pigment network", "typical pigment network", "atypical pigment network",
                                   "regression structure",
                                   "pigmentation", "regular pigmentation", "irregular pigmentation",
                                   "blue whitish veil", 
                                   "vascular structures", "typical vascular structures", "atypical vascular structures",
                                   "streaks", "regular streaks", "irregular streaks",
                                   "dots and globules", "regular dots and globules", "irregular dots and globules",
                                  ]]
        
        concept_list=concept_list+[f"isicconcept_{isicconcept}" for isicconcept in ["pigment_network", 
                                                                                   "negative_network",
                                                                                   "milia_like_cyst", 
                                                                                   "streaks", 
                                                                                   "globules"]]
        
        concept_list=concept_list+[f"disease_{disease_name}" for disease_name in ['seborrheic keratosis', 'nevus', 'squamous cell carcinoma',
                        'melanoma', 'lichenoid keratosis', 'lentigo',
                        'actinic keratosis', 'basal cell carcinoma', 'dermatofibroma',
                        'atypical melanocytic proliferation', 'verruca',
                        'clear cell acanthoma', 'angiofibroma or fibrous papule', 'scar',
                        'angioma', 'atypical spitz tumor', 'solar lentigo', 'AIMP',
                        'neurofibroma', 'lentigo simplex', 'acrochordon', 
                        'angiokeratoma', 'vascular lesion', 'cafe-au-lait macule',
                        'pigmented benign keratosis']]
        
        
        
    elif "proveai" in dataset_name:  
        metadata_all
        
        prove_logits_true=pd.read_csv("data/proveai/isic_upd_rev.csv",index_col=0).rename(columns={"truth": "y_pos"})
        
        y_pos=prove_logits_true.set_index("image_name")["y_pos"].loc[metadata_all.index].values
        
#         y_pos=prove_logits_true.set_index("image_name")["target"].loc[metadata_all.index]>0.05
        
        valid_idx = ~np.isnan(y_pos)
        
        
        concept_list=skincon_cols
        
        concept_list=concept_list+\
                            ["purple pen", 
                             "finger", 
                             "nail", 
                             "pinkish", 
                             "red", 
                             "hair", 
                             "orange sticker", 
                             "blue sticker", 
                             "red sticker",
                             "dermoscope border",
                             "gel",
                             "malignant",
                             "melanoma"]      
        
        
        concept_list=concept_list+[f"derm7ptconcept_{derm7ptconcept}" for derm7ptconcept in ["pigment network", "typical pigment network", "atypical pigment network",
                                   "regression structure",
                                   "pigmentation", "regular pigmentation", "irregular pigmentation",
                                   "blue whitish veil", 
                                   "vascular structures", "typical vascular structures", "atypical vascular structures",
                                   "streaks", "regular streaks", "irregular streaks",
                                   "dots and globules", "regular dots and globules", "irregular dots and globules",
                                  ]]
        
        concept_list=concept_list+[f"isicconcept_{isicconcept}" for isicconcept in ["pigment_network", 
                                                                                   "negative_network",
                                                                                   "milia_like_cyst", 
                                                                                   "streaks", 
                                                                                   "globules"]]
        
        concept_list=concept_list+[f"disease_{disease_name}" for disease_name in ['seborrheic keratosis', 'nevus', 'squamous cell carcinoma',
                        'melanoma', 'lichenoid keratosis', 'lentigo',
                        'actinic keratosis', 'basal cell carcinoma', 'dermatofibroma',
                        'atypical melanocytic proliferation', 'verruca',
                        'clear cell acanthoma', 'angiofibroma or fibrous papule', 'scar',
                        'angioma', 'atypical spitz tumor', 'solar lentigo', 'AIMP',
                        'neurofibroma', 'lentigo simplex', 'acrochordon', 
                        'angiokeratoma', 'vascular lesion', 'cafe-au-lait macule',
                        'pigmented benign keratosis']]        
        
    elif "derm7pt" in dataset_name:  
        metadata_all["benign_malignant_full"]=\
        metadata_all.apply(lambda x: derm7pt_map_diagnosis_malignant(x["diagnosis"]), axis=1)
        metadata_all["benign_malignant_bool"]=metadata_all["benign_malignant_full"].str.contains("malignant")
        
        y_pos=metadata_all["benign_malignant_bool"].values
        
        valid_idx = (~metadata_all["diagnosis"].isnull()).values
        
        concept_list=skincon_cols
        
        concept_list=concept_list+\
                            ["purple pen", 
                             "finger", 
                             "nail", 
                             "pinkish", 
                             "red", 
                             "hair", 
                             "orange sticker", 
                             "blue sticker", 
                             "red sticker",
                             "dermoscope border",
                             "gel",
                             "malignant",
                             "melanoma"]        
        
        concept_list=concept_list+[f"derm7ptconcept_{derm7ptconcept}" for derm7ptconcept in ["pigment network", "typical pigment network", "atypical pigment network",
                                   "regression structure",
                                   "pigmentation", "regular pigmentation", "irregular pigmentation",
                                   "blue whitish veil", 
                                   "vascular structures", "typical vascular structures", "atypical vascular structures",
                                   "streaks", "regular streaks", "irregular streaks",
                                   "dots and globules", "regular dots and globules", "irregular dots and globules",
                                  ]]
        
        concept_list=concept_list+[f"isicconcept_{isicconcept}" for isicconcept in ["pigment_network", 
                                                                                   "negative_network",
                                                                                   "milia_like_cyst", 
                                                                                   "streaks", 
                                                                                   "globules"]]
             
        
        concept_list=concept_list+[f"disease_{disease_name}" for disease_name in ['basal cell carcinoma', 'blue nevus', 'clark nevus',
                                                               'combined nevus', 'congenital nevus', 'dermal nevus',
                                                               'dermatofibroma', 'lentigo', 'melanoma', 'melanosis',
                                                                'recurrent nevus', 'reed or spitz nevus',
                                                               'seborrheic keratosis', 'vascular lesion']]   
        
    
        
        
    return {"valid_idx": valid_idx,
            "y_pos": y_pos,
            "metadata_all": metadata_all,
            "concept_list": concept_list}

for dataset_name in ["clinical_fd_clean_nodup_nooverlap"]:
    variable_dict[dataset_name].update(
        set_config(dataset_name, variable_dict[dataset_name]["metadata_all"])
    )

In [None]:
variable_dict.keys()

In [None]:
pd.read_csv("data/proveai/isic_upd.csv", index_col=0)#.rename(columns={"truth": "y_pos"})

In [None]:
pd.read_csv("data/proveai/isic_upd_rev.csv", index_col=0)#.rename(columns={"truth": "y_pos"})

In [None]:
prove_logits_true=pd.read_csv("data/proveai/isic_upd_rev.csv", index_col=0)

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", "proveai"
                    ]:
    variable_dict[dataset_name].update(
        set_config(dataset_name, variable_dict[dataset_name]["metadata_all"])
    )

In [None]:
def normalize_embedding(dataset_name, image_features):
    #prompt_ref_tokenized = clip.tokenize(prompt_ref, truncate=True)
    #output = model.model_step_with_text({"text": prompt_ref_tokenized.to(model_device)})
    #prompt_ref_embedding=output["text_features"].detach().cpu()
    #prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=1, keepdim=True)      
    
    image_features_norm = image_features / image_features.norm(dim=1, keepdim=True)
    
    return {"image_features_norm": image_features_norm}    

In [None]:
dataset_name

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", "proveai"
                    ]:
    if ("clinical_fd_clean" in dataset_name) or ("derm7pt" in dataset_name):
        variable_dict[dataset_name].update(
            {"image_features_vanilla_norm":normalize_embedding(dataset_name, 
                            variable_dict[dataset_name]["image_features_vanilla"])["image_features_norm"]}
        )           
    variable_dict[dataset_name].update(
        {"image_features_norm":normalize_embedding(dataset_name, 
                        variable_dict[dataset_name]["image_features"])["image_features_norm"]}
    )  
#     variable_dict[dataset_name].update(
#         {"image_features_vanilla_norm":normalize_embedding(dataset_name, 
#                         variable_dict[dataset_name]["image_features_vanilla"])["image_features_norm"]}
#     )    

In [None]:
for dataset_name in [ "proveai"
                    ]:
    if ("clinical_fd_clean" in dataset_name) or ("derm7pt" in dataset_name):
        variable_dict[dataset_name].update(
            {"image_features_vanilla_norm":normalize_embedding(dataset_name, 
                            variable_dict[dataset_name]["image_features_vanilla"])["image_features_norm"]}
        )           
    variable_dict[dataset_name].update(
        {"image_features_norm":normalize_embedding(dataset_name, 
                        variable_dict[dataset_name]["image_features"])["image_features_norm"]}
    )  
#     variable_dict[dataset_name].update(
#         {"image_features_vanilla_norm":normalize_embedding(dataset_name, 
#                         variable_dict[dataset_name]["image_features_vanilla"])["image_features_norm"]}
#     )    

In [None]:
def get_concept_embedding(dataset_name, concept_list, clip_model):
    prompt_info={}
    
    for concept_name in concept_list:
        if "clinical_fd_clean" in dataset_name:
            prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
            prompt_engineered_list = []
            for k, v in prompt_dict.items():
                if k != "original":
                    prompt_engineered_list += v    
            concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                      for prompt in prompt_engineered_list]))
            prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
            #prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]            
            prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
            
            prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]        
        
        elif "isic" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
                
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil","blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)         
                    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                     
                
            elif concept_name.startswith("isicconcept_"):
                isicconcept=concept_name[12:]
                if isicconcept=="pigment_network":
                    concept_term_list=["pigment network"]
                elif isicconcept=="negative_network":
                    concept_term_list=["negative network"]
                elif isicconcept=="milia_like_cyst":
                    concept_term_list=["milia like cyst"]
                    concept_term_list=["seborrheic keratosis"]
                elif isicconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif isicconcept=="globules":
                    concept_term_list=["globules"]
                else:
                    raise ValueError(isicconcept)                
            
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                 
                
            elif concept_name.startswith("disease_"):  
                if concept_name=="disease_AIMP":
                    disease_name=concept_name[8:]
                    prompt_target=[["This is dermatoscopy of AIMP",
                                    "This is dermatoscopy of Atypical intraepidermal melanocytic proliferation"],
                                   ["This is dermoscopy of AIMP",
                                    "This is dermoscopy of Atypical intraepidermal melanocytic proliferation"]]
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                else:
                    disease_name=concept_name[8:]
                    prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                                   [f"This is dermoscopy of {disease_name}"]] 
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                    
                    
        elif "proveai" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
                
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil","blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)         
                    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                     
                
            elif concept_name.startswith("isicconcept_"):
                isicconcept=concept_name[12:]
                if isicconcept=="pigment_network":
                    concept_term_list=["pigment network"]
                elif isicconcept=="negative_network":
                    concept_term_list=["negative network"]
                elif isicconcept=="milia_like_cyst":
                    concept_term_list=["milia like cyst"]
                    concept_term_list=["seborrheic keratosis"]
                elif isicconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif isicconcept=="globules":
                    concept_term_list=["globules"]
                else:
                    raise ValueError(isicconcept)                
            
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                 
                
            elif concept_name.startswith("disease_"):  
                if concept_name=="disease_AIMP":
                    disease_name=concept_name[8:]
                    prompt_target=[["This is dermatoscopy of AIMP",
                                    "This is dermatoscopy of Atypical intraepidermal melanocytic proliferation"],
                                   ["This is dermoscopy of AIMP",
                                    "This is dermoscopy of Atypical intraepidermal melanocytic proliferation"]]
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                else:
                    disease_name=concept_name[8:]
                    prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                                   [f"This is dermoscopy of {disease_name}"]] 
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                    
                    
                    
        elif "derm7pt_derm" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil", "blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)
    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
            elif concept_name.startswith("isicconcept_"):
                isicconcept=concept_name[12:]
                if isicconcept=="pigment_network":
                    concept_term_list=["pigment network"]
                elif isicconcept=="negative_network":
                    concept_term_list=["negative network"]
                elif isicconcept=="milia_like_cyst":
                    concept_term_list=["milia like cyst"]
                    concept_term_list=["seborrheic keratosis"]
                elif isicconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif isicconcept=="globules":
                    concept_term_list=["globules"]
                else:
                    raise ValueError(isicconcept)                
            
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                  
                
            elif concept_name.startswith("disease_"):  
                disease_name=concept_name[8:]
                prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                               [f"This is dermoscopy of {disease_name}"]] 
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]                    
                
        
        #print(prompt_target, prompt_ref)
        # target embedding
        prompt_target_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_target]
        with torch.no_grad():
            prompt_target_embedding = torch.stack([clip_model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_target_tokenized])
        prompt_target_embedding_norm=prompt_target_embedding/prompt_target_embedding.norm(dim=2, keepdim=True)          

        # reference embedding
        prompt_ref_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_ref]
        with torch.no_grad():
            prompt_ref_embedding = torch.stack([clip_model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_ref_tokenized])
        prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=2, keepdim=True)                

        prompt_info[concept_name]={"prompt_ref_embedding_norm":prompt_ref_embedding_norm,
                                   "prompt_target_embedding_norm":prompt_target_embedding_norm,
                                  }
        print(dataset_name, concept_name, "\n" , prompt_target, prompt_ref)
        print('-----------')
        #print(prompt_target_embedding_norm)
        #print(prompt_ref_embedding_norm)
        del prompt_ref
        del prompt_target
    
    return {"prompt_info": prompt_info}

In [None]:
variable_dict[dataset_name]["prompt_info"]["skincon_Macule"]["prompt_target_embedding_norm"].shape

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", "proveai"
                    ]:
    if ("clinical_fd_clean" in dataset_name) or ("derm7pt" in dataset_name):
        variable_dict[dataset_name].update(
            {"prompt_info_vanilla":get_concept_embedding(dataset_name, 
                          concept_list=variable_dict[dataset_name]["concept_list"],
                                 clip_model=model_vanilla)["prompt_info"]})            
    variable_dict[dataset_name].update(
        {"prompt_info":get_concept_embedding(dataset_name, 
                      concept_list=variable_dict[dataset_name]["concept_list"],
                             clip_model=model)["prompt_info"]})  

In [None]:
for dataset_name in [ "proveai"
                    ]:
    if ("clinical_fd_clean" in dataset_name) or ("derm7pt" in dataset_name):
        variable_dict[dataset_name].update(
            {"prompt_info_vanilla":get_concept_embedding(dataset_name, 
                          concept_list=variable_dict[dataset_name]["concept_list"],
                                 clip_model=model_vanilla)["prompt_info"]})            
    variable_dict[dataset_name].update(
        {"prompt_info":get_concept_embedding(dataset_name, 
                      concept_list=variable_dict[dataset_name]["concept_list"],
                             clip_model=model)["prompt_info"]})  

In [None]:
def calculate_similaity_score(image_features_norm, 
                              prompt_target_embedding_norm,
                              prompt_ref_embedding_norm,
                              temp=1,
                              normalize=True):

    target_similarity=prompt_target_embedding_norm.float()@image_features_norm.T.float()
    ref_similarity=prompt_ref_embedding_norm.float()@image_features_norm.T.float()
    
    target_similarity_mean=target_similarity.mean(dim=[1]) # (template, terms, num_images) -> (template, num_images)
    ref_similarity_mean=ref_similarity.mean(axis=1) # (template, 1, num_images) -> (template, num_images)
         
    if normalize:
        similarity_score=scipy.special.softmax([target_similarity_mean.numpy()/temp, 
                            ref_similarity_mean.numpy()/temp], axis=0)[0,:].mean(axis=0)   
        # (template, num_images), (template, num_images) -> (template, num_images) -> (1, num_images)
    else:
        similarity_score=target_similarity_mean.mean(axis=0)

    
    return similarity_score

In [None]:
def calculate_similaity_score_(image_features_norm, 
                              prompt_target_embedding_norm,
                              prompt_ref_embedding_norm,
                              temp=1/np.exp(4.5944),
                              normalize=True):

    target_similarity=prompt_target_embedding_norm.float()@image_features_norm.T.float() # (template, terms, num_features) @ (num_features, num_images) -> (template, terms, num_images) 
    ref_similarity=prompt_ref_embedding_norm.float()@image_features_norm.T.float() # (template, terms, num_features) @ (num_features, num_images) -> (template, terms, num_images) 
    
    target_similarity_mean=target_similarity.mean(dim=[1]) # (template, terms, num_images) -> (template, num_images)
    ref_similarity_mean=ref_similarity.mean(axis=1) # (template, 1, num_images) -> (template, num_images)
         
    if normalize:
        similarity_score=scipy.special.softmax([target_similarity_mean.numpy()/temp, 
                            ref_similarity_mean.numpy()/temp], axis=0)[0,:].mean(axis=0)   
        # (template, num_images), (template, num_images) -> (template, num_images) -> (1, num_images)
    else:
        similarity_score=target_similarity_mean.mean(axis=0)

    
    return similarity_score

In [None]:
model_loaded.net.model.logit_scale

In [None]:
from sklearn.model_selection import train_test_split
from MONET.datamodules.components.base_dataset import BaseDataset

In [None]:
import torchvision
from torch import nn
from torch.nn import functional as F
from torchmetrics import AUROC

class Classifier(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.backbone = torchvision.models.resnet50(weights="ResNet50_Weights.IMAGENET1K_V1")

        for param in self.backbone.parameters():
            param.requires_grad = True
            # pass

        head_in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.head = nn.Linear(head_in_features, output_dim)

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x  


class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
    
def find_thres_best_f1(y_test, y_test_pred):
    precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_test, y_test_pred)
    numerator = 2 * recall * precision
    denom = recall + precision
    f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom!=0))
    max_f1 = np.max(f1_scores)
    max_f1_thresh = thresholds[np.argmax(f1_scores)]    
    return max_f1_thresh 

def train_classifier(train_dataloader, val_dataloader, test_dataloader, classifier_type="resnet", verbose=True):
    if classifier_type=="resnet":
        classifier = Classifier(output_dim=1)
    elif classifier_type=="inception":
        classifier = Inception(output_dim=1)
    classifier_device = "cuda:3"
    classifier.to(classifier_device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=2, verbose=True)
    early_stopper = EarlyStopper(patience=5, min_delta=0)

    train_auroc = AUROC(task="binary")
    val_auroc = AUROC(task="binary")
    for epoch in range(20):
        train_loss = 0
        train_correct = 0
        classifier.train()
        if verbose:
            pbar=tqdm.tqdm(train_dataloader)
        else:
            pbar=train_dataloader        
        for batch in pbar:
            image, label = batch["image"].to(classifier_device), batch["label"].to(classifier_device)
            logits = classifier(image)
            weight = torch.ones(label.shape[0], device=label.device)
            loss = F.binary_cross_entropy_with_logits(
                input=logits[:, 0], target=(label == 1).float()
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * image.size(0)
            train_auroc.update(logits, (label == 1))

        val_loss = 0
        classifier.eval()
        label_list=[]
        logits_list=[]
        with torch.no_grad():   
            if verbose:
                pbar=tqdm.tqdm(val_dataloader)
            else:
                pbar=val_dataloader             
            for batch in pbar:
                image, label = batch["image"].to(classifier_device), batch["label"].to(
                    classifier_device
                )
                logits = classifier(image)
                loss = F.binary_cross_entropy_with_logits(
                    input=logits[:, 0], target=(label == 1).float()
                )
                val_loss += loss.item() * image.size(0)
                val_auroc.update(logits, (label == 1))
                logits_list.append(logits.detach().cpu().numpy())
                label_list.append(label.detach().cpu().numpy())                
        if verbose:
            print(
                f"Train loss: {train_loss/len(train_dataloader.dataset):.3f} AUROC: {train_auroc.compute():.3f} Val loss: {val_loss/len(val_dataloader.dataset):.3f} AUROC: {val_auroc.compute():.3f}"
            )
        scheduler.step(val_loss)
        if early_stopper.early_stop(val_loss):
            print("break")
            break
        train_auroc.reset()
        val_auroc.reset() 
        max_f1_thresh=find_thres_best_f1(y_test=np.hstack(label_list), y_test_pred=np.concatenate(logits_list)[:,0])
        print(max_f1_thresh)
        
        
    

        
    test_auroc = AUROC(task="binary")    
    test_loss = 0
    classifier.eval()
    
    logits_list=[]
    label_list=[]
    metadata_list=[]
    
    with torch.no_grad():
        if verbose:
            pbar=tqdm.tqdm(test_dataloader)
        else:
            pbar=test_dataloader          
        for batch in tqdm.tqdm(test_dataloader):
            image, label = batch["image"].to(classifier_device), batch["label"].to(
                classifier_device
            )
            logits = classifier(image)
            loss = F.binary_cross_entropy_with_logits(
                input=logits[:, 0], target=(label == 1).float()
            )
            test_loss += loss.item() * image.size(0)
            test_auroc.update(logits, (label == 1))
            logits_list.append(logits.detach().cpu().numpy())
            label_list.append(label.detach().cpu().numpy())
            metadata_list.append(batch["metadata"])
            
    if verbose:
        print(
            f"Test loss: {test_loss/len(test_dataloader.dataset):.3f} AUROC: {test_auroc.compute():.3f}"
        )   
    return test_auroc.compute(), classifier, logits_list, label_list, metadata_list, max_f1_thresh

In [None]:
def fdrcorrection(pvals, alpha=0.05, method='indep', is_sorted=False):
    '''
    pvalue correction for false discovery rate.

    This covers Benjamini/Hochberg for independent or positively correlated and
    Benjamini/Yekutieli for general or negatively correlated tests.

    Parameters
    ----------
    pvals : array_like, 1d
        Set of p-values of the individual tests.
    alpha : float, optional
        Family-wise error rate. Defaults to ``0.05``.
    method : {'i', 'indep', 'p', 'poscorr', 'n', 'negcorr'}, optional
        Which method to use for FDR correction.
        ``{'i', 'indep', 'p', 'poscorr'}`` all refer to ``fdr_bh``
        (Benjamini/Hochberg for independent or positively
        correlated tests). ``{'n', 'negcorr'}`` both refer to ``fdr_by``
        (Benjamini/Yekutieli for general or negatively correlated tests).
        Defaults to ``'indep'``.
    is_sorted : bool, optional
        If False (default), the p_values will be sorted, but the corrected
        pvalues are in the original order. If True, then it assumed that the
        pvalues are already sorted in ascending order.

    Returns
    -------
    rejected : ndarray, bool
        True if a hypothesis is rejected, False if not
    pvalue-corrected : ndarray
        pvalues adjusted for multiple hypothesis testing to limit FDR

    Notes
    -----
    If there is prior information on the fraction of true hypothesis, then alpha
    should be set to ``alpha * m/m_0`` where m is the number of tests,
    given by the p-values, and m_0 is an estimate of the true hypothesis.
    (see Benjamini, Krieger and Yekuteli)

    The two-step method of Benjamini, Krieger and Yekutiel that estimates the number
    of false hypotheses will be available (soon).

    Both methods exposed via this function (Benjamini/Hochberg, Benjamini/Yekutieli)
    are also available in the function ``multipletests``, as ``method="fdr_bh"`` and
    ``method="fdr_by"``, respectively.

    See also
    --------
    multipletests

    '''
    

    def _ecdf(x):
        '''no frills empirical cdf used in fdrcorrection
        '''
        nobs = len(x)
        return np.arange(1,nobs+1)/float(nobs)    

    pvals = np.asarray(pvals)
    assert pvals.ndim == 1, "pvals must be 1-dimensional, that is of shape (n,)"

    if not is_sorted:
        pvals_sortind = np.argsort(pvals)
        pvals_sorted = np.take(pvals, pvals_sortind)
    else:
        pvals_sorted = pvals  # alias

    if method in ['i', 'indep', 'p', 'poscorr']:
        ecdffactor = _ecdf(pvals_sorted)
    elif method in ['n', 'negcorr']:
        cm = np.sum(1./np.arange(1, len(pvals_sorted)+1))   #corrected this
        ecdffactor = _ecdf(pvals_sorted) / cm
##    elif method in ['n', 'negcorr']:
##        cm = np.sum(np.arange(len(pvals)))
##        ecdffactor = ecdf(pvals_sorted)/cm
    else:
        raise ValueError('only indep and negcorr implemented')
    reject = pvals_sorted <= ecdffactor*alpha
    if reject.any():
        rejectmax = max(np.nonzero(reject)[0])
        reject[:rejectmax] = True

    pvals_corrected_raw = pvals_sorted / ecdffactor
    pvals_corrected = np.minimum.accumulate(pvals_corrected_raw[::-1])[::-1]
    del pvals_corrected_raw
    pvals_corrected[pvals_corrected>1] = 1
    if not is_sorted:
        pvals_corrected_ = np.empty_like(pvals_corrected)
        pvals_corrected_[pvals_sortind] = pvals_corrected
        del pvals_corrected
        reject_ = np.empty_like(reject)
        reject_[pvals_sortind] = reject
        return reject_, pvals_corrected_
    else:
        return reject, pvals_corrected

In [None]:
from scipy.special import xlogy
def log_loss(
    y_true, y_pred, *, eps="auto", normalize=True, sample_weight=None, labels=None
):
    
    r"""Log loss, aka logistic loss or cross-entropy loss.

    This is the loss function used in (multinomial) logistic regression
    and extensions of it such as neural networks, defined as the negative
    log-likelihood of a logistic model that returns ``y_pred`` probabilities
    for its training data ``y_true``.
    The log loss is only defined for two or more labels.
    For a single sample with true label :math:`y \in \{0,1\}` and
    a probability estimate :math:`p = \operatorname{Pr}(y = 1)`, the log
    loss is:

    .. math::
        L_{\log}(y, p) = -(y \log (p) + (1 - y) \log (1 - p))

    Read more in the :ref:`User Guide <log_loss>`.

    Parameters
    ----------
    y_true : array-like or label indicator matrix
        Ground truth (correct) labels for n_samples samples.

    y_pred : array-like of float, shape = (n_samples, n_classes) or (n_samples,)
        Predicted probabilities, as returned by a classifier's
        predict_proba method. If ``y_pred.shape = (n_samples,)``
        the probabilities provided are assumed to be that of the
        positive class. The labels in ``y_pred`` are assumed to be
        ordered alphabetically, as done by
        :class:`preprocessing.LabelBinarizer`.

    eps : float or "auto", default="auto"
        Log loss is undefined for p=0 or p=1, so probabilities are
        clipped to `max(eps, min(1 - eps, p))`. The default will depend on the
        data type of `y_pred` and is set to `np.finfo(y_pred.dtype).eps`.

        .. versionadded:: 1.2

        .. versionchanged:: 1.2
           The default value changed from `1e-15` to `"auto"` that is
           equivalent to `np.finfo(y_pred.dtype).eps`.

    normalize : bool, default=True
        If true, return the mean loss per sample.
        Otherwise, return the sum of the per-sample losses.

    sample_weight : array-like of shape (n_samples,), default=None
        Sample weights.

    labels : array-like, default=None
        If not provided, labels will be inferred from y_true. If ``labels``
        is ``None`` and ``y_pred`` has shape (n_samples,) the labels are
        assumed to be binary and are inferred from ``y_true``.

        .. versionadded:: 0.18

    Returns
    -------
    loss : float
        Log loss, aka logistic loss or cross-entropy loss.

    Notes
    -----
    The logarithm used is the natural logarithm (base-e).

    References
    ----------
    C.M. Bishop (2006). Pattern Recognition and Machine Learning. Springer,
    p. 209.

    Examples
    --------
    >>> from sklearn.metrics import log_loss
    >>> log_loss(["spam", "ham", "ham", "spam"],
    ...          [[.1, .9], [.9, .1], [.8, .2], [.35, .65]])
    0.21616...
    """
    
    def _weighted_sum(sample_score, sample_weight, normalize=False):
        if normalize:
            return np.average(sample_score, weights=sample_weight)
        elif sample_weight is not None:
            return np.dot(sample_score, sample_weight)
        else:
            return sample_score.sum()    
    y_pred = sklearn.utils.check_array(
        y_pred, ensure_2d=False, dtype=[np.float64, np.float32, np.float16]
    )
    eps = np.finfo(y_pred.dtype).eps if eps == "auto" else eps

    sklearn.utils.check_consistent_length(y_pred, y_true, sample_weight)
    lb = sklearn.preprocessing.LabelBinarizer()
    if labels is not None:
        lb.fit(labels)
    else:
        lb.fit(y_true)

    if len(lb.classes_) == 1:
        if labels is None:
            raise ValueError(
                "y_true contains only one label ({0}). Please "
                "provide the true labels explicitly through the "
                "labels argument.".format(lb.classes_[0])
            )
        else:
            raise ValueError(
                "The labels array needs to contain at least two "
                "labels for log_loss, "
                "got {0}.".format(lb.classes_)
            )

    transformed_labels = lb.transform(y_true)

    if transformed_labels.shape[1] == 1:
        transformed_labels = np.append(
            1 - transformed_labels, transformed_labels, axis=1
        )

    # Clipping
    y_pred = np.clip(y_pred, eps, 1 - eps)

    # If y_pred is of single dimension, assume y_true to be binary
    # and then check.
    if y_pred.ndim == 1:
        y_pred = y_pred[:, np.newaxis]
    if y_pred.shape[1] == 1:
        y_pred = np.append(1 - y_pred, y_pred, axis=1)

    # Check if dimensions are consistent.
    transformed_labels = sklearn.utils.check_array(transformed_labels)
    if len(lb.classes_) != y_pred.shape[1]:
        if labels is None:
            raise ValueError(
                "y_true and y_pred contain different number of "
                "classes {0}, {1}. Please provide the true "
                "labels explicitly through the labels argument. "
                "Classes found in "
                "y_true: {2}".format(
                    transformed_labels.shape[1], y_pred.shape[1], lb.classes_
                )
            )
        else:
            raise ValueError(
                "The number of classes in labels is different "
                "from that in y_pred. Classes found in "
                "labels: {0}".format(lb.classes_)
            )

    # Renormalize
#     print(y_pred)
    y_pred_sum = y_pred.sum(axis=1)
    y_pred = y_pred / y_pred_sum[:, np.newaxis]
#     print(y_pred)
#     print(-xlogy(transformed_labels, y_pred))
    loss = -xlogy(transformed_labels, y_pred).sum(axis=1)
#     print(-xlogy(transformed_labels, y_pred))
#     print(loss)
    return loss
#     return _weighted_sum(loss, sample_weight, normalize)

In [None]:
def similarity_matrix(embeddings,
                      prompt_info,
                     idx):

    concept_similarity_all=[]
    for concept_name in prompt_info.keys():
        concept_similarity=calculate_similaity_score(
            image_features_norm=embeddings,
            prompt_target_embedding_norm=prompt_info[concept_name]["prompt_target_embedding_norm"],
            prompt_ref_embedding_norm=prompt_info[concept_name]["prompt_ref_embedding_norm"],
            temp=1/np.exp(4.5944),
            normalize=True)
        concept_similarity_all.append(pd.Series(concept_similarity, 
                                                index=idx,
                                               name=concept_name
                                               )
                                     )
                                      
#                                       {"concept_name":concept_name,
#                                       "concept_similarity":,
#                                       })
    concept_similarity_all=pd.concat(concept_similarity_all, axis=1)
    
    return concept_similarity_all

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", "proveai"
                    ]:
    if ("clinical_fd_clean" in dataset_name) or ("derm7pt" in dataset_name):
        variable_dict[dataset_name].update(
            {"similarity_matrix_vanilla": similarity_matrix(embeddings=variable_dict[dataset_name]["image_features_vanilla_norm"],
                             prompt_info=variable_dict[dataset_name]["prompt_info_vanilla"],
                            idx=variable_dict[dataset_name]["metadata_all"].index
                     )})  
    variable_dict[dataset_name].update(
        {"similarity_matrix": similarity_matrix(embeddings=variable_dict[dataset_name]["image_features_norm"],
                         prompt_info=variable_dict[dataset_name]["prompt_info"],
                        idx=variable_dict[dataset_name]["metadata_all"].index
                 )})  

In [None]:
variable_dict["derm7pt_derm_nodup"]["similarity_matrix_vanilla"]

In [None]:
variable_dict["derm7pt_derm_nodup"]["similarity_matrix"]

In [None]:
variable_dict["derm7pt_derm_nodup"]["image_features_vanilla_norm"]

In [None]:
variable_dict["derm7pt_derm_nodup"]["image_features_norm"]

In [None]:
variable_dict["derm7pt_derm_nodup"]["prompt_info_vanilla"]

In [None]:
variable_dict["derm7pt_derm_nodup"]["prompt_info"]

In [None]:
def check_concept_name(dataset_name, concept_name):
    if "isic" in dataset_name:
        if concept_name.startswith("disease"):
            return False
        elif concept_name.startswith("isicconcept_"):
            return False        
        elif concept_name.startswith("derm7ptconcept_"):
            if 'typical' in concept_name:
                return False                
            elif 'regular' in concept_name:
                return False
            elif "pigmentation" in concept_name:
                return False
            else:
                return True        
        elif concept_name in ["melanoma", "malignant", "finger", "red sticker"]:
            return False
        else:
            return True
    else:
        raise NotImplemented(dataset_name)

In [None]:
#check
from sklearn.model_selection import train_test_split
from MONET.datamodules.components.base_dataset import BaseDataset

def get_training_data_idx(dataloader, valid_idx, y_pos, subset_idx_train, subset_idx_test, n_px=None):
    metadata_all_new = dataloader.dataset.metadata_all.copy()
    metadata_all_new["label"]=y_pos.astype(int)
    
    metadata_all_new_=metadata_all_new[valid_idx]
    
#     print(subset_idx_train)
    
    train_idx, val_idx = train_test_split(np.unique(subset_idx_train), test_size=0.2, random_state=42)
    
    metadata_all_new_train=metadata_all_new_.loc[[i for i in subset_idx_train if i in train_idx]]
    metadata_all_new_val=metadata_all_new_.loc[[i for i in subset_idx_train if i in val_idx]]
    metadata_all_new_test=metadata_all_new_.loc[subset_idx_test]
    
    
    
    print("train:", len(metadata_all_new_train))
    print("val:", len(metadata_all_new_val))
    print("test:", len(metadata_all_new_test))

    if n_px is None:
        n_px=dataloader.dataset.n_px
    
    data_train = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new_train,
        integrity_level="weak",
        return_label=["label"],
    )

    data_val = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new_val,
        integrity_level="weak",
        return_label=["label"],
    )
    
    data_test = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new_test,
        integrity_level="weak",
        return_label=["label"],
    )    

    from MONET.utils.loader import custom_collate

    train_dataloader = torch.utils.data.DataLoader(
        dataset=data_train,
        batch_size=64,
        num_workers=4,
#         pin_memory=True,
#         persistent_workers=False,
        shuffle=True,
        collate_fn=custom_collate,
    )
    val_dataloader = torch.utils.data.DataLoader(
        dataset=data_val,
        batch_size=64,
        num_workers=4,
#         pin_memory=True,
#         persistent_workers=False,
        shuffle=False,
        collate_fn=custom_collate,
    )   
    
    test_dataloader = torch.utils.data.DataLoader(
        dataset=data_test,
        batch_size=64,
        num_workers=4,
#         pin_memory=True,
#         persistent_workers=False,
        shuffle=False,
        collate_fn=custom_collate,
    )       
    
    return train_dataloader, val_dataloader, test_dataloader

# forced direction

In [None]:
def get_concept_bool_from_metadata(dataset_name, metadata_all, concept_name):
    if "derm7pt_derm" in dataset_name:
        if concept_name=="derm7ptconcept_pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
            concept_bool=(metadata_all["pigment_network"]!="absent")
        elif concept_name=="derm7ptconcept_typical pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
#                 valid_idx=(metadata_all["pigment_network"].str.contains("typical")).values
            concept_bool=(metadata_all["pigment_network"]=="typical")
        elif concept_name=="derm7ptconcept_atypical pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
#                 valid_idx=(metadata_all["pigment_network"].str.contains("typical")).values
            concept_bool=(metadata_all["pigment_network"]=="atypical")
        elif concept_name=="derm7ptconcept_regression structure":
            valid_idx=(~metadata_all["regression_structures"].isnull()).values
            concept_bool=(metadata_all["regression_structures"]!="absent")
        elif concept_name=="derm7ptconcept_pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
            concept_bool=(metadata_all["pigmentation"]!="absent")
        elif concept_name=="derm7ptconcept_regular pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
#                 valid_idx=(metadata_all["pigmentation"].str.contains("regular")).values
            concept_bool=(metadata_all["pigmentation"].str.contains(" regular"))
        elif concept_name=="derm7ptconcept_irregular pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
#                 valid_idx=(metadata_all["pigmentation"].str.contains("regular")).values
            concept_bool=(metadata_all["pigmentation"].str.contains(" irregular"))
        elif concept_name=="derm7ptconcept_blue whitish veil":
            valid_idx=(~metadata_all["blue_whitish_veil"].isnull()).values
            concept_bool=(metadata_all["blue_whitish_veil"]!="absent")
        elif concept_name=="derm7ptconcept_vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"]!="absent")
        elif concept_name=="derm7ptconcept_typical vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"].isin(["within regression", "arborizing", "comma", "hairpin", "wreath"]))
        elif concept_name=="derm7ptconcept_atypical vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"].isin(["dotted", "linear irregular"]))
        elif concept_name=="derm7ptconcept_streaks":
            #print(metadata_all["streaks"].value_counts())
            valid_idx=(~metadata_all["streaks"].isnull()).values
            concept_bool=(metadata_all["streaks"]!="absent")
        elif concept_name=="derm7ptconcept_regular streaks":
            valid_idx=(~metadata_all["streaks"].isnull()).values
            #valid_idx=(metadata_all["streaks"].str.contains("regular")).values
            concept_bool=(metadata_all["streaks"]=="regular")
        elif concept_name=="derm7ptconcept_irregular streaks":
            valid_idx=(~metadata_all["streaks"].isnull()).values
            #valid_idx=(metadata_all["streaks"].str.contains("regular")).values
            concept_bool=(metadata_all["streaks"]=="irregular")
        elif concept_name=="derm7ptconcept_dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]!="absent")
        elif concept_name=="derm7ptconcept_regular dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]=="regular")
        elif concept_name=="derm7ptconcept_irregular dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]=="irregular")
        else:
            raise ValueError(concept_name)
            
        concept_bool_true=concept_bool
        concept_bool_false=(~concept_bool)


    elif "isic" in dataset_name:
        if concept_name=='isicconcept_pigment_network':
            label=(metadata_all["pigment_network"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values
            concept_bool=(label>30)
        elif concept_name=='isicconcept_negative_network':
            label=(metadata_all["negative_network"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_milia_like_cyst':
            label=(metadata_all["milia_like_cyst"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_streaks':
            label=(metadata_all["streaks"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_globules':
            label=(metadata_all["globules"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values
            concept_bool=(label>30)
        else:
            raise ValueError(concept_name)            
            
    elif "clinical_fd_clean" in dataset_name:
        if concept_name.startswith("skincon_"):
            concept_bool_true=(metadata_all[concept_name]==1)
            concept_bool_false=(metadata_all[concept_name]==0)
        else:
            raise ValueError(concept_name)
        
      
    return {"concept_bool_true": concept_bool_true,
            "concept_bool_false": concept_bool_false,
           }

In [None]:
def forced_training(dataset_name, dataloader, concept_list, metadata_all, y_pos, random_seed_range=[]):
    simulation_data_list=[]

    for concept_name in concept_list:
        if "clinical_fd_clean" in dataset_name:
            concept_idx=(metadata_all["skincon_Do not consider this image"]==0).values
        elif "derm7pt" in dataset_name:
            concept_idx=(~metadata_all["regression_structures"].isnull()).values
            
            
        if get_concept_bool_from_metadata(dataset_name, 
                        metadata_all[concept_idx], 
                        concept_name)["concept_bool_true"].values.astype(bool).sum()<30:
            print(concept_name, "!!!!!!!!!!!!!!!!!!!!!!!! SKIPPED !!!!!!!!!!!!!!!!!!!!!!!!")
            continue  

    #     if concept_name!="skincon_Brown(Hyperpigmentation)":
    #         continue


        num_train_pos=500
        num_train_neg=500

        num_test_pos=500
        num_test_neg=500
    #     for proportion in [1.0, 0.8, 0.2, 0.0]:
        for proportion in [1]:
            num_train_pos_with = int(num_train_pos*proportion)
            num_train_pos_without = num_train_pos-num_train_pos_with

            num_train_neg_with = int(num_train_neg*(1-proportion))
            num_train_neg_without = num_train_neg-num_train_neg_with


            num_test_pos_with = int(num_test_pos*(1-proportion))
            num_test_pos_without = num_test_pos-num_test_pos_with

            num_test_neg_with = int(num_test_neg*(proportion))
            num_test_neg_without = num_test_neg-num_test_neg_with


            for random_seed in random_seed_range:
                subset_idx_train_, subset_idx_test_ = train_test_split(np.arange(len(concept_idx))[concept_idx], 
                                                                       test_size=0.4, 
                                                                       random_state=random_seed)



                metadata_all_train_=metadata_all.iloc[subset_idx_train_]
                y_pos_train_=y_pos[subset_idx_train_]
                metadata_all_train_=metadata_all_train_.copy()
                metadata_all_train_["y_pos"]=y_pos_train_
                
                
                
                concept_bool_train_true=get_concept_bool_from_metadata(dataset_name, 
                                                            metadata_all_train_, 
                                                            concept_name)["concept_bool_true"]
                concept_bool_train_false=get_concept_bool_from_metadata(dataset_name, 
                                                            metadata_all_train_, 
                                                            concept_name)["concept_bool_false"]                        

                print(len(metadata_all_train_[concept_bool_train_true&(metadata_all_train_["y_pos"]==True)]),
                      len(metadata_all_train_[concept_bool_train_false&(metadata_all_train_["y_pos"]==True)]),
                      len(metadata_all_train_[concept_bool_train_true&(metadata_all_train_["y_pos"]==False)]),
                      len(metadata_all_train_[concept_bool_train_false&(metadata_all_train_["y_pos"]==False)]))
                
                if len(metadata_all_train_[concept_bool_train_true&(metadata_all_train_["y_pos"]==True)])<30 or\
                len(metadata_all_train_[concept_bool_train_false&(metadata_all_train_["y_pos"]==True)])<30 or\
                len(metadata_all_train_[concept_bool_train_true&(metadata_all_train_["y_pos"]==False)])<30 or\
                len(metadata_all_train_[concept_bool_train_false&(metadata_all_train_["y_pos"]==False)])<30:
                    print("Train Not found")
                    continue
                    
                #\continue

                
                train_idx_pos_with=metadata_all_train_[concept_bool_train_true&(metadata_all_train_["y_pos"]==True)].sample(n=num_train_pos_with, replace=True, random_state=random_seed).index
                train_idx_pos_without=metadata_all_train_[concept_bool_train_false&(metadata_all_train_["y_pos"]==True)].sample(n=num_train_pos_without, replace=True, random_state=random_seed).index

                train_idx_neg_with=metadata_all_train_[concept_bool_train_true&(metadata_all_train_["y_pos"]==False)].sample(n=num_train_neg_with, replace=True, random_state=random_seed).index
                train_idx_neg_without=metadata_all_train_[concept_bool_train_false&(metadata_all_train_["y_pos"]==False)].sample(n=num_train_neg_without, replace=True, random_state=random_seed).index

                train_idx=train_idx_pos_with.tolist()+train_idx_pos_without.tolist()+train_idx_neg_with.tolist()+train_idx_neg_without.tolist()
                # train_idx=metadata_all_train_[(metadata_all_train_[concept_name]==1] ###
                #train_idx=metadata_all_train_.index.tolist()
                #print(metadata_all_train_.loc[train_idx][[concept_name,"y_pos"]])
                metadata_all_train=metadata_all_train_.loc[train_idx]


                metadata_all_test_=metadata_all.iloc[subset_idx_test_]
                y_pos_test_=y_pos[subset_idx_test_]
                metadata_all_test_=metadata_all_test_.copy()
                metadata_all_test_["y_pos"]=y_pos_test_
                
                concept_bool_test_true=get_concept_bool_from_metadata(dataset_name, 
                                                            metadata_all_test_, 
                                                            concept_name)["concept_bool_true"]
                concept_bool_test_false=get_concept_bool_from_metadata(dataset_name, 
                                                            metadata_all_test_, 
                                                            concept_name)["concept_bool_false"]                    

                if len(metadata_all_test_[concept_bool_test_true&(metadata_all_test_["y_pos"]==True)])<30 or\
                len(metadata_all_test_[concept_bool_test_false&(metadata_all_test_["y_pos"]==True)])<30 or\
                len(metadata_all_test_[concept_bool_test_true&(metadata_all_test_["y_pos"]==False)])<30 or\
                len(metadata_all_test_[concept_bool_test_false&(metadata_all_test_["y_pos"]==False)])<30:        
                    print("Test Not found")
                    continue
                #continue


                test_idx_pos_with=metadata_all_test_[concept_bool_test_true&(metadata_all_test_["y_pos"]==True)].sample(n=num_test_pos_with, replace=True, random_state=random_seed).index
                test_idx_pos_without=metadata_all_test_[concept_bool_test_false&(metadata_all_test_["y_pos"]==True)].sample(n=num_test_pos_without, replace=True, random_state=random_seed).index

                test_idx_neg_with=metadata_all_test_[concept_bool_test_true&(metadata_all_test_["y_pos"]==False)].sample(n=num_test_neg_with, replace=True, random_state=random_seed).index
                test_idx_neg_without=metadata_all_test_[concept_bool_test_false&(metadata_all_test_["y_pos"]==False)].sample(n=num_test_neg_without, replace=True, random_state=random_seed).index

                test_idx=test_idx_pos_with.tolist()+test_idx_pos_without.tolist()+test_idx_neg_with.tolist()+test_idx_neg_without.tolist()


                metadata_all_test=metadata_all_test_.loc[test_idx]
                print(len(train_idx), len(test_idx))

                train_dataloader, val_dataloader, test_dataloader=\
                get_training_data_idx(dataloader=dataloader, 
                                  valid_idx=concept_idx, 
                                  y_pos=y_pos, 
    #                               y_pos=variable_dict[dataset_name]["dataloader"].dataset.metadata_all[concept_name].fillna(0), 
                                  subset_idx_train=metadata_all_train.index, 
                                  #subset_idx_test=test_idx, 
                                  subset_idx_test=metadata_all_test.index,
                                  n_px=None)            

    #             print(len(train_dataloader))
    #             print(len(val_dataloader))
    #             print(len(test_dataloader))
    #             subset_idx_train=

                auc, x, logits_test, label_test, metadata_test, max_f1_thres =train_classifier(train_dataloader=train_dataloader, 
                                       val_dataloader=val_dataloader,
                                       test_dataloader=test_dataloader, verbose=True)  

                metadata_test=pd.concat(metadata_test)  
                label_test=pd.Series(np.hstack(label_test), index=metadata_test.index)
                logit_test=pd.Series(np.concatenate(logits_test)[:,0], index=metadata_test.index)

                simulation_data_list.append({"concept_name": concept_name,
                                             "random_seed": random_seed,
                                             "proportion": proportion,
                                             "label_test": label_test,
                                             "logit_test": logit_test,
                                             "metadata_all_train": metadata_all_train,
                                             "metadata_all_test": metadata_all_test,
                                             "metadata_test": metadata_test,
                                             "max_f1_thres": max_f1_thres,
                                            })
    #             label_list=np.hstack(label_list)
    #             logits_list=np.concatenate(logits_list)[:,0]
    #             metadata_list=pd.concat(metadata_list)         

    #         record_dict_list.append({"concept_name": concept_name})
            print(concept_name)
    return {"model_auditing_simulation_data": simulation_data_list}
 

In [None]:
for dataset_name in ["derm7pt_derm_nodup"
                    ]:
    x=forced_training(dataset_name=dataset_name,
                      dataloader=variable_dict[dataset_name]["dataloader"], 
                      concept_list=['derm7ptconcept_pigment network',
                    'derm7ptconcept_regression structure',
                    'derm7ptconcept_pigmentation',
                    'derm7ptconcept_blue whitish veil',
                    'derm7ptconcept_vascular structures',
                    'derm7ptconcept_streaks',
                    'derm7ptconcept_dots and globules'],
                      metadata_all=variable_dict[dataset_name]["metadata_all"], 
                      y_pos=variable_dict[dataset_name]["y_pos"],
                     random_seed_range=list(range(0,20))
                     )
#         variable_dict[dataset_name].update()        

In [None]:
for dataset_name in ["derm7pt_derm_nodup"
                    ]:
    x=forced_training(dataset_name=dataset_name,
                      dataloader=variable_dict[dataset_name]["dataloader"], 
                      concept_list=['derm7ptconcept_pigment network',
                    'derm7ptconcept_regression structure',
                    'derm7ptconcept_pigmentation',
                    'derm7ptconcept_blue whitish veil',
                    'derm7ptconcept_vascular structures',
                    'derm7ptconcept_streaks',
                    'derm7ptconcept_dots and globules'],
                      metadata_all=variable_dict[dataset_name]["metadata_all"], 
                      y_pos=variable_dict[dataset_name]["y_pos"],
                     random_seed_range=[0]
                     )
#         variable_dict[dataset_name].update()        

In [None]:
0.936 0.37
0.925 0.645
0.934 0.590
0.941 0.379

In [None]:
for dataset_name in ["derm7pt_derm_nodup"
                    ]:
    x=forced_training(dataset_name=dataset_name,
                      dataloader=variable_dict[dataset_name]["dataloader"], 
                      concept_list=['derm7ptconcept_pigment network',
                    'derm7ptconcept_regression structure',
                    'derm7ptconcept_pigmentation',
                    'derm7ptconcept_blue whitish veil',
                    'derm7ptconcept_vascular structures',
                    'derm7ptconcept_streaks',
                    'derm7ptconcept_dots and globules'],
                      metadata_all=variable_dict[dataset_name]["metadata_all"], 
                      y_pos=variable_dict[dataset_name]["y_pos"])
#         variable_dict[dataset_name].update()        

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap"
                    ]:
    x=forced_training(dataset_name=dataset_name,
                      dataloader=variable_dict[dataset_name]["dataloader"], 
                      concept_list=skincon_cols, 
                      metadata_all=variable_dict[dataset_name]["metadata_all"], 
                      y_pos=variable_dict[dataset_name]["y_pos"])
#         variable_dict[dataset_name].update()        

In [None]:
pd.DataFrame(variable_dict["derm7pt_derm_nodup"]["model_auditing_simulation_data"])["concept_name"].value_counts()



In [None]:
pigment network, regression structure, pigment, streaks, vascular structure

In [None]:
pd.DataFrame(variable_dict["clinical_fd_clean_nodup_nooverlap"]["model_auditing_simulation_data"])["concept_name"].value_counts()



In [None]:
p

In [None]:
variable_dict["derm7pt_derm_nodup"].keys()

In [None]:
x.keys()

In [None]:
100*20

In [None]:
TP / (TP+FP) * (TP+FN)/(TN+FP)

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["valid_idx"]

In [None]:
variable_dict["derm7pt_derm_nodup"]["metadata_all"]

In [None]:
# torch.save(variable_dict, "logs/experiment_results/model_auditing_0827.pt")

In [None]:
!ls logs/experiment_results/ -trl

In [None]:
variable_dict=torch.load("logs/experiment_results/model_auditing_0827.pt", map_location="cpu")

In [None]:
pd.DataFrame(variable_dict["derm7pt_derm_nodup"]["model_auditing_simulation_data"])["concept_name"].value_counts()

In [None]:
model_auditing_simulation_data

In [None]:
variable_dict["derm7pt_derm_nodup"]["evaluation_model_audit"]

In [None]:
# torch.save(simulation_data_list, "logs/experiment_results/model_audit_benchmark_0525.pt")

In [None]:
# simulation_data_list=torch.load("logs/experiment_results/model_audit_benchmark_0525.pt")

In [None]:
simulation_data_list=torch.load("logs/experiment_results/model_audit_benchmark_0526.pt", map_location="cpu")

In [None]:
# torch.save(variable_dict, "logs/experiment_results/model_audit_data_0526.pt")

In [None]:
variable_dict=torch.load("logs/experiment_results/model_audit_data_0526.pt", map_location="cpu")

In [None]:
pd.DataFrame(simulation_data_list)

In [None]:
# logit_test

# log_loss(y_pred=logit_test,
#         y_true=label_test).mean()

In [None]:
pd.set_option('display.max_columns', 500)

In [None]:
variable_dict[dataset_name].keys()

current method+CLIP
DOMINO+MONET

In [None]:
# similarity_info_copy=cluster_concept_test(similarity_info=concept_similarity_all,
#                                    ground_truth=variable_dict["clinical_fd_clean_nodup"]["metadata_all"][skincon_cols],
#                                    clustering_features=concept_similarity_all,
#                                    labels=label_test, logits=logit_test,
#                                    threshold=0,
#                                    score_threshold=0.8, accuracy_diff=0.1)

In [None]:
from scipy.stats import fisher_exact
def fisher_test_df(concept_list_bool_dict, y_pos):
    res_df=[]
    for concept_name, concept_bool in concept_list_bool_dict.items():
#         print(column)
#         print([[((data[column]==1)&(y_pos.loc[data.index]==True)).sum(), ((data[column]==0)&(y_pos.loc[data.index]==True)).sum()],
#             [((data[column]==1)&(y_pos.loc[data.index]==False)).sum(), ((data[column]==0)&(y_pos.loc[data.index]==False)).sum()]])
#         print(data.shape)
#         print(((data[column]==1).shape,(y_pos.loc[data.index]==True).shape))

        concept_bool_true=concept_bool["concept_bool_true"]
        concept_bool_false=concept_bool["concept_bool_false"]

        y_1_c_1=((y_pos==True)&(concept_bool_true)).sum()
        y_1_c_0=((y_pos==True)&(concept_bool_false)).sum()
        y_0_c_1=((y_pos==False)&(concept_bool_true)).sum()
        y_0_c_0=((y_pos==False)&(concept_bool_false)).sum()
    

        res=fisher_exact(
            [[y_1_c_1, y_1_c_0],
            [y_0_c_1, y_0_c_0]])
    
#         rl_top=((data[column]==1)&(y_pos==True)).sum()/(((data[column]==1)).sum())
#         rl_bottom=((data[column]==0)&(y_pos==True)).sum()/(((data[column]==0)).sum())
        
#         rl_top=((data[column]==1)&(y_pos==True)).sum()/(((y_pos==True)).sum())
#         rl_bottom=((data[column]==1)&(y_pos==False)).sum()/(((y_pos==False)).sum())        
#         print(res)
        direction=(y_1_c_1-y_1_c_0)*(y_0_c_1-y_0_c_0)
        
        res_df.append({"name": concept_name,
                       "y=1,c=1":y_1_c_1,
                       "y=1,c=0":y_1_c_0,
                       "y=0,c=1":y_0_c_1,
                       "y=0,c=0":y_0_c_0,
                       "direction":direction,
                       "direction1":(y_1_c_1-y_1_c_0),
                       "direction2":(y_0_c_1-y_0_c_0),                       
#                        "rl": rl_top/ rl_bottom,
                      "pvalue": res.pvalue,
                       "statistic": res.statistic,
                      })
        
#         print()
    
    res_df=pd.DataFrame(res_df).sort_values('statistic').set_index('name')
    fdr_corrected=fdrcorrection(res_df["pvalue"])
    res_df["FDR_rejected"]=fdr_corrected[0]
    res_df["FDR_pvalue_adjusted"]=fdr_corrected[1]    
    res_df["bof_pvalue_adjusted"]=res_df["pvalue"]*len(res_df)
    res_df["bof_pvalue_adjusted"]=res_df["bof_pvalue_adjusted"].map(lambda x:1 if x>1 else x)
    #res_df["bof_rejected"]=res_df["bof_pvalue_adjusted"]<0.05
    res_df["bof_rejected"]=res_df["bof_pvalue_adjusted"]<0.01
    return res_df 

In [None]:
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import random
from scipy.stats import mode

def cluster_concept_test_real(similarity_info, clustering_features, fixed_answer,
                         labels, logits, threshold,
                         metric_diff=0.5, metric_over=0,
                         n_clusters=40, random_state=42, return_only_highperforming=True):
    
    torch.manual_seed(random_state)
    random.seed(random_state)
    np.random.seed(random_state)    
    
    record_list=[]
    
    per_label=False
    
    if per_label:
        labels_unique=np.unique(labels)
    else:
        labels_unique=[None]
        
    for label in labels_unique:
        if label is not None:
            focus_idx=labels[labels==label].index

            similarity_info_focus=similarity_info.loc[focus_idx].copy()
            clustering_features_focus=clustering_features.loc[focus_idx].copy()
            labels_focus=labels[labels==label].copy()
            logits_focus=logits[labels==label].copy()
        else:
            focus_idx=labels[labels.astype(int)>-9].index

            similarity_info_focus=similarity_info.loc[focus_idx].copy()
            clustering_features_focus=clustering_features.loc[focus_idx].copy()
            labels_focus=labels.copy()
            logits_focus=logits.copy()            
            
            
        assert (similarity_info_focus.index==clustering_features_focus.index).all()
        assert (similarity_info_focus.index==labels_focus.index).all()
        assert (similarity_info_focus.index==logits_focus.index).all()

        if clustering_features_focus.shape[1]<50:
            pca = PCA(n_components=10)
        else:
            pca = PCA(n_components=50)

        clustering_features_focus_pc=pca.fit_transform(clustering_features_focus)
    
        kmeans = KMeans(n_clusters=n_clusters//len(labels_unique), random_state=random_state, n_init="auto").fit(clustering_features_focus_pc)
        kmeans_dist=sklearn.metrics.pairwise_distances(kmeans.cluster_centers_)
    
        similarity_info_focus_copy=similarity_info_focus.copy()
        similarity_info_focus_copy["kmeans_label"]=kmeans.labels_
        similarity_info_focus_copy["kmeans_dist"]=((clustering_features_focus_pc-kmeans.cluster_centers_[kmeans.labels_])**2).sum(axis=1)
        similarity_info_focus_copy["accuracy"]=(labels_focus==(logits_focus>threshold))
        similarity_info_focus_copy["loss"]=-log_loss(y_true=labels_focus, y_pred=logits_focus.map(lambda x: 1/(1+np.exp(-x))), labels=[0,1])
        similarity_info_focus_copy["label"]=labels_focus
        similarity_info_focus_copy["logit"]=logits_focus
        
        similarity_info_focus_copy_group=similarity_info_focus_copy.groupby("kmeans_label")[similarity_info.columns.tolist()].apply(lambda x: pd.Series([x[i].values for i in x.columns], index=x.columns))
        similarity_info_focus_copy_group["count"]=similarity_info_focus_copy.groupby("kmeans_label").apply(len)
        similarity_info_focus_copy_group["accuracy"]=similarity_info_focus_copy.groupby("kmeans_label")["accuracy"].mean()
        similarity_info_focus_copy_group["loss"]=similarity_info_focus_copy.groupby("kmeans_label")["loss"].mean()
        similarity_info_focus_copy_group["label_frequent"]=similarity_info_focus_copy.groupby("kmeans_label")["label"].apply(lambda x: mode(x, keepdims=False).mode)

        metric_use="accuracy"

        for count, (idx, row) in enumerate(similarity_info_focus_copy_group.sort_values(metric_use, ascending=True).iterrows()):
            if return_only_highperforming:
                if row[metric_use]>=similarity_info_focus_copy[metric_use].mean():
                    continue
            

            sorted_idx=pd.Series(kmeans_dist[idx], index=sorted(np.unique(kmeans.labels_))).sort_values(ascending=True).index
            sorted_idx=[i for i in sorted_idx if (similarity_info_focus_copy_group.loc[i][metric_use]>(similarity_info_focus_copy[metric_use].mean()+metric_diff)) and \
                        (similarity_info_focus_copy_group.loc[i][metric_use]>=(metric_over))
                       ]
            
            similarity_info_focus_copy_group_diff_plus=similarity_info_focus_copy_group.copy().loc[[sorted_idx[0]]]
#             print(similarity_info_focus_copy_group_diff_plus)
            similarity_info_focus_copy_group_diff_plus[similarity_info.columns.tolist()]=similarity_info_focus_copy_group_diff_plus[similarity_info.columns.tolist()].apply(lambda x: pd.Series([(np.mean(row[i])-np.mean(x.loc[i])) for i in x.index], index=x.index), axis=1)
    
            x=pd.concat([
                similarity_info_focus_copy_group_diff_plus[similarity_info.columns.tolist()].loc[sorted_idx[0]].rename('diff_magnitude'),
                row[similarity_info.columns.tolist()].map(lambda x: np.mean(x)).rename("mean_value")
            ],
                axis=1)
#             print(x.sort_values("diff_magnitude", ascending=False))    

#             print(f"cluster_idx {count} / \
# Target metric: {row[metric_use]:.3f} / \
# Ref metric: {similarity_info_focus_copy_group.loc[sorted_idx[0]][metric_use]:.3f}/ \
# Target concept: {np.mean(row['purple pen']):.3f}/ \
# Ref concept: {np.mean(similarity_info_focus_copy_group.loc[sorted_idx[0]]['purple pen']):.3f}/ \
# Mean metric:  {similarity_info_focus_copy[metric_use].mean():.3f}\
# ")
#             print(np.mean(row['purple pen']),
#                   np.mean(similarity_info_focus_copy_group.loc[sorted_idx[0]]['purple pen'])
#                  ) 
            #sdsd
                        
            similarity_info_focus_copy_group_diff_minus=similarity_info_focus_copy_group.copy().loc[[sorted_idx[0]]]
            similarity_info_focus_copy_group_diff_minus[similarity_info.columns.tolist()]=similarity_info_focus_copy_group_diff_minus[similarity_info.columns.tolist()].apply(lambda x: pd.Series([-(np.mean(row[i])-np.mean(x.loc[i])) for i in x.index], index=x.index), axis=1)            
            #import ipdb
            #ipdb.set_trace()
            
#             print('-------')
#             print(x.sort_values('diff_magnitude', ascending=True).index==similarity_info_focus_copy_group_diff_minus[similarity_info.columns.tolist()].loc[sorted_idx[0]].sort_values(ascending=False).index.tolist())
#             print(x.sort_values('diff_magnitude', ascending=True))
#             print(idx, sorted_idx)
            
            
            
            record_list.append(
                { 
#                  "on_the_spot_plus_pred": similarity_info_focus_copy_group_diff_plus[similarity_info.columns.tolist()].loc[sorted_idx[0]].sort_values(ascending=False).index.tolist(),
#                  "on_the_spot_minus_pred": similarity_info_focus_copy_group_diff_minus[similarity_info.columns.tolist()].loc[sorted_idx[0]].sort_values(ascending=False).index.tolist(),
                 "on_the_spot_plus_pred": x[(x["mean_value"]>0.5)&(x["diff_magnitude"]>0)].sort_values("diff_magnitude", ascending=False).index.tolist(),
                 "on_the_spot_minus_pred": x[(x["mean_value"]>0.5)&(x["diff_magnitude"]<0)].sort_values("diff_magnitude", ascending=True).index.tolist(),                    
                 "statistics": x,
                 "labels": similarity_info_focus_copy[(similarity_info_focus_copy["kmeans_label"]==idx)][["kmeans_dist", metric_use]],
                 "labels_ref": similarity_info_focus_copy[(similarity_info_focus_copy["kmeans_label"]==sorted_idx[0])][["kmeans_dist", metric_use]]                 
                })
#             print(record_list[-1]["statistics"].sort_values("diff_magnitude", ascending=False).index==record_list[-1]["on_the_spot_minus_pred"]).all()
    
    return record_list, similarity_info_focus_copy_group

In [None]:
record_list_temp, similarity_info_focus_copy_group_temp =\
\
cluster_concept_test_real(similarity_info=variable_dict[dataset_name]["similarity_matrix"]
[variable_dict[dataset_name]["similarity_matrix"].columns[variable_dict[dataset_name]["similarity_matrix"].columns.map(lambda x: check_concept_name(dataset_name, x))]]
                                                     , 
                          clustering_features=pd.DataFrame(variable_dict[dataset_name]["efficientnet_feature"].numpy(),
                                                             index=variable_dict[dataset_name]["metadata_all"].index,
                                                            ), 
                          fixed_answer=["red"],
                         labels=label_subset_proveai, 
                          logits=logits_subset_proveai, 
                          threshold=threshold_select,
                         metric_diff=0,
                          metric_over=0,
                         n_clusters=10, random_state=42, return_only_highperforming=False)

In [None]:
record_list_temp, similarity_info_focus_copy_group_temp =\
\
cluster_concept_test_real(similarity_info=variable_dict[dataset_name]["similarity_matrix"]
[variable_dict[dataset_name]["similarity_matrix"].columns[variable_dict[dataset_name]["similarity_matrix"].columns.map(lambda x: check_concept_name(dataset_name, x))]]
                                                     , 
                          clustering_features=pd.DataFrame(variable_dict[dataset_name]["efficientnet_feature"].numpy(),
                                                             index=variable_dict[dataset_name]["metadata_all"].index,
                                                            ), 
                          fixed_answer=["red"],
                         labels=label_subset_proveai, 
                          logits=logits_subset_proveai, 
                          threshold=threshold_select,
                         metric_diff=0,
                          metric_over=0.5,
                         n_clusters=10, random_state=42, return_only_highperforming=False)

In [None]:
similarity_info_focus_copy_group_temp.sort_values("accuracy")[
    ["count","accuracy","loss", "label_frequent"]
]

In [None]:
similarity_info_focus_copy_group_diff_plus.iloc[0]

In [None]:
similarity_info_focus_copy_group_diff_plus.iloc[0].loc["skincon_Vesicle"]

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

In [None]:
import math

def evaluate_model_audit(dataset_name,
                         simulation_data_list,
                         similarity_matrix,
                         similarity_matrix_vanilla,
                         efficientnet_feature,
                         metadata_all,
                         concept_list
                        ):
    
    def get_ground_truth_on_the_spot(ground_truth, target_idx, reference_idx):
        ground_truth_copy=ground_truth.copy()
        
        ground_truth_copy_target={concept_name: concept_bool["concept_bool_true"].copy().loc[target_idx] for concept_name, concept_bool in ground_truth.items()}
        ground_truth_copy_reference={concept_name: concept_bool["concept_bool_true"].copy().loc[reference_idx] for concept_name, concept_bool in ground_truth.items()}
    
        ground_truth_copy_target=(pd.DataFrame(ground_truth_copy_target).mean(axis=0)>0.5).astype(int)
        ground_truth_copy_reference=(pd.DataFrame(ground_truth_copy_reference).mean(axis=0)>0.5).astype(int)
        ground_truth_copy_target_diff=ground_truth_copy_target-ground_truth_copy_reference
        return {"more_present":ground_truth_copy_target_diff[ground_truth_copy_target_diff>0].index,
                "less_present":ground_truth_copy_target_diff[ground_truth_copy_target_diff<0].index}
                
                
    record_dict_all=[]
    for simulation_count, simulation_data in enumerate(tqdm.tqdm(simulation_data_list)):
#         if simulation_count not in [0,20,40,60,80]:
#             continue
        concept_name=simulation_data["concept_name"]
        label_test=simulation_data["label_test"]
        logit_test=simulation_data["logit_test"]
        metadata_train=simulation_data["metadata_all_train"]        
        metadata_test=simulation_data["metadata_all_test"]                
#         metadata_test=simulation_data["metadata_test"]
        max_f1_thres=simulation_data["max_f1_thres"]
        random_seed=simulation_data["random_seed"]
        
        print("metadata_train",metadata_train.shape)
        print("metadata_test",metadata_train.shape)
        print("label_test",label_test.shape)
        print("logit_test",logit_test.shape)
    
             
        
        fisher_pvals_train = fisher_test_df(concept_list_bool_dict={i:get_concept_bool_from_metadata(dataset_name, metadata_train, i) for i in concept_list}, 
                                           y_pos=metadata_train["y_pos"])
        
        fisher_pvals_test = fisher_test_df(concept_list_bool_dict={i:get_concept_bool_from_metadata(dataset_name, metadata_test, i) for i in concept_list}, 
                                           y_pos=metadata_test["y_pos"])        
        
#         print(fisher_pvals_train)
#         print(fisher_pvals_test)
        
        test_less_represented=fisher_pvals_train[(fisher_pvals_train["direction"]<0)&(fisher_pvals_train["statistic"]>1)].index\
        .intersection(fisher_pvals_test[(fisher_pvals_test["direction"]<0)&(fisher_pvals_test["statistic"]<1)].index)
        
        test_more_represented=fisher_pvals_train[(fisher_pvals_train["direction"]<0)&(fisher_pvals_train["statistic"]<1)].index\
        .intersection(fisher_pvals_test[(fisher_pvals_test["direction"]<0)&(fisher_pvals_test["statistic"]>1)].index)
        
        test_more_represented_=fisher_pvals_train[(fisher_pvals_train["direction1"]>0)&(fisher_pvals_train["direction2"]<0)].index\
        .intersection(fisher_pvals_test[(fisher_pvals_test["direction1"]<0)&(fisher_pvals_test["direction2"]>0)].index)
        
        test_less_represented_=fisher_pvals_train[(fisher_pvals_train["direction1"]<0)&(fisher_pvals_train["direction2"]>0)].index\
        .intersection(fisher_pvals_test[(fisher_pvals_test["direction1"]>0)&(fisher_pvals_test["direction2"]<0)].index)        
        
        print(concept_name, "Train-/Test+", test_more_represented.tolist(), "Train+/Test-", test_less_represented.tolist())
        print("Test",concept_name, "Train-/Test+", test_more_represented_.tolist(), "Train+/Test-", test_less_represented_.tolist())
                
        fixed_answer=test_more_represented.tolist()+test_less_represented.tolist()

        if len(concept_list)<47:
            concept_name_test_list=concept_list+np.random.RandomState(random_seed).choice([i for i in skincon_cols if i not in ['skincon_Brown(Hyperpigmentation)','skincon_White(Hypopigmentation)','skincon_Blue', 'skincon_Pigmented']], size=48-len(concept_list), replace=False).tolist()
        else:
            concept_name_test_list=concept_list        
        
        
        
#         print(similarity_matrix)
#         print(similarity_matrix_vanilla)
#         dsds
        
        test_result_MONET=cluster_concept_test_real(similarity_info=similarity_matrix[concept_name_test_list],
                                                    clustering_features=pd.DataFrame(efficientnet_feature.numpy(),
                                                                     index=metadata_all.index,
                                                                    ),
                                                    fixed_answer=fixed_answer,
                                                    labels=label_test, logits=logit_test,
                                                    threshold=max_f1_thres,
        #                                            score_threshold=0.8, 
                                                    metric_diff=0,
                                                    n_clusters=40)[0]
        
        test_result_vanilla=cluster_concept_test_real(similarity_info=similarity_matrix_vanilla[concept_name_test_list],
                                                    clustering_features=pd.DataFrame(efficientnet_feature.numpy(),
                                                                     index=metadata_all.index,
                                                                    ),
                                                    fixed_answer=fixed_answer,
                                                    labels=label_test, logits=logit_test,
                                                    threshold=max_f1_thres,
        #                                            score_threshold=0.8, 
                                                    metric_diff=0,
                                                    n_clusters=40)[0]
        
        
        for model in ["MONET", "CLIP"]:
            if model=="MONET":
                test_result_list=test_result_MONET
            elif model=="CLIP":
                test_result_list=test_result_vanilla
            else:
                raise
        
            for test_result in test_result_list:
#                 ground_truth_on_the_spot=get_ground_truth_on_the_spot(ground_truth=metadata_all[skincon_cols], 
#                                              target_idx=test_result["labels"].index, 
#                                              reference_idx=test_result["labels_ref"].index)
                ground_truth_on_the_spot=get_ground_truth_on_the_spot(ground_truth={i:get_concept_bool_from_metadata(dataset_name, metadata_all, i) for i in concept_list}, 
                                             target_idx=test_result["labels"].index, 
                                             reference_idx=test_result["labels_ref"].index)                
        
#                 print(model, len(test_result_list), len(test_result["labels"]))
#                 print(test_result.keys(), len(concept_name_test_list),concept_name_test_list)
                for i in range(1,5+1):
                    if i>len(test_result["on_the_spot_plus_pred"]):
                        continue
                    record_dict_all.append({
                        "model": model,
                        "method": "on_the_spot_plus",
                        "rank_n": i,
                        "metric": len(set(ground_truth_on_the_spot["more_present"]).intersection(test_result["on_the_spot_plus_pred"][:i]))!=0,
                        "answer_length": len(set(ground_truth_on_the_spot["more_present"])),
                        "random_performance": 1-(math.perm(len(concept_name_test_list)-len(set(ground_truth_on_the_spot["more_present"])), i) / math.perm(len(concept_name_test_list), i)),
                        "target_group_size": len(test_result["labels"]),
                        "random_seed":random_seed,
                    })

                for i in range(1,5+1):
                    if i>len(test_result["on_the_spot_minus_pred"]):
                        continue                 
                    record_dict_all.append({
                        "model": model,
                        "method": "on_the_spot_minus",
                        "rank_n": i,
                        "metric": len(set(ground_truth_on_the_spot["less_present"]).intersection(test_result["on_the_spot_minus_pred"][:i]))!=0,
                        "answer_length": len(set(ground_truth_on_the_spot["less_present"])),
                        "target_group_size": len(test_result["labels"]),
                        "random_seed":random_seed,
                    })  
                    
#                 for i in range(1,5+1):
#                     record_dict_all.append({
#                         "model": model,
#                         "method": "on_the_spot_both",
#                         "rank_n": i,
#                         "metric": len(set(ground_truth_on_the_spot["more_present"]).intersection(test_result["on_the_spot_plus_pred"][:i]))!=0 and len(set(ground_truth_on_the_spot["less_present"]).intersection(test_result["on_the_spot_minus_pred"][:i]))!=0,
#                     })                      
                    
            for i in range(1,5+1):      
                record_dict_all.append({
                    "model": model,
                    "method": "fixed_answer_plus",
                    "answer_length": len(set(fixed_answer)),
                    "count": len(test_result_list),
                    "rank_n": i,
                    "metric": len(set(fixed_answer).intersection([p for test_result in test_result_list for p in test_result["on_the_spot_plus_pred"][:i]]))!=0,
                    "random_performance": 1-(math.comb(len(concept_name_test_list)-len(set(fixed_answer)), i) / math.comb(len(concept_name_test_list), i))**len(test_result_list),
                    "random_seed":random_seed,
                })    
                
            for i in range(1,5+1):
                record_dict_all.append({
                    "model": model,
                    "method": "fixed_answer_minus",
                    "answer_length": len(set(fixed_answer)),
                    "count": len(test_result_list),
                    "rank_n": i,
                    "metric": len(set(fixed_answer).intersection([p for test_result in test_result_list for p in test_result["on_the_spot_minus_pred"][:i]]))!=0,
                    "random_seed":random_seed,
                })    
        
    return record_dict_all

In [None]:
for dataset_name in ["derm7pt_derm_nodup"]:
    x=evaluate_model_audit(dataset_name=dataset_name,
                         simulation_data_list=variable_dict[dataset_name]["model_auditing_simulation_data"],
                         similarity_matrix=variable_dict[dataset_name]["similarity_matrix"],
                         similarity_matrix_vanilla=variable_dict[dataset_name]["similarity_matrix_vanilla"],
                         efficientnet_feature=variable_dict[dataset_name]["efficientnet_feature"],
                         metadata_all=variable_dict[dataset_name]["metadata_all"],
                         concept_list=['derm7ptconcept_pigment network',
                                        'derm7ptconcept_regression structure',
                                        'derm7ptconcept_pigmentation',
                                        'derm7ptconcept_blue whitish veil',
                                        'derm7ptconcept_vascular structures',
                                        'derm7ptconcept_streaks',
                                        'derm7ptconcept_dots and globules'],
                        )
#     variable_dict[dataset_name].update(
#         {"evaluation_model_audit": x})

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap"]:
    x=evaluate_model_audit(dataset_name=dataset_name,
                         simulation_data_list=variable_dict[dataset_name]["model_auditing_simulation_data"],
                         similarity_matrix=variable_dict[dataset_name]["similarity_matrix"],
                         similarity_matrix_vanilla=variable_dict[dataset_name]["similarity_matrix_vanilla"],
                         efficientnet_feature=variable_dict[dataset_name]["efficientnet_feature"],
                         metadata_all=variable_dict[dataset_name]["metadata_all"],
                         concept_list=skincon_cols,
                        )
#     variable_dict[dataset_name].update(
#         {"evaluation_model_audit": x})    

In [None]:
variable_dict.keys()

In [None]:
for i in variable_dict["clinical_fd_clean_nodup_nooverlap"]["model_auditing_simulation_data"][0]:
    print(i)

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap"]:
    evaluate_model_audit(simulation_data_list,
                         concept_list=variable_dict[dataset_name]["concept_list"],
                         similarity_matrix=variable_dict[dataset_name]["similarity_matrix"],
                         similarity_matrix_vanilla=variable_dict[dataset_name]["similarity_matrix_vanilla"],
                         efficientnet_feature=variable_dict[dataset_name]["efficientnet_feature"],
                         metadata_all=variable_dict[dataset_name]["metadata_all"],
                        )

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["similarity_matrix"]

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["similarity_matrix"]

In [None]:
variable_dict.keys()

In [None]:
variable_dict["derm7pt_derm_nodup"]["similarity_matrix"]

In [None]:
dataset_name="clinical_fd_clean_nodup_nooverlap"

In [None]:
variable_dict[dataset_name]["efficientnet_feature"]

In [None]:
variable_dict[dataset_name]["similarity_matrix"].shape,
variable_dict[dataset_name]["similarity_matrix_vanilla"].shape

In [None]:
variable_dict[dataset_name]["metadata_all"]

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["similarity_matrix"]

In [None]:
len(variable_dict["clinical_fd_clean_nodup_nooverlap"]["concept_list"])

In [None]:
evaluate_model_audit(simulation_data_list=simulation_data_list[:])

In [None]:
pd

In [None]:
eval_result_df[]

In [None]:
1-(math.comb(48-2, 1) / math.comb(48, 1))**20

In [None]:
1-(math.comb(48-2, 2) / math.comb(48, 2))**20

In [None]:
1-(math.comb(48-2, 4) / math.comb(48, 4))**20

In [None]:
eval_result_df[eval_result_df["method"]=="fixed_answer_plus"].groupby(["method", "rank_n"]).mean()

In [None]:
1-(math.perm(48-len(set(ground_truth_on_the_spot["more_present"])), i) / math.perm(48, i))

In [None]:
eval_result_df[(eval_result_df["method"]=="fixed_answer_plus")].groupby(["rank_n", "method"]).mean()

# generate_df

In [None]:
import math

In [None]:
eval_result_df.plot(x="random_performance", y="random_performance_")

In [None]:
variable_dict["derm7pt_derm_nodup"]["evaluation_model_audit"]
variable_dict["clinical_fd_clean_nodup_nooverlap"]["evaluation_model_audit"]

In [None]:
# eval_result_df[eval_result_df["method"]=="fixed_answer_minus"]["answer_length"].hist()
# eval_result_df[eval_result_df["method"]=="on_the_spot_minus"]["answer_length"].hist()
# eval_result=evaluate_model_audit(simulation_data_list=
# [
#     simulation_data_list[0],
#     simulation_data_list[20],
#     simulation_data_list[40],
#     simulation_data_list[60],
#     simulation_data_list[80],    

# ])
# eval_result=evaluate_model_audit(simulation_data_list=simulation_data_list)
eval_result_df=pd.DataFrame(variable_dict["derm7pt_derm_nodup"]["evaluation_model_audit"])
# eval_result_df=pd.DataFrame(x)
eval_result_df["metric_random_ratio"]=eval_result_df["metric"].astype(int)/eval_result_df["random_performance"]

In [None]:
simulation_data_list[0]["concept_name"],\
simulation_data_list[20]["concept_name"],\
simulation_data_list[40]["concept_name"],\
simulation_data_list[60]["concept_name"],\
simulation_data_list[80]["concept_name"]

In [None]:
# variable_dict["clinical_fd_clean_nodup"]["metadata_all"][
# (variable_dict["clinical_fd_clean_nodup"]["metadata_all"]["skincon_Crust"]==0)
# &(variable_dict["clinical_fd_clean_nodup"]["y_pos"]==False)
# ].iloc[5:]
# variable_dict["clinical_fd_clean_nodup"]["dataloader"].dataset.getitem(
# variable_dict["clinical_fd_clean_nodup"]["dataloader"].dataset.metadata_all.index.tolist().index(
# "7d2f3fa05f4f362299c1ed148e7fc719.jpg")
# )["image"]

# one ground truth

In [None]:
eval_result_df

In [None]:
plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                Paired[12][3],
                                                                                Paired[12][5],
                                                                                Paired[12][7],
                                                                                Paired[12][9],
                                                                                Paired[12][11]
                                                                                ]])
# fig = plt.figure(constrained_layout=True, figsize=(15, 6))
fig = plt.figure(figsize=(18, 5))
subfigs = fig.subfigures(1, 1)

axes = subfigs.subplots(1,1, gridspec_kw={"wspace":0.3})

# axd={'fixed': axes[0], "on_the_spot": axes[1] }
axd={'fixed': axes,}

plot_key="fixed"


# sns.barplot(x="rank_n", y="metric", hue="model", hue_order=["MONET", "CLIP"], 
#             data=eval_result_df[(eval_result_df["method"]=="fixed_answer_plus")&(eval_result_df["rank_n"]<=3)],
#            ax=axd[plot_key])

eval_result_df_mean_fixed=eval_result_df[
    (eval_result_df["method"]=="fixed_answer_plus")
    &(eval_result_df["answer_length"]!=0)
    &(eval_result_df["rank_n"]<=3)].groupby(["rank_n", "model"]).mean()
print(eval_result_df_mean_fixed)
sns.barplot(
    x="rank_n", y="metric", hue="model", hue_order=["MONET", "CLIP"],
    data=(eval_result_df_mean_fixed).reset_index(),
    width=0.5,
    ax=axd[plot_key]
)


# plt.title("\
# ground-truth is defined based on distribution of train/test set\
# \n(i.e., similar to the `red` confounder in the ISIC)")
for axis in ['top','bottom','left','right']:
    axd[plot_key].spines[axis].set_linewidth(1.5)
axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)

axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))
axd[plot_key].yaxis.grid(True, which='major', linewidth=0.4, alpha=0.4)
# axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.2, alpha=0.4)

axd[plot_key].set_ylim(0,1)

axd[plot_key].tick_params(axis='both', which='major', labelsize=16)
axd[plot_key].tick_params(axis='both', which='minor', labelsize=16)

axd[plot_key].set_xlabel("Top-N", fontsize=18)
# axd[plot_key].set_ylabel("Freq. of recovering spurious corr.\n(across all underperforming clusters)", fontsize=16)
axd[plot_key].set_ylabel("Freq. of recovering spurious corr.", fontsize=18)

for patch in axd[plot_key].patches :
    patch.set_linewidth(1)
    patch.set_edgecolor("black")


    
# leg=axd[plot_key].legend(fontsize = 16, facecolor='white', framealpha=0.5)
axd[plot_key].get_legend().remove()
axd[plot_key].text(x=-0.1, 
                   y=1.09, 
#                    y=1.03, 
                   transform=axd[plot_key].transAxes,
                     s=" B.", fontsize=23, weight='bold')
# leg.set_title("Model", prop={"size":16})


# axd[plot_key].set_title("Do the top N rec spurious correlation", fontsize=16)
axd[plot_key].set_title("Do the top-N concept explanations recover spurious correlations?\n(across all low-performing clusters)", 
                        fontsize=18)


# fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.pdf", bbox_inches='tight')

In [None]:
plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                Paired[12][3],
                                                                                Paired[12][5],
                                                                                Paired[12][7],
                                                                                Paired[12][9],
                                                                                Paired[12][11]
                                                                                ]])
# fig = plt.figure(constrained_layout=True, figsize=(15, 6))
fig = plt.figure(figsize=(18, 5))
subfigs = fig.subfigures(1, 1)

axes = subfigs.subplots(1,2, gridspec_kw={"wspace":0.3})

axd={'fixed': axes[0], "on_the_spot": axes[1] }

plot_key="fixed"


# sns.barplot(x="rank_n", y="metric", hue="model", hue_order=["MONET", "CLIP"], 
#             data=eval_result_df[(eval_result_df["method"]=="fixed_answer_plus")&(eval_result_df["rank_n"]<=3)],
#            ax=axd[plot_key])

eval_result_df_mean_fixed=eval_result_df[
    (eval_result_df["method"]=="fixed_answer_plus")
    &(eval_result_df["answer_length"]!=0)
    &(eval_result_df["rank_n"]<=3)].groupby(["rank_n", "model", "random_seed"]).mean()
print(eval_result_df_mean_fixed)
sns.barplot(
    x="rank_n", y="metric", hue="model", hue_order=["MONET", "CLIP"],
    data=(eval_result_df_mean_fixed).reset_index(),
    width=0.7,
    ax=axd[plot_key]
)

# plt.title("\
# ground-truth is defined based on distribution of train/test set\
# \n(i.e., similar to the `red` confounder in the ISIC)")
for axis in ['top','bottom','left','right']:
    axd[plot_key].spines[axis].set_linewidth(1.5)
axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)

axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))
axd[plot_key].yaxis.grid(True, which='major', linewidth=0.4, alpha=0.4)
# axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.2, alpha=0.4)

axd[plot_key].set_ylim(0,1)

axd[plot_key].tick_params(axis='both', which='major', labelsize=16)
axd[plot_key].tick_params(axis='both', which='minor', labelsize=16)

axd[plot_key].set_xlabel("Top-N", fontsize=18)
# axd[plot_key].set_ylabel("Freq. of recovering spurious corr.\n(across all underperforming clusters)", fontsize=16)
axd[plot_key].set_ylabel("Freq. of recovering spurious corr.", fontsize=18)

for patch in axd[plot_key].patches :
    patch.set_linewidth(1)
    patch.set_edgecolor("black")


    
# leg=axd[plot_key].legend(fontsize = 16, facecolor='white', framealpha=0.5)
# axd[plot_key].get_legend().remove()
leg=axd[plot_key].legend(fontsize = 16, facecolor='white', framealpha=0.5, ncols=2,
                         loc='upper center', bbox_to_anchor=(0.487, -0.13, 0, 0), 
                        
                        )
leg.set_title("", prop={"size":16})

axd[plot_key].text(x=-0.2, 
                   y=1.09, 
#                    y=1.03, 
                   transform=axd[plot_key].transAxes,
                     s=" B.", fontsize=23, weight='bold')
# leg.set_title("Model", prop={"size":16})


# axd[plot_key].set_title("Do the top N rec spurious correlation", fontsize=16)
axd[plot_key].set_title("Do the top-N concept explanations recover spurious correlations?\n(across all low-performing clusters)", 
                        fontsize=18)

plot_key="on_the_spot"





# sns.barplot(x="rank_n", y="metric_random_ratio", hue="model", 
# data=eval_result_df[
#     (eval_result_df["method"]=="on_the_spot_plus")
#     &(eval_result_df["answer_length"]!=0)
#     &(eval_result_df["rank_n"]<=3)], ax=axd[plot_key])


for axis in ['top','bottom','left','right']:
    axd[plot_key].spines[axis].set_visible(False)
    
axd[plot_key].tick_params(left = False, right = False , labelleft = False ,
            labelbottom = False, bottom = False)

# fig.savefig(log_dir/"plots"/"model_audit_benchmark_derm7pt.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"model_audit_benchmark_derm7pt.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"model_audit_benchmark_derm7pt.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"model_audit_benchmark_derm7pt.pdf", bbox_inches='tight')

In [None]:
eval_result_df_mean_fixed=eval_result_df[
    (eval_result_df["method"]=="fixed_answer_plus")
    &(eval_result_df["answer_length"]!=0)
    &(eval_result_df["rank_n"]<=3)].groupby(["rank_n", "model"]).mean()
print(eval_result_df_mean_fixed)
sns.barplot(
    x="rank_n", y="metric", hue="model", hue_order=["MONET", "CLIP"],
    data=(eval_result_df_mean_fixed).reset_index(),
    width=0.7,
    ax=axd[plot_key]
)


In [None]:
eval_result_df[
    (eval_result_df["method"]=="fixed_answer_plus")
    &(eval_result_df["answer_length"]!=0)
    &(eval_result_df["rank_n"]<=3)].groupby(["rank_n", "model", "random_seed"]).mean()

In [None]:
eval_result_df

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"].keys()

In [None]:
# eval_result_df[eval_result_df["method"]=="fixed_answer_minus"]["answer_length"].hist()
# eval_result_df[eval_result_df["method"]=="on_the_spot_minus"]["answer_length"].hist()
# eval_result=evaluate_model_audit(simulation_data_list=
# [
#     simulation_data_list[0],
#     simulation_data_list[20],
#     simulation_data_list[40],
#     simulation_data_list[60],
#     simulation_data_list[80],    

# ])
# eval_result=evaluate_model_audit(simulation_data_list=simulation_data_list)
eval_result_df=pd.DataFrame(variable_dict["clinical_fd_clean_nodup_nooverlap"]["evaluation_model_audit"])
eval_result_df["metric_random_ratio"]=eval_result_df["metric"].astype(int)/eval_result_df["random_performance"]

In [None]:
plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                Paired[12][3],
                                                                                Paired[12][5],
                                                                                Paired[12][7],
                                                                                Paired[12][9],
                                                                                Paired[12][11]
                                                                                ]])
# fig = plt.figure(constrained_layout=True, figsize=(15, 6))
fig = plt.figure(figsize=(18, 5))
subfigs = fig.subfigures(1, 1)

axes = subfigs.subplots(1,2, gridspec_kw={"wspace":0.3})

axd={'fixed': axes[0], "on_the_spot": axes[1] }

plot_key="fixed"


# sns.barplot(x="rank_n", y="metric", hue="model", hue_order=["MONET", "CLIP"], 
#             data=eval_result_df[(eval_result_df["method"]=="fixed_answer_plus")&(eval_result_df["rank_n"]<=3)],
#            ax=axd[plot_key])

eval_result_df_mean_fixed=eval_result_df[
    (eval_result_df["method"]=="fixed_answer_plus")
    &(eval_result_df["answer_length"]!=0)
    &(eval_result_df["rank_n"]<=3)].groupby(["rank_n", "model"]).mean()
print(eval_result_df_mean_fixed)
sns.barplot(
    x="rank_n", y="metric", hue="model", hue_order=["MONET", "CLIP"],
    data=(eval_result_df_mean_fixed).reset_index(),
    width=0.7,
    ax=axd[plot_key]
)


# plt.title("\
# ground-truth is defined based on distribution of train/test set\
# \n(i.e., similar to the `red` confounder in the ISIC)")
for axis in ['top','bottom','left','right']:
    axd[plot_key].spines[axis].set_linewidth(1.5)
axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)

axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))
axd[plot_key].yaxis.grid(True, which='major', linewidth=0.4, alpha=0.4)
# axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.2, alpha=0.4)

axd[plot_key].set_ylim(0,1)

axd[plot_key].tick_params(axis='both', which='major', labelsize=16)
axd[plot_key].tick_params(axis='both', which='minor', labelsize=16)

axd[plot_key].set_xlabel("Top-N", fontsize=18)
# axd[plot_key].set_ylabel("Freq. of recovering spurious corr.\n(across all underperforming clusters)", fontsize=16)
axd[plot_key].set_ylabel("Freq. of recovering spurious corr.", fontsize=18)

for patch in axd[plot_key].patches :
    patch.set_linewidth(1)
    patch.set_edgecolor("black")


    
# leg=axd[plot_key].legend(fontsize = 16, facecolor='white', framealpha=0.5)
# axd[plot_key].get_legend().remove()
leg=axd[plot_key].legend(fontsize = 16, facecolor='white', framealpha=0.5, ncols=2,
                         loc='upper center', bbox_to_anchor=(0.487, -0.13, 0, 0), 
                        
                        )
leg.set_title("", prop={"size":16})

axd[plot_key].text(x=-0.2, 
                   y=1.09, 
#                    y=1.03, 
                   transform=axd[plot_key].transAxes,
                     s=" B.", fontsize=23, weight='bold')
# leg.set_title("Model", prop={"size":16})


# axd[plot_key].set_title("Do the top N rec spurious correlation", fontsize=16)
axd[plot_key].set_title("Do the top-N concept explanations recover spurious correlations?\n(across all low-performing clusters)", 
                        fontsize=18)

plot_key="on_the_spot"





# sns.barplot(x="rank_n", y="metric_random_ratio", hue="model", 
# data=eval_result_df[
#     (eval_result_df["method"]=="on_the_spot_plus")
#     &(eval_result_df["answer_length"]!=0)
#     &(eval_result_df["rank_n"]<=3)], ax=axd[plot_key])


for axis in ['top','bottom','left','right']:
    axd[plot_key].spines[axis].set_visible(False)
    
axd[plot_key].tick_params(left = False, right = False , labelleft = False ,
            labelbottom = False, bottom = False)

fig.savefig(log_dir/"plots"/"model_audit_benchmark_skincon.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"model_audit_benchmark_skincon.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"model_audit_benchmark_skincon.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"model_audit_benchmark_skincon.pdf", bbox_inches='tight')

In [None]:
log_dir = Path("logs")

# two ground truth

In [None]:
plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                Paired[12][3],
                                                                                Paired[12][5],
                                                                                Paired[12][7],
                                                                                Paired[12][9],
                                                                                Paired[12][11]
                                                                                ]])
# fig = plt.figure(constrained_layout=True, figsize=(15, 6))
fig = plt.figure(figsize=(18, 5))
subfigs = fig.subfigures(1, 1)

axes = subfigs.subplots(1,2, gridspec_kw={"wspace":0.3})

axd={'fixed': axes[0], "on_the_spot": axes[1] }

plot_key="fixed"


# sns.barplot(x="rank_n", y="metric", hue="model", hue_order=["MONET", "CLIP"], 
#             data=eval_result_df[(eval_result_df["method"]=="fixed_answer_plus")&(eval_result_df["rank_n"]<=3)],
#            ax=axd[plot_key])

eval_result_df_mean_fixed=eval_result_df[
    (eval_result_df["method"]=="fixed_answer_plus")
    &(eval_result_df["answer_length"]!=0)
    &(eval_result_df["rank_n"]<=3)].groupby(["rank_n", "model"]).mean()
print(eval_result_df_mean_fixed)
sns.barplot(
    x="rank_n", y="metric", hue="model", hue_order=["MONET", "CLIP"],
    data=(eval_result_df_mean_fixed).reset_index(),
    width=0.7,
    ax=axd[plot_key]
)


# plt.title("\
# ground-truth is defined based on distribution of train/test set\
# \n(i.e., similar to the `red` confounder in the ISIC)")
for axis in ['top','bottom','left','right']:
    axd[plot_key].spines[axis].set_linewidth(1.5)
axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)

axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))
axd[plot_key].yaxis.grid(True, which='major', linewidth=0.4, alpha=0.4)
# axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.2, alpha=0.4)

axd[plot_key].set_ylim(0,1)

axd[plot_key].tick_params(axis='both', which='major', labelsize=16)
axd[plot_key].tick_params(axis='both', which='minor', labelsize=16)

axd[plot_key].set_xlabel("Top-N", fontsize=18)
# axd[plot_key].set_ylabel("Freq. of recovering spurious corr.\n(across all underperforming clusters)", fontsize=16)
axd[plot_key].set_ylabel("Freq. of recovering spurious corr.", fontsize=18)

for patch in axd[plot_key].patches :
    patch.set_linewidth(1)
    patch.set_edgecolor("black")


    
# leg=axd[plot_key].legend(fontsize = 16, facecolor='white', framealpha=0.5)
axd[plot_key].get_legend().remove()
axd[plot_key].text(x=-0.2, 
                   y=1.09, 
#                    y=1.03, 
                   transform=axd[plot_key].transAxes,
                     s=" B.", fontsize=23, weight='bold')
# leg.set_title("Model", prop={"size":16})


# axd[plot_key].set_title("Do the top N rec spurious correlation", fontsize=16)
axd[plot_key].set_title("Do the top-N concept explanations recover spurious correlations?\n(across all low-performing clusters)", 
                        fontsize=18)

plot_key="on_the_spot"





# sns.barplot(x="rank_n", y="metric_random_ratio", hue="model", 
# data=eval_result_df[
#     (eval_result_df["method"]=="on_the_spot_plus")
#     &(eval_result_df["answer_length"]!=0)
#     &(eval_result_df["rank_n"]<=3)], ax=axd[plot_key])

eval_result_df_mean_spot=eval_result_df[
    (eval_result_df["method"]=="on_the_spot_plus")
    &(eval_result_df["answer_length"]!=0)
    &(eval_result_df["rank_n"]<=3)].groupby(["rank_n", "model"]).mean()
print(eval_result_df_mean_spot)
sns.barplot(
    x="rank_n", y=0, hue="model", hue_order=["MONET", "CLIP"],
    data=(eval_result_df_mean_spot["metric"]/eval_result_df_mean_spot["random_performance"]).reset_index(),
    width=0.7,
    ax=axd[plot_key]
)

for axis in ['top','bottom','left','right']:
    axd[plot_key].spines[axis].set_linewidth(1.5)
axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)

axd[plot_key].yaxis.set_major_locator(MultipleLocator(1))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.5))
axd[plot_key].yaxis.grid(True, which='major', linewidth=0.4, alpha=0.4)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.2, alpha=0.4)

axd[plot_key].tick_params(axis='both', which='major', labelsize=16)
axd[plot_key].tick_params(axis='both', which='minor', labelsize=16)

axd[plot_key].set_xlabel("Top-N", fontsize=18)
# axd[plot_key].set_ylabel("Prob. of listing ground-truth concept\ncompared to random (per cluster)", fontsize=16)
axd[plot_key].set_ylabel("Ratio of freq. of including ground truth\nto that in random ordering", fontsize=18)



for patch in axd[plot_key].patches :
    patch.set_linewidth(1)
    patch.set_edgecolor("black")

leg=axd[plot_key].legend(fontsize = 16, facecolor='white', framealpha=0.5, ncols=2,
                         loc='upper center', bbox_to_anchor=(-0.2, -0.1, 0, 0)
                        
                        )
leg.set_title("", prop={"size":16})
axd[plot_key].text(x=-0.19, 
                   y=1.09, 
#                    y=1.03, 
                   transform=axd[plot_key].transAxes,
                     s="C.", fontsize=23, weight='bold')
# plt.tight_figure()
# axd[plot_key].set_title("Ground-truth", fontsize=16)
axd[plot_key].set_title("Do the top-N concept explanations include ground truth\ndefined per low-performing cluster?", 
                        fontsize=18)

# fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.pdf", bbox_inches='tight')

In [None]:
eval_result_df_mean_spot

In [None]:
top_n=3
num_true=2
1-(math.comb(48-num_true, top_n) / math.comb(48, top_n))**20

In [None]:
# variable_dict_isic=torch.load("logs/experiment_results/data_audit_new_0429.pt", map_location='cpu')

# real

In [None]:
variable_dict_classifier=torch.load("logs/experiment_results/concept_annotation_data_auditing_0826.pt", map_location="cpu")

In [None]:
variable_dict_classifier["isic_nodup_nooverlap"].keys()

In [None]:
dex(['51b3c4a7cc25da63c438edc9d2d5e749.jpg'], dtype='object'),
 Index(['b80678b6ff26b79f53e079c2b853af9a.jpg'], dtype='object'),
 Index(['e89438538c2e7c4fe86a1d6112fcd599.jpg'], dtype='object')

In [None]:
def get_subset_index(dataset_name, metadata_all, attribution):
    if "isic" in dataset_name:
        if attribution=="all":
            #pd.Series([True], index=metadata_all)
            subset_idx=np.array([True]*len(metadata_all))
        else:
            collection_65=(metadata_all["collection_65"]==1).values
            
            if attribution=="barcelona_all":
                subset_idx=((metadata_all["attribution"]=="Department of Dermatology, Hospital Clínic de Barcelona")|(metadata_all["attribution"]=="Hospital Clínic de Barcelona")).values
            elif attribution=="mskcc_all":
                subset_idx=((metadata_all["attribution"]=="MSKCC")|(metadata_all["attribution"]=="Memorial Sloan Kettering Cancer Center")).values
            else:
                subset_idx=(metadata_all["attribution"]==attribution).values          
                
            subset_idx=subset_idx&collection_65
        #for attribution in [None]+["barcelona", "vienna", "barcelona_all"]:
        #for attribution in ["all"]+list(metadata_all["attribution"].unique())+["barcelona_all", "mskcc_all"]:
    elif "clinical_fd_clean" in dataset_name:
        if attribution=="all":
            subset_idx=np.array([True]*len(metadata_all))
        else:
            if attribution=="dark":
                subset_idx=((metadata_all["fitzpatrick_scale"].fillna(-9).astype(int).isin([5,6]))|\
                        (metadata_all["skin_tone"].fillna(-9).astype(int).isin([56]))).values   
            elif attribution=="light":
                subset_idx=((metadata_all["fitzpatrick_scale"].fillna(-9).astype(int).isin([1,2,3,4]))|\
                        (metadata_all["skin_tone"].fillna(-9).astype(int).isin([12,34]))).values
    return subset_idx

In [None]:
def check_image(dataset_name, idx):
    if "clinical_fd_clean_nodup" in dataset_name:
        if idx in ["58b4bc079ca94e6e9377a42ca7564b40.jpg",
                   "720cf31558966c82c118ab75b50632eb.jpg",
                   "5f046cda32a3cc547205662e7be774f9.jpg",
                   "d8bf377acc45a3beb0c6e81bf7ac1ff5.jpg",
                   "109dc569333a2fa8490e098c95c6c3ca.jpg",
                   "97445c91fd5215758e2c1a77c3fd1c12.jpg",
                   "51b3c4a7cc25da63c438edc9d2d5e749.jpg",
                   "b80678b6ff26b79f53e079c2b853af9a.jpg",
                   "e89438538c2e7c4fe86a1d6112fcd599.jps",
                  ]:
            return False
        else:
            return True            
    elif "isic" in dataset_name:
        if idx in ['ISIC_0053863',
             'ISIC_0072697',
             'ISIC_0062196',
             'ISIC_0064559',
             'ISIC_0062301',
             'ISIC_0068895',
             'ISIC_0070853',
             'ISIC_0054985',
             'ISIC_0059954',
             'ISIC_0058819']:
            return False
        else:
            return True
    else:
        return True

In [None]:
def check_image(dataset_name, idx):
    if "clinical_fd_clean_nodup" in dataset_name:
        if idx in ["58b4bc079ca94e6e9377a42ca7564b40.jpg",
         "720cf31558966c82c118ab75b50632eb.jpg",
         "5f046cda32a3cc547205662e7be774f9.jpg",
         "d8bf377acc45a3beb0c6e81bf7ac1ff5.jpg"]:
            return False
        else:
            return True            
    else:
        return True

In [None]:
def check_concept_name(dataset_name, concept_name):
    if "isic" in dataset_name:
        if concept_name.startswith("disease"):
            return False
        elif concept_name.startswith("isicconcept_"):
            return False          
        elif concept_name.startswith("derm7ptconcept_"):
            if 'typical' in concept_name:
                return False                
            elif 'regular' in concept_name:
                return False
            elif "pigmentation" in concept_name:
                return False
            else:
                return True        
        elif concept_name in ["melanoma", "malignant", "finger", "red sticker", "blue sticker"]:
            return False
        else:
            return True
        
    if "proveai" in dataset_name:
        if concept_name.startswith("disease"):
            return False
        elif concept_name.startswith("isicconcept_"):
            return False          
        elif concept_name.startswith("derm7ptconcept_"):
            if 'typical' in concept_name:
                return False                
            elif 'regular' in concept_name:
                return False
            elif "pigmentation" in concept_name:
                return False
            else:
                return True        
        elif concept_name in ["melanoma", "malignant", "finger", "red sticker", "blue sticker"]:
            return False
        else:
            return True        
        
    elif "clinical_fd_clean" in dataset_name:
        if concept_name.startswith("disease"):
            return False
        else:
            return True
    else:
        raise NotImplementedError(dataset_name)

def shorten_concept_name(concept_name, strict=True):
    if concept_name.startswith("disease_"):
        short_name=concept_name.replace("disease_", "")
    elif concept_name=="skincon_Erythema":
        short_name="Erythema"
    elif concept_name=="skincon_Bulla":
        short_name="Bulla"
    elif concept_name=="skincon_Lichenification":
        short_name="Lichenification"
    elif concept_name=="skincon_Pustule":
        short_name="Pustule"
    elif concept_name=="skincon_Ulcer":
        short_name="Ulcer"
    elif concept_name=="skincon_Warty/Papillomatous":
        short_name="Warty"
    elif concept_name=="skincon_White(Hypopigmentation)":
        short_name="Hypopigmentation"
    elif concept_name=="skincon_Brown(Hyperpigmentation)":
        short_name="Hyperpigmentation"
    elif concept_name=="skincon_Exophytic/Fungating":
        short_name="Fungating"          
    elif concept_name=="purple pen":
        short_name="Purple pen"
    elif concept_name=="nail":
        short_name="Nail"  
    elif concept_name=="orange sticker":
        short_name="Orange sticker"          
    elif concept_name=="hair":
        short_name="Hair"          
    elif concept_name=="gel":
        short_name="Gel"
    elif concept_name=="red":
        short_name="Red"
    elif concept_name=="blue sticker":
        short_name="Blue sticker"
    elif concept_name=="red sticker":
        short_name="Red sticker"        
    elif concept_name=="dermoscope border":
        short_name="Dermoscopic border"
    elif concept_name=="pinkish":
        short_name="Pink"
        
    elif concept_name=="derm7ptconcept_pigment network":
        short_name="Pigment network"        
    elif concept_name=="derm7ptconcept_regression structure":
        short_name="Regression structure"        
    elif concept_name=="derm7ptconcept_pigmentation":
        short_name="Pigmentation"        
    elif concept_name=="derm7ptconcept_blue whitish veil":
        short_name="Blue whitish veil"        
    elif concept_name=="derm7ptconcept_vascular structures":
        short_name="Vascular structures"        
    elif concept_name=="derm7ptconcept_streaks":
        short_name="Streaks"        
    elif concept_name=="derm7ptconcept_dots and globules":
        short_name="Dots and globules"
        
    else:
        if concept_name.startswith("skincon_"):
            short_name=concept_name[8:]
        else:
            if strict:
                raise NotImplementedError(concept_name)
            else:
                short_name=concept_name
            
    return short_name

def shorten_hospital_name(hospital_name):
    if hospital_name=="Hospital Clínic de Barcelona":
        short_name="Hospital Clínic de Barcelona"
    elif hospital_name=="ViDIR Group, Department of Dermatology, Medical University of Vienna":
        short_name="Medical University of Vienna"    
    elif hospital_name=="dark":
        short_name="Dark"
    elif hospital_name=="light":
        short_name="Light"            
    return short_name

In [None]:
from scipy.stats import mode as sci_mode
def plot_slice_figure(dataset_name,
                      data_dict,
                      prompt_info, 
                      row_per_slice=30, 
                      example_per_row=30, 
                      normalize=True, 
                      show_small_box=True, 
                      print_alphabet=True,
                      print_legend_color=True,
                      print_legend_color_idx=2,
                      print_legend_number=True,
                      task_type="malignancy",
                      fontsize=32,
                      true_pred_count_fontsize=25,
                      slice_title_fontsize=32,
                      skip_section=0,
                      figure_title=None, debug=False):
    

    
    
    red_color=np.array([212,17,89]) #np.array((222,40,40))
    green_color=np.array([26,133,255]) #np.array((40,200,40))
    [31, 120, 180], [51, 160, 44]
    # two_color=[np.array([90, 0, 220]), np.array([51, 160, 44])]
    #two_color=[np.array([31, 120, 180]), np.array([51, 160, 44])]
    two_color=[np.array([60, 50, 180]), np.array([51, 160, 44])]

    
    total_slices=(len([j for exp_name in data_dict.keys() for j in data_dict[exp_name]["sample_list_list"]]))
    
    fig = plt.figure(figsize=(3*(example_per_row), 
                              3*(row_per_slice)*total_slices+\
                              0.4*(len(data_dict)-1)+\
                              0.3*((total_slices-1)-(len(data_dict)-1))
                             )
                    )

    box1 = gridspec.GridSpec(len(data_dict), 1,
                             wspace=0.0,
                             hspace=0.4)
    
    axd={}
    for idx1, exp_name in enumerate(data_dict.keys()):
        box2 = gridspec.GridSpecFromSubplotSpec(len(data_dict[exp_name]["sample_list_list"]), 1,
                        subplot_spec=box1[idx1], wspace=0.0, hspace=0.3)

        for idx2, (slice_assignment) in enumerate(data_dict[exp_name]["sample_list_list"]):
            box3 = gridspec.GridSpecFromSubplotSpec(row_per_slice, example_per_row,
                                                                    subplot_spec=box2[idx2], wspace=0, hspace=0.05)            
#             if example_per_slice//10==1:
#                 box3 = gridspec.GridSpecFromSubplotSpec(row_per_slice, example_per_row,
#                                                         subplot_spec=box2[idx2], wspace=0, hspace=0.05)
#             else:
#                 box3 = gridspec.GridSpecFromSubplotSpec(row_per_slice, example_per_row,
#                                                         subplot_spec=box2[idx2], wspace=0.05, hspace=0.15)
            for rank_num in range(row_per_slice*example_per_row):
                ax=plt.Subplot(fig, box3[rank_num])
                fig.add_subplot(ax)

                plot_key=(idx1, idx2, rank_num)
                axd[plot_key]=ax   
                
    #dsdsd           
    for idx1, exp_name in enumerate(data_dict.keys()):
        
        targets=data_dict[exp_name]["targets"]
        preds=data_dict[exp_name]["preds"]
        main_title=data_dict[exp_name]["main_title"]   
                
        
        
        for idx2, (sample_list) in enumerate(data_dict[exp_name]["sample_list_list"]):
            print('--------------------------------------------------')
            for rank_num in range(row_per_slice*example_per_row):
                plot_key=(idx1, idx2, rank_num)
                axd[plot_key].set_xticks([])
                axd[plot_key].set_yticks([])   
                
                if rank_num==0 and idx2==1:
                    
                    axd[plot_key].text(x=0.0, 
                                         #y=1.1, 
                                         y=1.1, 
                                         transform=axd[plot_key].transAxes,
                                         s=main_title[1], 
                                         fontsize=slice_title_fontsize)    
#                                            , weight='bold')     

                if rank_num==example_per_row-1:
                    
                    label_str=f"True Malignant: {targets.loc[sample_list].sum()} Neg={(1-targets.loc[sample_list]).sum()}"
                    predicted_str=f" Pred +={(preds.loc[sample_list]==1).sum()} Neg={(preds.loc[sample_list]==0).sum()}"                    
                    
#                     title= f"Malignant: {targets[slice_mask].sum()} → {(preds[slice_mask]==1).sum()}   Benign: {(1-targets[slice_mask]).sum()} → {(prob[slice_mask]<0.5).sum()}"
                    title= f"True: {targets.loc[sample_list].sum()} / {(1-targets.loc[sample_list]).sum()} → Pred: {(preds.loc[sample_list]==1).sum()} / {(preds.loc[sample_list]==0).sum()} "
                    targets.loc[sample_list].sum()
                    print(title)
                    axd[plot_key].text(x=1.0, y=1.1, transform=axd[plot_key].transAxes,
                                         s=title, fontsize=true_pred_count_fontsize, color="black",
                                       horizontalalignment="right",

#                                       bbox=dict(facecolor='white', edgecolor='red')
                                      )   
                    
                #if print_legend_number and idx1==len(data_dict)-1 and idx2==len(data_dict[exp_name]["sample_list_list"])-1 and rank_num==example_per_row-1:
                if print_legend_number and idx1==len(data_dict)-1 and idx2==len(data_dict[exp_name]["sample_list_list"])-1 and rank_num==row_per_slice*example_per_row-1:
                    if task_type=="malignancy":
                        title= f"True: # Malignant / # Benign → Pred: # Malignant / # Benign"
                    elif task_type=="melanoma":
                        title= f"True: # Melanoma / # Non-melanoma → Pred: # Melanoma / # Non-melanoma"
                    else:
                        raise ValueError(task_type)
                    targets.loc[sample_list].sum()
                    axd[plot_key].text(x=1.0, y=-0.25, transform=axd[plot_key].transAxes,
                                         s=title, fontsize=true_pred_count_fontsize-2, color="black",
                                       horizontalalignment="right",

#                                       bbox=dict(facecolor='white', edgecolor='red')
                                      )  
                
                
                
                
                
                if print_legend_color and idx1==len(data_dict)-1 and idx2==len(data_dict[exp_name]["sample_list_list"])-1 and rank_num==print_legend_color_idx:

                    legend_elements = [Line2D([0], [0], marker='o', color=(1,1,1,1), 
                                              markerfacecolor=np.array((200,40,40))/256, 
                                              markeredgecolor=np.array((0,0,0))/256, 
                                              markersize=30, 
                                              label="Maligant"),
                                       Line2D([0], [0], marker='X', color=(1,1,1,1), 
                                              markerfacecolor=np.array((40,200,40))/256, 
                                              markeredgecolor=np.array((0,0,0))/256, 
                                              markersize=30, label="Benign"),]

                    if task_type=="malignancy":
                        legend_elements = [Line2D([0], [0], marker='s', color=(1,1,1,1), 
                                                  markerfacecolor=red_color/256, 
                                                  markeredgecolor=np.array((0,0,0))/256, 
                                                  markersize=30, 
                                                  label="Maligant"),
                                           Line2D([0], [0], marker='s', color=(1,1,1,1), 
                                                  markerfacecolor=green_color/256, 
                                                  markeredgecolor=np.array((0,0,0))/256, 
                                                  markersize=30, label="Benign   (Upper left: True, Lower right: Pred)"),]                                
                    elif task_type=="melanoma":
                        legend_elements = [Line2D([0], [0], marker='s', color=(1,1,1,1), 
                                                  markerfacecolor=red_color/256, 
                                                  markeredgecolor=np.array((0,0,0))/256, 
                                                  markersize=30, 
                                                  label="Melanoma"),
                                           Line2D([0], [0], marker='s', color=(1,1,1,1), 
                                                  markerfacecolor=green_color/256, 
                                                  markeredgecolor=np.array((0,0,0))/256, 
                                                  markersize=30, label="Non-melanoma   (Upper left: True, Lower right: Pred)"),]
                    else:
                        raise ValueError(task_type)                    
                    


                    axd[plot_key].legend(handles=legend_elements, 
                                        ncol=2, 
                                        handlelength=3,
                                        handletextpad=-0.1, 
                                        columnspacing=1.5,
                                        fontsize=true_pred_count_fontsize-2,
                                        loc='lower center', 
                                        bbox_to_anchor=(1, -0.45))  
                    
#                     axd[plot_key].legend(handles=legend_elements, 
#                                         ncol=2, 
#                                         handlelength=3,
#                                         handletextpad=-0.1, 
#                                         columnspacing=1.5,
#                                         fontsize=23,
#                                         loc='lower center', 
#                                         bbox_to_anchor=(0, -0.45))                   
            
            
            image_idx_list=variable_dict[dataset_name]["metadata_all"].index.get_indexer(sample_list)
        
            count=0
            rank_num=0
            while rank_num<min(row_per_slice*example_per_row, len(image_idx_list)):
                if check_image(dataset_name, variable_dict[dataset_name]["metadata_all"].index[image_idx_list[count]]):
#                 if check_image(dataset_name, image_idx_list[count]):
                    pass
                else:
                    count+=1
                    continue
                    
                plot_key=(idx1, idx2, rank_num)
                
                item=variable_dict[dataset_name]["dataloader"].dataset.getitem(image_idx_list[count])
                image=item["image"]
                axd[plot_key].imshow(image.resize((300, 300)))
                
                if debug:
                    axd[plot_key].set_title(item["metadata"].name)
                
                if show_small_box:

                    if pd.Series(variable_dict[dataset_name]["y_pos"], index=variable_dict[dataset_name]["metadata_all"].index).loc[item["metadata"].name]==True:
#                     if item["metadata"]["benign_malignant_bool"]==True:
                        axd[plot_key].scatter(x=[0.905], y=[0.905], s=650, 
                                       linewidths=1.5,
                                       edgecolor=np.array((0,0,0, 120))/256,
                                       #edgecolor=np.array((255,255,0, 120))/256,
                                       color=red_color/256,
                                       marker="s",
                                       transform=axd[plot_key].transAxes)     

                    elif pd.Series(variable_dict[dataset_name]["y_pos"], index=variable_dict[dataset_name]["metadata_all"].index).loc[item["metadata"].name]==False:
#                     elif item["metadata"]["benign_malignant_bool"]==False:
                        axd[plot_key].scatter(x=[0.905], y=[0.905], s=650, 
                                   linewidths=1.5,
                                   #edgecolor=np.array((0,0,0, 120))/256,
                                   edgecolor=np.array((0,0,0, 120))/256,
                                   color=green_color/256,
                                   marker="s",
                                   transform=axd[plot_key].transAxes)                 

                    x1=0.82
                    x2=0.99
                    if preds.loc[item["metadata"].name]==1:
                        axd[plot_key].fill([x1, x2, x2, x1], [x1, x2, x1, x1], 
                                           color=red_color/256,
                                          transform=axd[plot_key].transAxes
                                          )    
                    else:
                        axd[plot_key].fill([x1, x2, x2, x1], [x1, x2, x1, x1], 
                                           color=green_color/256,
                                          transform=axd[plot_key].transAxes
                                          )
    #                     axd[plot_key].scatter(x=[0.99], y=[0.99], s=700, 
    #                                    linewidths=1.3,
    # #                                    edgecolor=np.array((0,0,0, 120))/256,
    #                                    color=np.array((40,200,40))/256,
    #                                         #color=np.array((100,40,40))/256,
    #                                    marker=6,
    #                                    transform=axd[plot_key].transAxes)                         
                else:
                    axd[plot_key].set_title(item["metadata"].name, fontsize=10)

                axd[plot_key].set_xticks([])
                axd[plot_key].set_yticks([])  
                      

                if rank_num==0:   
                    #shorten_concept_name(concept_name)

                    #axd[plot_key].set_ylabel(shorten_concept_name(concept_name), fontsize=30, zorder=-10)
                    #axd[plot_key].set_ylabel(str(idx1), fontsize=30, zorder=-10)
                    pass

                if rank_num==100:
                    diff_dict_df=pd.DataFrame(diff_dict)
                    #print(diff_dict_df["concept_name"])
                    diff_dict_df=diff_dict_df[diff_dict_df["concept_name"].map(lambda x: check_concept_name("isic", x))]
                    
#                     print(diff_dict_df.sort_values("diff_score", ascending=False).iloc[:]["concept_name"])
                    #title=f"{int(slice_mask.sum()):d} {(targets[slice_mask]==((prob[slice_mask]>0.5).astype(int))).mean():.2f} "
                    #title=', '.join(diff_dict_df.sort_values("diff_score", ascending=False).iloc[:5]["concept_name"].map(shorten_concept_name).tolist())

#                     concept_str=diff_dict_df.sort_values("diff_score", ascending=False).iloc[:5]["concept_name"]
                    concept_str=diff_dict_df.sort_values("concept_presence_score", ascending=False).iloc[:5]["concept_name"]
#                     concept_str=diff_dict_df.sort_values("slice_score", ascending=False).iloc[:5]["concept_name"]
                    concept_str=concept_str.map(shorten_concept_name)
                    concept_str=", ".join(concept_str.str.replace("skincon_",""))
                    concept_str=concept_str
                    
                    print(concept_str)
#                     print(diff_dict_df.sort_values("diff_score", ascending=False).iloc[:20])
                    print(diff_dict_df.sort_values("concept_presence_score", ascending=False).iloc[:20])                    
#                     print(diff_dict_df.sort_values("concept_presence_score", ascending=False).iloc[:20])                    
                    

                    
                    #title+= / Predicted Pos={(prob[slice_mask]>0.5).sum()} Neg={(prob[slice_mask]<0.5).sum()}"
                                      
                    title= concept_str
                        
                    targets.loc[sample_list].sum()
                    axd[plot_key].text(x=-0., y=1.1, transform=axd[plot_key].transAxes,
                                         s=title, fontsize=true_pred_count_fontsize, color="black",
                                      
#                                       bbox=dict(facecolor='white', edgecolor='red')
                                      )   
                     
                    
                    pass
                   
                    
#                       axd[plot_key].text(x=-0.3, y=1.05, transform=axd[plot_key].transAxes,
#                                          s=["A", "B", "C", "D", "E"][idx1], fontsize=35, weight='bold')

                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(1)      
#                 print('idx~~~~', idx1)
                if rank_num==0 and idx1==0 and idx2==0 and figure_title is not None:
                    print(figure_title)
                    axd[plot_key].text(x=-0.33, 
                                         #y=1.1, 
                                         y=1.4,
                                         transform=axd[plot_key].transAxes,
                                         s=figure_title[0], 
                                         fontsize=35, weight='bold')  
                    axd[plot_key].text(x=0.01, 
                                         #y=1.1, 
                                         y=1.4, 
                                         transform=axd[plot_key].transAxes,
                                         s=figure_title[1],
                                         fontsize=fontsize)        
                    
                if rank_num==0 and idx2==0:
                    if print_alphabet and skip_section+idx1<26:
                        axd[plot_key].text(x=-0.3, 
                                             #y=1.1, 
                                             y=1.1, 
                                             transform=axd[plot_key].transAxes,
                                             s=["A.", "B.", "C.", "D.", "E.", "F.", "G.", "H.", "I.", "J.", 
                                                "K.", "L.", "M.", "N.", "O.", "P.", "Q.", "R.", "S.", "T.", 
                                                "U.", "V.", "W.", "X.", "Y.", "Z."][skip_section+idx1], 
                                             fontsize=slice_title_fontsize+3, weight='bold')  
                        axd[plot_key].text(x=0.05, 
                                             #y=1.1, 
                                             y=1.1, 
                                             transform=axd[plot_key].transAxes,
                                             s=main_title[0], 
                                             fontsize=fontsize)
                    else:
#                         axd[plot_key].text(x=-0.3, 
#                                              #y=1.1, 
#                                              y=1.4, 
#                                              transform=axd[plot_key].transAxes,
#                                              s=["A.", "B.", "C.", "D.", "E."][skip_section+idx1], 
#                                              fontsize=35, weight='bold')  
                        axd[plot_key].text(x=0.0, 
                                             #y=1.1, 
                                             y=1.1, 
                                             transform=axd[plot_key].transAxes,
                                             s=main_title[0], 
                                             fontsize=slice_title_fontsize)
    
    
                     
                    
  
                
                   
                    
                rank_num+=1
                count+=1                      
            
    return fig

In [None]:
variable_dict[dataset_name]["dataloader"].dataset.getitem(0)

In [None]:
[i["labels"].shape for i in test_result_list_from1_to2_concept_only]

In [None]:
def select_subset(image_features_norm, metadata_all, 
                  logits, labels, subset_idx):
    
    image_features_norm_subset=image_features_norm[metadata_all.index.get_indexer(metadata_all[subset_idx].index)]    
    
    logits_subset=logits.iloc[logits.index.get_indexer(metadata_all[subset_idx].index)]
    
    label_subset=labels.iloc[labels.index.get_indexer(metadata_all[subset_idx].index)]
    
    return image_features_norm_subset, logits_subset, label_subset

# run ISIC

In [None]:
attribution_dict={
    "isic_nodup_nooverlap": ["ViDIR Group, Department of Dermatology, Medical University of Vienna", "Hospital Clínic de Barcelona"],
    "clinical_fd_clean_nodup_nooverlap": ["light", "dark"],
}

dataset_name="isic_nodup_nooverlap"

hospital_1,hospital_2=attribution_dict[dataset_name]
n_clusters_param=n_clusters_param_dict[dataset_name]

In [None]:
max_f1_thres_isic={}

classifier_val_idx=variable_dict_classifier[dataset_name][f"classifier_dataloader_{hospital_1}"][1].dataset.metadata_all.index
y_test=variable_dict_classifier[dataset_name]["classifier_dataloader_all"].dataset.metadata_all["label"].loc[classifier_val_idx]
y_test_predicted_probas=variable_dict_classifier[dataset_name][f"classifier_model_{hospital_1}_eval"]["logits_with_index"].loc[classifier_val_idx]
# y_test_predicted_probas=y_test_predicted_probas.map(lambda x: 1/(1 + np.exp(-x)))

precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_test, y_test_predicted_probas)
numerator = 2 * recall * precision
denom = recall + precision
f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom!=0))
max_f1 = np.max(f1_scores)
max_f1_thresh = thresholds[np.argmax(f1_scores)]
max_f1_thres_isic[hospital_1]=max_f1_thresh
print(max_f1_thresh)

classifier_val_idx=variable_dict_classifier[dataset_name][f"classifier_dataloader_{hospital_2}"][1].dataset.metadata_all.index
y_test=variable_dict_classifier[dataset_name]["classifier_dataloader_all"].dataset.metadata_all["label"].loc[classifier_val_idx]
y_test_predicted_probas=variable_dict_classifier[dataset_name][f"classifier_model_{hospital_2}_eval"]["logits_with_index"].loc[classifier_val_idx]
# y_test_predicted_probas=y_test_predicted_probas.map(lambda x: 1/(1 + np.exp(-x)))

precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_test, y_test_predicted_probas)
numerator = 2 * recall * precision
denom = recall + precision
f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom!=0))
max_f1 = np.max(f1_scores)
max_f1_thresh = thresholds[np.argmax(f1_scores)]
max_f1_thres_isic[hospital_2]=max_f1_thresh

print(max_f1_thresh)

In [None]:
def calculate_test_auc(metadata_all, 
                       logits, 
                       idx):
    #print(logits.shape, idx, metadata_all.shape)
    pd.Series(logits, index=metadata_all.index)
#     return sklearn.metrics.roc_auc_score(y_true=metadata_all["label"].loc[idx].values.astype(int),
#                              y_score=logits.loc[idx].values)
    return sklearn.metrics.roc_auc_score(y_true=metadata_all["label"].loc[idx].values.astype(int),
                             y_score=pd.Series(logits, index=metadata_all.index).loc[idx].values)

print(calculate_test_auc(metadata_all=variable_dict_classifier[dataset_name][f"classifier_model_{hospital_1}_eval"]["metadata"],
                  logits=variable_dict_classifier[dataset_name][f"classifier_model_{hospital_1}_eval"]["logits"],
                   idx=variable_dict_classifier[dataset_name][f"classifier_dataloader_{hospital_2}"][1].dataset.metadata_all.index))

print(calculate_test_auc(metadata_all=variable_dict_classifier[dataset_name][f"classifier_model_{hospital_1}_eval"]["metadata"],
                  logits=variable_dict_classifier[dataset_name][f"classifier_model_{hospital_1}_eval"]["logits"],
                   idx=variable_dict_classifier[dataset_name][f"classifier_dataloader_{hospital_1}"][1].dataset.metadata_all.index))


In [None]:
def calculate_test_auc(metadata_all, 
                       logits, 
                       idx):
    #print(logits.shape, idx, metadata_all.shape)
    pd.Series(logits, index=metadata_all.index)
#     return sklearn.metrics.roc_auc_score(y_true=metadata_all["label"].loc[idx].values.astype(int),
#                              y_score=logits.loc[idx].values)
    return sklearn.metrics.roc_auc_score(y_true=metadata_all["label"].loc[idx].values.astype(int),
                             y_score=pd.Series(logits, index=metadata_all.index).loc[idx].values)

print(calculate_test_auc(metadata_all=variable_dict_classifier[dataset_name][f"classifier_model_{hospital_2}_eval"]["metadata"],
                  logits=variable_dict_classifier[dataset_name][f"classifier_model_{hospital_2}_eval"]["logits"],
                   idx=variable_dict_classifier[dataset_name][f"classifier_dataloader_{hospital_1}"][1].dataset.metadata_all.index))

print(calculate_test_auc(metadata_all=variable_dict_classifier[dataset_name][f"classifier_model_{hospital_2}_eval"]["metadata"],
                  logits=variable_dict_classifier[dataset_name][f"classifier_model_{hospital_2}_eval"]["logits"],
                   idx=variable_dict_classifier[dataset_name][f"classifier_dataloader_{hospital_2}"][1].dataset.metadata_all.index))


In [None]:
logits

In [None]:
sklearn.metrics.roc_auc_score(y_true=y_test, y_score=y_test_predicted_probas)

In [None]:
image_features_norm_subset_from1_to2, logits_subset_from1_to2, label_subset_from1_to2 = \
select_subset(image_features_norm=variable_dict_classifier[dataset_name]["image_features_norm"],
             metadata_all=variable_dict_classifier[dataset_name]["metadata_all"],
             logits=variable_dict_classifier[dataset_name][f"classifier_model_{hospital_1}_eval"]["logits_with_index"],
             labels=variable_dict_classifier[dataset_name]["classifier_dataloader_all"].dataset.metadata_all["label"],
             subset_idx=get_subset_index(dataset_name=dataset_name, 
                                         metadata_all=variable_dict_classifier[dataset_name]["metadata_all"], 
                                         attribution=hospital_2)&(variable_dict_classifier[dataset_name]["valid_idx"])) 

test_result_list_from1_to2_concept_only=cluster_concept_test_real(similarity_info=variable_dict[dataset_name]["similarity_matrix"]
[variable_dict[dataset_name]["similarity_matrix"].columns[variable_dict[dataset_name]["similarity_matrix"].columns.map(lambda x: check_concept_name(dataset_name, x))]]
                                                     , 
                          clustering_features=pd.DataFrame(variable_dict[dataset_name]["efficientnet_feature"].numpy(),
                                                             index=variable_dict[dataset_name]["metadata_all"].index,
                                                            ), 
                          fixed_answer=["red"],
                         labels=label_subset_from1_to2, 
                          logits=logits_subset_from1_to2, 
                          threshold=max_f1_thres_isic[hospital_1],
                         metric_diff=0,
                         n_clusters=40, random_state=42)

# test_result_list_from1_to2_with_disease=cluster_concept_test_real(similarity_info=variable_dict[dataset_name]["similarity_matrix"]
#                                                      , 
#                           clustering_features=pd.DataFrame(variable_dict[dataset_name]["efficientnet_feature"].numpy(),
#                                                              index=variable_dict[dataset_name]["metadata_all"].index,
#                                                             ), 
#                           fixed_answer=["red"],
#                          labels=label_subset_from1_to2, 
#                           logits=logits_subset_from1_to2, 
#                           threshold=max_f1_thres_isic[hospital_1],
#                          metric_diff=0.1,
#                          n_clusters=80, random_state=42)

print(len(label_subset_from1_to2))

In [None]:
similarity_thres=variable_dict[dataset_name]["similarity_matrix"].quantile(0.5, axis=0)
data_dict_main={}
data_dict_supple={}
count=0
for test_result in test_result_list_from1_to2_concept_only:
    # print(test_result["statistics"].sort_values("diff_magnitude", ascending=False))
    concept_name_list_plus=test_result["on_the_spot_plus_pred"][:5]
    sampe_list_plus=pd.concat([variable_dict[dataset_name]["similarity_matrix"][concept_name_list_plus].sum(axis=1).loc[test_result["labels"].index].rename("concept"),
test_result["labels"]], axis=1).sort_values(["accuracy", "concept"], ascending=[True, False]).index
#     sampe_list_plus=test_result["labels"].sort_values("kmeans_dist").index
#     test_result["labels"].sort_values("kmeans_dist").index,
#                                   test_result["labels_ref"].sort_values("kmeans_dist").index
    sub_title_plus=", ".join([shorten_concept_name(i, strict=False) for i in concept_name_list_plus])
    print(concept_name_list_plus)
    
    concept_name_list_minus=test_result["on_the_spot_minus_pred"][:3]
    concept_name_list_minus=test_result["statistics"][(test_result["statistics"]["diff_magnitude"]<0)&(test_result["statistics"]["mean_value"]>0.3)].sort_values("diff_magnitude", ascending=True).iloc[:5].index.tolist()
    sub_title_minus=", ".join([shorten_concept_name(i, strict=False) for i in concept_name_list_minus])  
    

    
    if count<5:
        data_dict_main[count]={
            "targets": label_subset_from1_to2.astype(int),
            "preds": (logits_subset_from1_to2>max_f1_thres_isic[hospital_1]).astype(int),
    #         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
    # #                              test_result["labels_ref"].sort_values("kmeans_dist").index
    #                             ],
            "sample_list_list":[
                sampe_list_plus
            ],           
            "main_title": [sub_title_plus, sub_title_minus],
        #     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
        }  
    if count<15:
        data_dict_supple[count]={
            "targets": label_subset_from1_to2.astype(int),
            "preds": (logits_subset_from1_to2>max_f1_thres_isic[hospital_1]).astype(int),
    #         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
    # #                              test_result["labels_ref"].sort_values("kmeans_dist").index
    #                             ],
            "sample_list_list":[
                sampe_list_plus
            ],           
            "main_title": [sub_title_plus, sub_title_minus],
        #     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
        }          
        
    count+=1

In [None]:
x="959,ISIC_0061234,ISIC_0053863,ISIC_0072697,,\n\
961,ISIC_0058303,ISIC_0062196,,,\n\
962,ISIC_0062827,ISIC_0064559,ISIC_0062301,,\n\
963,ISIC_0070137,ISIC_0068895,,,\n\
964,ISIC_0065605,ISIC_0070853,ISIC_0054985,ISIC_0059954,\n\
965,ISIC_0061470,ISIC_0058819,,,"
x_=[[j for j in i.split(',')[1:] if j!=""] for i in x.split('\n')]

In [None]:
[j for i in x_ for j in i[1:]]

In [None]:
data_dict_supple[13].keys()

In [None]:
def temp():
    targets=data_dict_supple[0]['targets']
    preds=data_dict_supple[0]['preds']
    sample_list=data_dict_supple[0]["sample_list_list"][0]
    print(sample_list)
#     sdd
    
    label_str=f"True Malignant: {targets.loc[sample_list].sum()} Neg={(1-targets.loc[sample_list]).sum()}"
    print(label_str)
    predicted_str=f" Pred +={(preds.loc[sample_list]==1).sum()} Neg={(preds.loc[sample_list]==0).sum()}"                    
    print(predicted_str)
    
temp()    


In [None]:
(data_dict_supple[0]['targets'].loc[data_dict_supple[0]["sample_list_list"][0]]).value_counts()

In [None]:
plot_slice_figure??

In [None]:
data_dict_supple[0]['targets'].loc[data_dict_supple[0]["sample_list_list"][0]]

In [None]:
                    
                    label_str=f"True Malignant: {targets.loc[sample_list].sum()} Neg={(1-targets.loc[sample_list]).sum()}"
                    predicted_str=f" Pred +={(preds.loc[sample_list]==1).sum()} Neg={(preds.loc[sample_list]==0).sum()}"                    
                    

In [None]:
data_dict_supple[0]['preds'].loc[data_dict_supple[0]["sample_list_list"][0]].sum()

In [None]:
data_dict_supple[0]['preds'].loc[
    data_dict_supple[0]['targets'].loc[data_dict_supple[0]["sample_list_list"][0]]\
    [data_dict_supple[0]['targets'].loc[data_dict_supple[0]["sample_list_list"][0]]==0].index
].value_counts()

In [None]:
183+115

In [None]:
data_dict_supple[7]['targets'].loc[data_dict_supple[7]["sample_list_list"][0]].sum()

In [None]:
data_dict_supple[7]['preds'].loc[data_dict_supple[7]["sample_list_list"][0]].sum()

In [None]:
data_dict_supple[7]['preds'].loc[
    data_dict_supple[7]['targets'].loc[data_dict_supple[13]["sample_list_list"][0]]\
    [data_dict_supple[7]['targets'].loc[data_dict_supple[13]["sample_list_list"][0]]==1].index
].sum()

In [None]:

data_dict_supple[13]['preds'].loc[data_dict_supple[13]["sample_list_list"][0]].loc[
    data_dict_supple[13]['targets'].loc[data_dict_supple[13]["sample_list_list"][0]].index
].sum()

In [None]:
data_dict_supple[13]['preds'].loc[data_dict_supple[13]["sample_list_list"][0]]

In [None]:
data_dict_supple[13]["sample_list_list"][0]

In [None]:
fig=plot_slice_figure(dataset_name=dataset_name,
                      data_dict=data_dict_supple,
                  prompt_info=variable_dict[dataset_name]["prompt_info"],
                  example_per_row=10,
                  row_per_slice=1,
                  normalize=True, 
                  show_small_box=True,
                  skip_section=0,
                  print_alphabet=True,
                  print_legend_number=True,
                  print_legend_color=True,
                  figure_title=None, debug=False)
# fig.savefig(log_dir/"plots"/f"model_audit_from1_to2_supple.pdf", bbox_inches='tight')

In [None]:
fig.savefig(log_dir/"plots"/f"model_audit_from1_to2_supple.pdf", bbox_inches='tight')

In [None]:
fig=plot_slice_figure(dataset_name=dataset_name,
                      data_dict=data_dict_main,
                  prompt_info=variable_dict[dataset_name]["prompt_info"],
                  example_per_row=5,
                  row_per_slice=1,
                  normalize=True, 
                  show_small_box=True,
                  skip_section=0,
                  print_alphabet=False,
                  print_legend_number=False,
                  print_legend_color=False,
                      slice_title_fontsize=27,
                  figure_title=("D. ", "Trained at Med U. Vienna / Tested at Hosp. Barcelona "))
# fig.savefig(log_dir/"plots"/f"model_audit_from1_to2_main.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"model_audit_from1_to2_main.pdf", bbox_inches='tight')

In [None]:
image_features_norm_subset_from2_to1, logits_subset_from2_to1, label_subset_from2_to1 = \
select_subset(image_features_norm=variable_dict_classifier[dataset_name]["image_features_norm"],
             metadata_all=variable_dict_classifier[dataset_name]["metadata_all"],
             logits=variable_dict_classifier[dataset_name][f"classifier_model_{hospital_2}_eval"]["logits_with_index"],
             labels=variable_dict_classifier[dataset_name]["classifier_dataloader_all"].dataset.metadata_all["label"],
             subset_idx=get_subset_index(dataset_name=dataset_name, 
                                         metadata_all=variable_dict_classifier[dataset_name]["metadata_all"], 
                                         attribution=hospital_1)&(variable_dict_classifier[dataset_name]["valid_idx"])) 

test_result_list_from2_to1_concept_only=cluster_concept_test_real(similarity_info=variable_dict[dataset_name]["similarity_matrix"]
[variable_dict[dataset_name]["similarity_matrix"].columns[variable_dict[dataset_name]["similarity_matrix"].columns.map(lambda x: check_concept_name(dataset_name, x))]]
                                                     , 
                          clustering_features=pd.DataFrame(variable_dict[dataset_name]["efficientnet_feature"].numpy(),
                                                             index=variable_dict[dataset_name]["metadata_all"].index,
                                                            ), 
                          fixed_answer=["red"],
                         labels=label_subset_from2_to1, 
                          logits=logits_subset_from2_to1, 
                          threshold=max_f1_thres_isic[hospital_2],
                         metric_diff=0,
                         n_clusters=40, random_state=42)

# test_result_list_from2_to1_with_disease=cluster_concept_test_real(similarity_info=variable_dict[dataset_name]["similarity_matrix"]
#                                                      , 
#                           clustering_features=pd.DataFrame(variable_dict[dataset_name]["efficientnet_feature"].numpy(),
#                                                              index=variable_dict[dataset_name]["metadata_all"].index,
#                                                             ), 
#                           fixed_answer=["red"],
#                          labels=label_subset_from2_to1, 
#                           logits=logits_subset_from2_to1, 
#                           threshold=max_f1_thres_isic[hospital_2],
#                          metric_diff=0.1,
#                          n_clusters=80, random_state=42)
print(len(label_subset_from2_to1))

In [None]:
similarity_thres=variable_dict[dataset_name]["similarity_matrix"].quantile(0.5, axis=0)
data_dict_main={}
data_dict_supple={}
count=0
for test_result in test_result_list_from2_to1_concept_only:
    print(test_result["statistics"].sort_values("diff_magnitude", ascending=False))
    concept_name_list_plus=test_result["on_the_spot_plus_pred"][:5]
    sampe_list_plus=pd.concat([variable_dict[dataset_name]["similarity_matrix"][concept_name_list_plus].sum(axis=1).loc[test_result["labels"].index].rename("concept"),
test_result["labels"]], axis=1).sort_values(["accuracy", "concept"], ascending=[True, False]).index
#     sampe_list_plus=test_result["labels"].sort_values("kmeans_dist").index
#     test_result["labels"].sort_values("kmeans_dist").index,
#                                   test_result["labels_ref"].sort_values("kmeans_dist").index
    sub_title_plus=", ".join([shorten_concept_name(i, strict=False) for i in concept_name_list_plus])
    
    concept_name_list_minus=test_result["on_the_spot_minus_pred"][:3]
    concept_name_list_minus=test_result["statistics"][(test_result["statistics"]["diff_magnitude"]<0)&(test_result["statistics"]["mean_value"]>0.3)].sort_values("diff_magnitude", ascending=True).iloc[:5].index.tolist()
    sub_title_minus=", ".join([shorten_concept_name(i, strict=False) for i in concept_name_list_minus])  
    

    
    if count<5:
        data_dict_main[count]={
            "targets": label_subset_from2_to1.astype(int),
            "preds": (logits_subset_from2_to1>max_f1_thres_isic[hospital_2]).astype(int),
    #         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
    # #                              test_result["labels_ref"].sort_values("kmeans_dist").index
    #                             ],
            "sample_list_list":[
                sampe_list_plus
            ],           
            "main_title": [sub_title_plus, sub_title_minus],
        #     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
        }  
    if count<15:
        data_dict_supple[count]={
            "targets": label_subset_from2_to1.astype(int),
            "preds": (logits_subset_from2_to1>max_f1_thres_isic[hospital_2]).astype(int),
    #         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
    # #                              test_result["labels_ref"].sort_values("kmeans_dist").index
    #                             ],
            "sample_list_list":[
                sampe_list_plus
            ],           
            "main_title": [sub_title_plus, sub_title_minus],
        #     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
        }          
        
    count+=1

In [None]:
fig=plot_slice_figure(dataset_name=dataset_name,
                      data_dict=data_dict_supple,
                  prompt_info=variable_dict[dataset_name]["prompt_info"],
                  example_per_row=10,
                  row_per_slice=1,
                  normalize=True, 
                  show_small_box=True,
                  skip_section=0,
                  print_alphabet=True,
                  print_legend_number=True,
                  print_legend_color=True,
                  figure_title=None)
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.pdf", bbox_inches='tight')

In [None]:
fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.pdf", bbox_inches='tight')

In [None]:
fig=plot_slice_figure(dataset_name=dataset_name,
                      data_dict=data_dict_main,
                  prompt_info=variable_dict[dataset_name]["prompt_info"],
                  example_per_row=5,
                  row_per_slice=1,
                  normalize=True, 
                  show_small_box=True,
                  skip_section=0,
                  print_alphabet=False,
                  print_legend_number=False,
                  print_legend_color=False,
                  slice_title_fontsize=27,
                  figure_title=("E. ", "Trained at Hosp. Barcelona / Tested at Med U. Vienna"))
fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_main.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_main.pdf", bbox_inches='tight')

In [None]:
Trained at Med U. Vienna / Tested at Hosp. Barcelona
Trained on light skin / Tested on dark skin 

In [None]:
fig=plot_slice_figure(dataset_name=dataset_name,
                      data_dict={"a":data_dict_main[0]},
                  prompt_info=variable_dict[dataset_name]["prompt_info"],
                  example_per_row=10,
                  row_per_slice=1,
                  normalize=True, 
                  show_small_box=True,
                  skip_section=0,
                  print_alphabet=False,
                  print_legend_number=False,
                  print_legend_color=True,
                  print_legend_color_idx=4,
                  figure_title=("E. ", "Trained at Med U. Vienna / Tested at Hosp. Barcelona "))
# fig.savefig(log_dir/"plots"/f"model_audit_main_legend.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"model_audit_main_legend.pdf", bbox_inches='tight')

# run proveai

In [None]:
dataset_name="proveai"

In [None]:
prove_logits_true=pd.read_csv("data/proveai/isic_upd_rev.csv",index_col=0)

In [None]:
sklearn.metrics.roc_auc_score(y_true=prove_logits_true["truth"],
                             y_score=prove_logits_true["scores"])

In [None]:
print(sklearn.metrics.classification_report(y_true=prove_logits_true["truth"],
                                      y_pred=prove_logits_true["scores"]))

In [None]:
prove_logits_true["scores"][prove_logits_true["prediction"]]

In [None]:
prove_logits_true["scores"][prove_logits_true["prediction"]==1].min()

In [None]:
prove_logits_true["scores"][prove_logits_true["prediction"]==0].max()

In [None]:
y_true=prove_logits_true["truth"]
y_pred=prove_logits_true["prediction"]
y_score=prove_logits_true["scores"]
# 0.001

tn, fp, fn, tp = (sklearn.metrics.confusion_matrix(y_true=y_true, y_pred=y_pred).ravel())
print("sensitivity", tp / (tp+fn))
print("specificity", tn / (tn+fp))
print("accuracy", (y_true==y_pred).mean())
print("auroc", sklearn.metrics.roc_auc_score(y_true=y_true,
                             y_score=y_score))

In [None]:
y_true=prove_logits_true["truth"]
y_pred=prove_logits_true["scores"]>0.00097
y_score=prove_logits_true["scores"]
# 0.001

tn, fp, fn, tp = (sklearn.metrics.confusion_matrix(y_true=y_true, y_pred=y_pred).ravel())
print("sensitivity", tp / (tp+fn))
print("specificity", tn / (tn+fp))
print("accuracy", (y_true==y_pred).mean())
print("auroc", sklearn.metrics.roc_auc_score(y_true=y_true,
                             y_score=y_score))

In [None]:
y_true=prove_logits_true["truth"]
y_pred=prove_logits_true["scores"]>0.0085
y_score=prove_logits_true["scores"]

tn, fp, fn, tp = (sklearn.metrics.confusion_matrix(y_true=y_true, y_pred=y_pred).ravel())
print("sensitivity", tp / (tp+fn))
print("specificity", tn / (tn+fp))
print("accuracy", (y_true==y_pred).mean())
print("auroc", sklearn.metrics.roc_auc_score(y_true=y_true,
                             y_score=y_score))

In [None]:
y_true=prove_logits_true["truth"]
y_pred=prove_logits_true["scores"]>0.5
y_score=prove_logits_true["scores"]

tn, fp, fn, tp = (sklearn.metrics.confusion_matrix(y_true=y_true, y_pred=y_pred).ravel())
print("sensitivity", tp / (tp+fn))
print("specificity", tn / (tn+fp))
print("accuracy", (y_true==y_pred).mean())
print("auroc", sklearn.metrics.roc_auc_score(y_true=y_true,
                             y_score=y_score))

In [None]:
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(prove_logits_true["truth"], 
                                                                       prove_logits_true["scores"])
numerator = 2 * recall * precision
denom = recall + precision
f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom!=0))
max_f1 = np.max(f1_scores)
max_f1_thresh = thresholds[np.argmax(f1_scores)]
max_f1_thresh

In [None]:
y_true=prove_logits_true["truth"]
y_pred=prove_logits_true["scores"]>0.0865900212694926
y_score=prove_logits_true["scores"]

tn, fp, fn, tp = (sklearn.metrics.confusion_matrix(y_true=y_true, y_pred=y_pred).ravel())
print("sensitivity", tp / (tp+fn))
print("specificity", tn / (tn+fp))
print("accuracy", (y_true==y_pred).mean())
print("auroc", sklearn.metrics.roc_auc_score(y_true=y_true,
                             y_score=y_score))

In [None]:
tp, (tp+fn)

In [None]:
prove_logits_true["truth"].sum()

In [None]:
for test_result in test_result_list_proveai_concept_only:
    y_true=prove_logits_true.set_index("image_name")["truth"].loc[test_result["labels"].index]
    y_pred=(prove_logits_true.set_index("image_name")["target"].loc[test_result["labels"].index])>0.00103
    y_score=prove_logits_true.set_index("image_name")["target"].loc[test_result["labels"].index]

    if (y_true==y_pred).all():
        continue
    
    tn, fp, fn, tp = (sklearn.metrics.confusion_matrix(y_true=y_true, y_pred=y_pred).ravel())
    print(tn, fp, fn, tp)
    print("sensitivity", tp / (tp+fn)) # tp / P
    print("specificity", tn / (tn+fp)) # tn / N
    if y_true.all() or (~y_true).all():
        pass
    else:
        print("auroc", sklearn.metrics.roc_auc_score(y_true=y_true,
                                     y_score=y_score))  
    print("accuracy", (y_true==y_pred).mean())    
    print()

In [None]:
for test_result in test_result_list_proveai_concept_only:
    y_true=prove_logits_true.set_index("image_name")["truth"].loc[test_result["labels"].index]
    y_pred=(prove_logits_true.set_index("image_name")["target"].loc[test_result["labels"].index])>0.1906984259977523
    y_score=prove_logits_true.set_index("image_name")["target"].loc[test_result["labels"].index]

    if (y_true==y_pred).all():
        continue
    
    tn, fp, fn, tp = (sklearn.metrics.confusion_matrix(y_true=y_true, y_pred=y_pred).ravel())
    print(tn, fp, fn, tp)
    print("sensitivity", tp / (tp+fn)) # tp / P
    print("specificity", tn / (tn+fp)) # tn / N
    if y_true.all() or (~y_true).all():
        pass
    else:
        print("auroc", sklearn.metrics.roc_auc_score(y_true=y_true,
                                     y_score=y_score))  
    print("accuracy", (y_true==y_pred).mean())    
    print()

In [None]:
for test_result in test_result_list_proveai_concept_only:
    y_true=prove_logits_true.set_index("image_name")["truth"].loc[test_result["labels"].index]
    y_pred=(prove_logits_true.set_index("image_name")["target"].loc[test_result["labels"].index])>0.05
    y_score=prove_logits_true.set_index("image_name")["target"].loc[test_result["labels"].index]

    if (y_true==y_pred).all():
        continue
    
    tn, fp, fn, tp = (sklearn.metrics.confusion_matrix(y_true=y_true, y_pred=y_pred).ravel())
    print(tn, fp, fn, tp)
    print("sensitivity", tp / (tp+fn)) # tp / P
    print("specificity", tn / (tn+fp)) # tn / N
    if y_true.all() or (~y_true).all():
        pass
    else:
        print("auroc", sklearn.metrics.roc_auc_score(y_true=y_true,
                                     y_score=y_score))  
    print("accuracy", (y_true==y_pred).mean())    
    print()

In [None]:
y_pred

In [None]:
y_true

In [None]:
y_score

In [None]:
test_result["labels"].index

In [None]:
tn, fp, fn, tp =\
(sklearn.metrics.confusion_matrix(y_true=prove_logits_true["truth"],
y_pred=prove_logits_true["target"]>0.0085).ravel())
print("sensitivity", tp / (tp+fn))
print("specificity", tn / (tn+fp))

In [None]:
sklearn.metrics 

confusion_matrix(y_true, y_pred).ravel()

In [None]:
# sensitivity
print(sklearn.metrics.recall_score(y_true=prove_logits_true["truth"],
    y_pred=prove_logits_true["prediction"]))

In [None]:
# specificity
print(sklearn.metrics.recall_score(y_true=prove_logits_true["truth"],
    y_pred=prove_logits_true["prediction"]))

In [None]:
sensitivity= recall = TP / (TP+FN)
specificity= TN / (TN+FP)
precision= TP / (TP+FP)

In [None]:
variable_dict[dataset_name]["metadata_all"].index

In [None]:
cluster_concept_test_real??

In [None]:
select_subset??

In [None]:
dataset_name="proveai"
threshold_select=0.00097
#threshold_select=0.5
# threshold_select=0.0865900212694926
# threshold_select=0.5

# image_features_norm_subset_proveai, logits_subset_proveai, label_subset_proveai = \
# variable_dict[dataset_name]["image_features_norm"], \
# prove_logits_true.set_index("image_name").loc[variable_dict[dataset_name]["metadata_all"].index]["scores"],\
# variable_dict[dataset_name]["y_pos"]

image_features_norm_subset_proveai, logits_subset_proveai, label_subset_proveai = \
select_subset(image_features_norm=variable_dict[dataset_name]["image_features_norm"],
             metadata_all=variable_dict[dataset_name]["metadata_all"],
             logits=prove_logits_true.set_index("image_name").loc[variable_dict[dataset_name]["metadata_all"].index]["scores"],
             labels=pd.Series(variable_dict[dataset_name]["y_pos"], index=variable_dict[dataset_name]["metadata_all"].index),
             subset_idx=pd.Series(variable_dict[dataset_name]["y_pos"], index=variable_dict[dataset_name]["metadata_all"].index)==0)



test_result_list_proveai_concept_only, similarity_info_focus_copy_group_proveai_concept_only = cluster_concept_test_real(similarity_info=variable_dict[dataset_name]["similarity_matrix"]
[variable_dict[dataset_name]["similarity_matrix"].columns[variable_dict[dataset_name]["similarity_matrix"].columns.map(lambda x: check_concept_name(dataset_name, x))]]
                                                     , 
                          clustering_features=pd.DataFrame(variable_dict[dataset_name]["efficientnet_feature"].numpy(),
                                                             index=variable_dict[dataset_name]["metadata_all"].index,
                                                            ), 
                          fixed_answer=["red"],
                         labels=label_subset_proveai, 
                          logits=logits_subset_proveai, 
                          threshold=threshold_select,
                         metric_diff=0,
                         metric_over=0.5,
                         n_clusters=10, random_state=42)

test_result_list_proveai_concept_only_, similarity_info_focus_copy_group_proveai_concept_only_ = cluster_concept_test_real(similarity_info=variable_dict[dataset_name]["similarity_matrix"]
[variable_dict[dataset_name]["similarity_matrix"].columns[variable_dict[dataset_name]["similarity_matrix"].columns.map(lambda x: check_concept_name(dataset_name, x))]]
                                                     , 
                          clustering_features=pd.DataFrame(variable_dict[dataset_name]["efficientnet_feature"].numpy(),
                                                             index=variable_dict[dataset_name]["metadata_all"].index,
                                                            ), 
                          fixed_answer=["red"],
                         labels=label_subset_proveai, 
                          logits=logits_subset_proveai, 
                          threshold=threshold_select,
                         metric_diff=0,
                         metric_over=0.5,                                                                                                                           
                         n_clusters=10, random_state=42, return_only_highperforming=False)

# test_result_list_proveai_with_disease=cluster_concept_test_real(similarity_info=variable_dict[dataset_name]["similarity_matrix"]
#                                                      , 
#                           clustering_features=pd.DataFrame(variable_dict[dataset_name]["efficientnet_feature"].numpy(),
#                                                              index=variable_dict[dataset_name]["metadata_all"].index,
#                                                             ), 
#                           fixed_answer=["red"],
#                          labels=label_subset_proveai, 
#                           logits=logits_subset_proveai, 
#                           threshold=max_f1_thres_isic[hospital_2],
#                          metric_diff=0.1,
#                          n_clusters=80, random_state=42)
print(len(label_subset_proveai))

In [None]:
check_concept_name??

In [None]:
len(test_result_list_proveai_concept_only_)

In [None]:
len(test_result_list_proveai_concept_only)

In [None]:
((logits_subset_proveai>threshold_select)==(label_subset_proveai==1)).mean()

In [None]:
((prove_logits_true.set_index("image_name")["target"].loc[test_result["labels"].index])>0.00103).sum()

In [None]:
(prove_logits_true["truth"]).sum(),\
(prove_logits_true["prediction"]).sum()

In [None]:
(prove_logits_true["truth"]==1)[(prove_logits_true["prediction"])==0].sum()

In [None]:
(prove_logits_true["truth"]).sum()

In [None]:
len(test_result_list_proveai_concept_only_)

In [None]:
threshold_select

In [None]:
(prove_logits_true.set_index("image_name")["truth"]==\
 (prove_logits_true.set_index("image_name")["target"]>threshold_select))\
.mean()

In [None]:
len(test_result_list_proveai_concept_only_)

In [None]:
test_result_list_proveai_concept_only_[0].keys()

In [None]:
on_the_spot_minus_pred

In [None]:
test_result_list_proveai_concept_only_[0]["on_the_spot_minus_pred"]

In [None]:
test_result_list_proveai_concept_only_[1]["on_the_spot_minus_pred"]

In [None]:
test_result_list_proveai_concept_only_[0]["labels_ref"]

In [None]:
test_result_list_proveai_concept_only_[1]["labels_ref"]

In [None]:
record_dict_list=[]
for test_result in test_result_list_proveai_concept_only_:
    y_true=prove_logits_true.set_index("image_name")["truth"].loc[test_result["labels"].index]
    y_pred=(prove_logits_true.set_index("image_name")["scores"].loc[test_result["labels"].index])>threshold_select
    y_score=prove_logits_true.set_index("image_name")["scores"].loc[test_result["labels"].index]

    if (y_true==y_pred).all():
        continue
    
    tn, fp, fn, tp = (sklearn.metrics.confusion_matrix(y_true=y_true, y_pred=y_pred, labels=[0,1]).ravel())
    
    
    record_dict_list.append({
        "M->M": tp,
        "M->B": fn,
        "B->M": fp,
        "B->B": tn,
        "sensitivity": tp / (tp+fn),
        "specificity": tn / (tn+fp),
        "accuracy":  (y_true==y_pred).mean(),
    })
    
    #print(f"{tp},{fn}/{tn},{fp} -> {tp},{fp}/{tn},{fn}")
    print("M -> M", tp)
    print("M -> B", fn)
    print("B -> M", fp)
    print("B -> B", tn)
    print("sensitivity", tp / (tp+fn)) # tp / P
    print("specificity", tn / (tn+fp)) # tn / N
    if y_true.all() or (~y_true).all():
        pass
    else:
        print("auroc", sklearn.metrics.roc_auc_score(y_true=y_true,
                                     y_score=y_score))  
    print("accuracy", (y_true==y_pred).mean())    
    print()

In [None]:
pd.DataFrame(record_dict_list)

In [None]:
prove_logits_true[(prove_logits_true["truth"]==0)&(prove_logits_true["prediction"]==1)].shape#.set_index("")

In [None]:
prove_logits_true[(prove_logits_true["truth"]==0)&(prove_logits_true["prediction"]==0)].shape#.set_index("")

In [None]:
prove_logits_true[(prove_logits_true["truth"]==0)&(prove_logits_true["prediction"]==0)]["image_name"]

In [None]:
pd.set_option('display.max_columns', 200)

In [None]:
(variable_dict[dataset_name]["similarity_matrix"].loc[
prove_logits_true[(prove_logits_true["truth"]==0)&(prove_logits_true["prediction"]==1)]["image_name"].values
].mean(axis=0)-\
variable_dict[dataset_name]["similarity_matrix"].loc[
prove_logits_true[(prove_logits_true["truth"]==0)&(prove_logits_true["prediction"]==0)]["image_name"].values
].mean(axis=0)).sort_values()

In [None]:
variable_dict[dataset_name]["similarity_matrix"].loc[
prove_logits_true[(prove_logits_true["truth"]==0)&(prove_logits_true["prediction"]==1)]["image_name"].values
]["purple pen"].mean()

In [None]:
variable_dict[dataset_name]["similarity_matrix"].loc[
prove_logits_true[(prove_logits_true["truth"]==0)&(prove_logits_true["prediction"]==0)]["image_name"].values
]["purple pen"]

In [None]:
test_result["labels_ref"]

In [None]:
pd.concat([variable_dict[dataset_name]["similarity_matrix"][concept_name_list_plus].sum(axis=1).loc[test_result["labels"].index].rename("concept"),
           
           
           
           

In [None]:
concept_name_list_plus

In [None]:
pd.concat([variable_dict[dataset_name]["similarity_matrix"][concept_name_list_plus].sum(axis=1).loc[test_result["labels"].index].rename("concept"),
test_result["labels"]], axis=1).sort_values(["accuracy", "concept"], ascending=[True, False]).index

In [None]:
test_result["statistics"]

In [None]:
test_result["labels"].shape, test_result["labels_ref"].shape

In [None]:
    sampe_list_minus=pd.concat([variable_dict[dataset_name]["similarity_matrix"][concept_name_list_minus].sum(axis=1).loc[test_result["labels_ref"].index].rename("concept"),
test_result["labels_ref"]], axis=1).sort_values(["accuracy", "concept"], ascending=[True, False]).index

# just 5 and with concept ordering

In [None]:
similarity_thres=variable_dict[dataset_name]["similarity_matrix"].quantile(0.5, axis=0)
data_dict_main={}
data_dict_supple={}
count=0
for test_result in test_result_list_proveai_concept_only:
    print(test_result["statistics"].sort_values("diff_magnitude", ascending=False))
    
    # get positive concept
    concept_name_list_plus=test_result["on_the_spot_plus_pred"][:5]
    
    # get image list
    sampe_list_plus=pd.concat([variable_dict[dataset_name]["similarity_matrix"][concept_name_list_plus].sum(axis=1).loc[test_result["labels"].index].rename("concept"),
    test_result["labels"]], axis=1).sort_values(["accuracy", "concept"], ascending=[True, False]).index   
#     sampe_list_plus=pd.concat([variable_dict[dataset_name]["similarity_matrix"][concept_name_list_plus].sum(axis=1).loc[test_result["labels"].index].rename("concept"),
# test_result["labels"]], axis=1).sort_values(["concept"], ascending=[False]).index    
    
    # get title
#     sampe_list_plus=test_result["labels"].sort_values("kmeans_dist").index
#     test_result["labels"].sort_values("kmeans_dist").index,
#                                   test_result["labels_ref"].sort_values("kmeans_dist").index
    sub_title_plus=", ".join([shorten_concept_name(i, strict=False) for i in concept_name_list_plus])
    
    # get negative concept
    concept_name_list_minus=test_result["on_the_spot_minus_pred"][:5]
    
    # get image list
    sampe_list_minus=pd.concat([variable_dict[dataset_name]["similarity_matrix"][concept_name_list_minus].sum(axis=1).loc[test_result["labels_ref"].index].rename("concept"),
    test_result["labels_ref"]], axis=1).sort_values(["accuracy", "concept"], ascending=[False, False]).index    
#     sampe_list_minus=pd.concat([variable_dict[dataset_name]["similarity_matrix"][concept_name_list_minus].sum(axis=1).loc[test_result["labels_ref"].index].rename("concept"),
# test_result["labels_ref"]], axis=1).sort_values(["concept"], ascending=[False]).index        
    # concept_name_list_minus=test_result["statistics"][(test_result["statistics"]["diff_magnitude"]<0)&(test_result["statistics"]["mean_value"]>0.3)].sort_values("diff_magnitude", ascending=True).iloc[:5].index.tolist()
    
    # get title
    sub_title_minus=", ".join([shorten_concept_name(i, strict=False) for i in concept_name_list_minus])  
    
    if count<5:
        data_dict_main[count]={
            "targets": label_subset_proveai.astype(int),
            "preds": (logits_subset_proveai>threshold_select).astype(int),
    #         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
    # #                              test_result["labels_ref"].sort_values("kmeans_dist").index
    #                             ],
            "sample_list_list":[
                sampe_list_plus, 
                sampe_list_minus
            ],           
            "main_title": [sub_title_plus, sub_title_minus],
        #     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
        }  
    if count<20:
        data_dict_supple[count]={
            "targets": label_subset_proveai.astype(int),
            "preds": (logits_subset_proveai>threshold_select).astype(int),
    #         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
    # #                              test_result["labels_ref"].sort_values("kmeans_dist").index
    #                             ],
            "sample_list_list":[
                sampe_list_plus,
                sampe_list_minus
            ],           
            "main_title": [sub_title_plus, sub_title_minus],
        #     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
        }          
        
    count+=1

In [None]:
for data_dict in data_dict_supple.values():
    y_true=data_dict["targets"].loc[data_dict["sample_list_list"][0]]
    y_pred=data_dict["preds"].loc[data_dict["sample_list_list"][0]]
    
    tn, fp, fn, tp = (sklearn.metrics.confusion_matrix(y_true=y_true, y_pred=y_pred, labels=[0,1]).ravel())
    print(tn, fp, fn, tp)
    print((y_true==1).sum(),"/", (y_true==0).sum(), "->", (y_pred==1).sum(),"/", (y_pred==0).sum())
    print("sensitivity", tp / (tp+fn)) # tp / P
    print("specificity", tn / (tn+fp)) # tn / N
    if y_true.all() or (~y_true).all():
        pass
    else:
        print("auroc", sklearn.metrics.roc_auc_score(y_true=y_true,
                                     y_score=y_score))  
    print("accuracy", (y_true==y_pred).mean())    
    print()    
    

In [None]:
fig=plot_slice_figure(dataset_name=dataset_name,
                      data_dict=data_dict_supple,
                      prompt_info=variable_dict[dataset_name]["prompt_info"],
                      example_per_row=5,
                      row_per_slice=1,
                      normalize=True, 
                      show_small_box=True,
                      skip_section=0,
                      print_alphabet=True,
                      print_legend_number=True,
                      print_legend_color=True,
                      true_pred_count_fontsize=16,
                      fontsize=16,
                      slice_title_fontsize=16,
                      figure_title=None)
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.pdf", bbox_inches='tight')

In [None]:
fig=plot_slice_figure(dataset_name=dataset_name,
                      data_dict=data_dict_supple,
                      prompt_info=variable_dict[dataset_name]["prompt_info"],
                      example_per_row=10,
                      row_per_slice=2,
                      normalize=True, 
                      show_small_box=True,
                      skip_section=0,
                      print_alphabet=True,
                      print_legend_number=True,
                      print_legend_color=True,
                      true_pred_count_fontsize=16,
                      fontsize=16,
                      slice_title_fontsize=16,
                      figure_title=None)
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.pdf", bbox_inches='tight')

In [None]:
fig.savefig(log_dir/"plots"/f"model_audit_ADAE_test_conceptordered.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/f"model_audit_ADAE_test_conceptordered.pdf", bbox_inches='tight')

# just 5 and without concept ordering

In [None]:
similarity_thres=variable_dict[dataset_name]["similarity_matrix"].quantile(0.5, axis=0)
data_dict_main={}
data_dict_supple={}
count=0
for test_result in test_result_list_proveai_concept_only:
    print(test_result["statistics"].sort_values("diff_magnitude", ascending=False))
    
    # get positive concept
    concept_name_list_plus=test_result["on_the_spot_plus_pred"][:5]
    
    # get image list
    # sampe_list_plus=pd.concat([variable_dict[dataset_name]["similarity_matrix"][concept_name_list_plus].sum(axis=1).loc[test_result["labels"].index].rename("concept"),
    # test_result["labels"]], axis=1).sort_values(["accuracy", "concept"], ascending=[True, False]).index   
    sampe_list_plus=pd.concat([variable_dict[dataset_name]["similarity_matrix"][concept_name_list_plus].sum(axis=1).loc[test_result["labels"].index].rename("concept"),
test_result["labels"]], axis=1).sort_values(["accuracy"], ascending=[True]).index    
    
    # get title
#     sampe_list_plus=test_result["labels"].sort_values("kmeans_dist").index
#     test_result["labels"].sort_values("kmeans_dist").index,
#                                   test_result["labels_ref"].sort_values("kmeans_dist").index
    sub_title_plus=", ".join([shorten_concept_name(i, strict=False) for i in concept_name_list_plus])
    
    # get negative concept
    concept_name_list_minus=test_result["on_the_spot_minus_pred"][:5]
    
    # get image list
    # sampe_list_minus=pd.concat([variable_dict[dataset_name]["similarity_matrix"][concept_name_list_minus].sum(axis=1).loc[test_result["labels_ref"].index].rename("concept"),
    # test_result["labels_ref"]], axis=1).sort_values(["accuracy", "concept"], ascending=[False, False]).index    
    sampe_list_minus=pd.concat([variable_dict[dataset_name]["similarity_matrix"][concept_name_list_minus].sum(axis=1).loc[test_result["labels_ref"].index].rename("concept"),
test_result["labels_ref"]], axis=1).sort_values(["accuracy"], ascending=[False]).index        
    # concept_name_list_minus=test_result["statistics"][(test_result["statistics"]["diff_magnitude"]<0)&(test_result["statistics"]["mean_value"]>0.3)].sort_values("diff_magnitude", ascending=True).iloc[:5].index.tolist()
    
    # get title
    sub_title_minus=", ".join([shorten_concept_name(i, strict=False) for i in concept_name_list_minus])  
    
    if count<5:
        data_dict_main[count]={
            "targets": label_subset_proveai.astype(int),
            "preds": (logits_subset_proveai>threshold_select).astype(int),
    #         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
    # #                              test_result["labels_ref"].sort_values("kmeans_dist").index
    #                             ],
            "sample_list_list":[
                sampe_list_plus, 
                sampe_list_minus
            ],           
            "main_title": [sub_title_plus, sub_title_minus],
        #     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
        }  
    if count<20:
        data_dict_supple[count]={
            "targets": label_subset_proveai.astype(int),
            "preds": (logits_subset_proveai>threshold_select).astype(int),
    #         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
    # #                              test_result["labels_ref"].sort_values("kmeans_dist").index
    #                             ],
            "sample_list_list":[
                sampe_list_plus,
                sampe_list_minus
            ],           
            "main_title": [sub_title_plus, sub_title_minus],
        #     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
        }          
        
    count+=1

In [None]:
fig=plot_slice_figure(dataset_name=dataset_name,
                      data_dict=data_dict_supple,
                      prompt_info=variable_dict[dataset_name]["prompt_info"],
                      example_per_row=5,
                      row_per_slice=1,
                      normalize=True, 
                      show_small_box=True,
                      skip_section=0,
                      print_alphabet=True,
                      print_legend_number=True,
                      print_legend_color=True,
                      true_pred_count_fontsize=16,
                      fontsize=16,
                      slice_title_fontsize=16,
                      figure_title=None)
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.pdf", bbox_inches='tight')

In [None]:
fig=plot_slice_figure(dataset_name=dataset_name,
                      data_dict=data_dict_supple,
                      prompt_info=variable_dict[dataset_name]["prompt_info"],
                      example_per_row=10,
                      row_per_slice=2,
                      normalize=True, 
                      show_small_box=True,
                      skip_section=0,
                      print_alphabet=True,
                      print_legend_number=True,
                      print_legend_color=True,
                      true_pred_count_fontsize=16,
                      fontsize=16,
                      slice_title_fontsize=16,
                      figure_title=None)
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.pdf", bbox_inches='tight')

In [None]:
fig.savefig(log_dir/"plots"/f"model_audit_ADAE_test.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/f"model_audit_ADAE_test.pdf", bbox_inches='tight')

# purple pen

In [None]:
variable_dict[dataset_name]["similarity_matrix"]["purple pen"].sort_values(ascending=False).index[:50]

In [None]:
variable_dict[dataset_name]["similarity_matrix"]["purple pen"]

In [None]:
idx_select

In [None]:
label_subset_proveai

In [None]:
logits_subset_proveai.shape

In [None]:
(logits_subset_proveai>threshold_select).astype(int).sum()

In [None]:
1-335/508

In [None]:
label_subset_proveai

In [None]:
idx_benign

In [None]:
idx_all

In [None]:
threshold_select

In [None]:
def get_metrics(y_true, y_pred):
    tn, fp, fn, tp = (sklearn.metrics.confusion_matrix(y_true=y_true, y_pred=y_pred).ravel())
    
    return {
        "sensitivity": tp / (tp+fn),
        "specificity": tn / (tn+fp),
        "auroc": sklearn.metrics.roc_auc_score(y_true=y_true, y_score=y_score),
        "accuracy": (y_true==y_pred).mean()   
    }

In [None]:
prove_logits_true

In [None]:
concept_name

In [None]:
concept_name="skincon_Poikiloderma"
top_idx=variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=False).index[:100]
bottom_idx=variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=True).index[:100]
below_top_idx=variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=False).index[100:]

In [None]:
num_top=(prove_logits_true.set_index("image_name")["scores"]>threshold_select).astype(int).loc[top_idx]
num_bottom=(prove_logits_true.set_index("image_name")["scores"]>threshold_select).astype(int).loc[below_top_idx].mean()

In [None]:
num_top

In [None]:
fisher_exact([
    [num_top.sum(), (1-num_top).sum()],
    [num_bottom.sum(), (1-num_bottom).sum()]
], alternative='two-sided')

In [None]:
fisher_exact([
    [num_top.sum(), (1-num_top).sum()],
    [num_bottom.sum(), (1-num_bottom).sum()]
], alternative='greater')

In [None]:
(num_top.sum()/ (1-num_top).sum())/(num_bottom.sum()/(1-num_bottom).sum())

In [None]:
fisher_exact([
    [num_top.sum(), (1-num_top).sum()],
    [num_bottom.sum(), (1-num_bottom).sum()]
], alternative='greater')

In [None]:
fisher_exact([
    [num_top.sum(), (1-num_top).sum()],
    [num_bottom.sum(), (1-num_bottom).sum()]
], alternative='less')

In [None]:
concept_name="skincon_Poikiloderma"

In [None]:
top_idx=variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=False).index[:100]
bottom_idx=variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=True).index[:100]
below_top_idx=variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=False).index[100:]



num_top=(prove_logits_true.set_index("image_name")["scores"]>threshold_select).astype(int).loc[top_idx]
num_bottom=(prove_logits_true.set_index("image_name")["scores"]>threshold_select).astype(int).loc[below_top_idx]


num_test=fisher_exact([
        [num_top.sum(), (1-num_top).sum()],
        [num_bottom.sum(), (1-num_bottom).sum()]
    ], alternative='greater')    

In [None]:
num_test

In [None]:
fisher_exact([
    [num_top.sum(), num_bottom.sum()],
    [(1-num_top).sum(), (1-num_bottom).sum()]
])

In [None]:
[
    [num_top.sum(), num_bottom.sum()],
    [(1-num_top).sum(), (1-num_bottom).sum()]
]

In [None]:
num_test.statistic

In [None]:
[
    [num_top.sum(), (1-num_top).sum()],
    [num_bottom.sum(), (1-num_bottom).sum()]
]

In [None]:
    num_top=(prove_logits_true.set_index("image_name")["scores"]>threshold_select).astype(int).loc[top_idx]
    num_bottom=(prove_logits_true.set_index("image_name")["scores"]>threshold_select).astype(int).loc[below_top_idx]

In [None]:
num_top.shape

In [None]:
record_dict_list_temp=[]

idx_benign=pd.Series(variable_dict[dataset_name]["y_pos"], index=variable_dict[dataset_name]["metadata_all"].index)[
    pd.Series(variable_dict[dataset_name]["y_pos"], index=variable_dict[dataset_name]["metadata_all"].index)==0
].index

idx_all=variable_dict[dataset_name]["metadata_all"].index

idx_select=idx_benign

for concept_name in variable_dict[dataset_name]["similarity_matrix"].columns:

    top_idx=variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=False).index[:100]
    bottom_idx=variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=True).index[:100]
    below_top_idx=variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=False).index[100:]

    
    
    num_top=(prove_logits_true.set_index("image_name")["scores"]>threshold_select).astype(int).loc[top_idx]
    num_bottom=(prove_logits_true.set_index("image_name")["scores"]>threshold_select).astype(int).loc[below_top_idx]
    
    
    num_test=fisher_exact([
            [num_top.sum(), (1-num_top).sum()],
            [num_bottom.sum(), (1-num_bottom).sum()]
        ], alternative='greater')    
    
    correct_top=((prove_logits_true.set_index("image_name")["scores"]>threshold_select).astype(int)==prove_logits_true.set_index("image_name")["truth"]).astype(int).loc[top_idx]
    correct_bottom=((prove_logits_true.set_index("image_name")["scores"]>threshold_select).astype(int)==prove_logits_true.set_index("image_name")["truth"]).astype(int).loc[below_top_idx]
    
    
    
    record_dict_list_temp.append(
        {
            "concept_name": concept_name,
            "prop_top": (1-num_top.mean()),
            "prop_bottom": 1-num_bottom.mean(),
            "prop_diff": -num_top.mean()+num_bottom.mean(),
            "num_test_statistic": num_test.statistic,
            "num_test_pval": num_test.pvalue,
            "accuracy_top": correct_top.mean(),
            "accuracy_bottom": correct_bottom.mean(),
            "accuracy_diff": correct_top.mean()-correct_bottom.mean(),
        }

    )
    #label_subset_proveai.astype(int),"preds": (logits_subset_proveai>threshold_select).astype(int),    
#     print(concept_name, (logits_subset_proveai>threshold_select).astype(int).loc[top_idx].sum(),
#           (logits_subset_proveai>threshold_select).astype(int).loc[bottom_idx].sum(),
#           (logits_subset_proveai>threshold_select).astype(int).loc[top_idx].sum()-(logits_subset_proveai>threshold_select).astype(int).loc[bottom_idx].sum()
#          )

In [None]:
record_dict_list_temp_df=pd.DataFrame(record_dict_list_temp)
record_dict_list_temp_df[record_dict_list_temp_df["concept_name"].map(lambda x: check_concept_name('proveai', x))]\
.sort_values("prop_diff", ascending=True)\
.rename(columns={"prop_top": "specificity_top100",
                 "prop_bottom": "specificity_rest",
                 "prop_diff": "specificity_diff",
                 "num_test_statistic": "odd ratio",
                 "num_test_pval": "fisher_pval",
                })\
[['concept_name', 'specificity_top100', 'specificity_rest',
       'specificity_diff', 'odd ratio', 'fisher_pval']]\
.iloc[:]

In [None]:
record_dict_list_temp_df=pd.DataFrame(record_dict_list_temp)

record_dict_list_temp_df=record_dict_list_temp_df[record_dict_list_temp_df["concept_name"].map(lambda x: check_concept_name('proveai', x))]

multiple_testing_correction=multitest.multipletests(
pvals=record_dict_list_temp_df["num_test_pval"].values,
method="bonferroni"
)
record_dict_list_temp_df["reject"]=multiple_testing_correction[0]
record_dict_list_temp_df["num_test_pval"]=multiple_testing_correction[1]
record_dict_list_temp_df["alphacSidak"]=multiple_testing_correction[2]
record_dict_list_temp_df["alphacBonf"]=multiple_testing_correction[3]


record_dict_list_temp_df_latex=record_dict_list_temp_df.copy()
record_dict_list_temp_df_latex["concept_name"]=record_dict_list_temp_df_latex["concept_name"].map(shorten_concept_name)

record_dict_list_temp_df_latex["num_test_pval"]=record_dict_list_temp_df_latex["num_test_pval"].map(lambda x: f"{x:.3e}")
record_dict_list_temp_df_latex["prop_top"]=record_dict_list_temp_df_latex["prop_top"]
record_dict_list_temp_df_latex["prop_bottom"]=record_dict_list_temp_df_latex["prop_bottom"]
record_dict_list_temp_df_latex[["prop_top",
                          "prop_bottom",
                          "num_test_statistic"
                         ]]=record_dict_list_temp_df_latex[["prop_top",
                          "prop_bottom",
                          "num_test_statistic"
                         ]].round(3)
# float_format_func=lambda x: f"{x:.3e}".replace("e", "\\times")
# record_dict_list_temp_df_latex["concept_name"].map(shorten_concept_name)
table_latex=record_dict_list_temp_df_latex.sort_values("prop_diff", ascending=True)\
.rename(columns={"prop_top": "specificity_top100",
                 "prop_bottom": "specificity_rest",
                 "prop_diff": "specificity_diff",
                 "num_test_statistic": "odd ratio",
                 "num_test_pval": "fisher_pval",
                })\
[['concept_name', 'specificity_top100', "specificity_rest", "odd ratio", 'fisher_pval']].to_latex(index=False,
# formatters={"fisher_pval": float_format_func}
)

In [None]:
print(table_latex)

In [None]:
print(table_latex)

In [None]:
record_dict_list_temp_df_latex.sort_values("prop_diff", ascending=True)

In [None]:
print(table_latex)#.replace("e-"," \\times 10^{-"))

In [None]:
record_dict_list_temp_df_latex.sort_values("prop_diff", ascending=True)

In [None]:
0.05/62

In [None]:
multitest.multipletests(
pvals=record_dict_list_temp_df["num_test_pval"].values,
method="bonferroni"
)

In [None]:
record_dict_list_temp_df_latex.shape

In [None]:
record_dict_list_temp_df.shape

In [None]:
from statsmodels.stats import multitest

In [None]:
multitest.multipletests(
pvals=record_dict_list_temp_df["num_test_pval"].values,
method="bonferroni"
)

In [None]:
record_dict_list_temp_df["num_test_pval"]

In [None]:
record_dict_list_temp_df=pd.DataFrame(record_dict_list_temp)
record_dict_list_temp_df[record_dict_list_temp_df["concept_name"].map(lambda x: check_concept_name('proveai', x))]\
.sort_values("accuracy_diff", ascending=True)

In [None]:
pd.DataFrame(record_dict_list_temp).sort_values("num_diff", ascending=False)

In [None]:
precision

In [None]:
shorten_concept_name(concept_name, strict=False)

In [None]:
similarity_thres=variable_dict[dataset_name]["similarity_matrix"].quantile(0.5, axis=0)
data_dict_main={}
data_dict_supple={}
count=0
#shorten_concept_name(i, strict=False)
    
idx_select=idx_benign

# concept_name="skincon_Poikiloderma"
concept_name="pinkish"

if False:
    data_test={
        "targets": prove_logits_true.set_index("image_name")["truth"].loc[idx_select].astype(int),
        "preds": (prove_logits_true.set_index("image_name")["scores"]>threshold_select).loc[idx_select].astype(int),
    #         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
    # #                              test_result["labels_ref"].sort_values("kmeans_dist").index
    #                             ],
        "sample_list_list":[
            variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=False).index[:100], 
    #         np.random.RandomState(seed=42).choice(
    #         variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=False).index[100:],
    #         100, replace=False)        
            variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=False).index[-100:]

        ],           
        "main_title": ["The top 100 images for "+shorten_concept_name(concept_name, strict=False), 
                       "The bottom 100 images for "+shorten_concept_name(concept_name, strict=False), 
                      # "100 sampled images that are not included in top 100"
                      ],
    #     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
    }  
    count+=1    

data_test={
    "targets": prove_logits_true.set_index("image_name")["truth"].loc[idx_select].astype(int),
    "preds": (prove_logits_true.set_index("image_name")["scores"]>threshold_select).loc[idx_select].astype(int),
#         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
# #                              test_result["labels_ref"].sort_values("kmeans_dist").index
#                             ],
    "sample_list_list":[
        variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=False).index[:100], 
#         np.random.RandomState(seed=42).choice(
#         variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=False).index[100:],
#         100, replace=False)        

    ],           
    "main_title": ["The top 100 images for "+shorten_concept_name(concept_name, strict=False), 
                  # "100 sampled images that are not included in top 100"
                  ],
#     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
}  
data_dict_main[count]=data_test
data_dict_supple[count]=data_test 

count+=1

data_test={
    "targets": prove_logits_true.set_index("image_name")["truth"].loc[idx_select].astype(int),
    "preds": (prove_logits_true.set_index("image_name")["scores"]>threshold_select).loc[idx_select].astype(int),
#         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
# #                              test_result["labels_ref"].sort_values("kmeans_dist").index
#                             ],
    "sample_list_list":[
#         np.random.RandomState(seed=42).choice(
#         variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=False).index[100:],
#         100, replace=False)        
        variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=False).index[-100:]

    ],           
    "main_title": ["The bottom 100 images for "+shorten_concept_name(concept_name, strict=False), 
                  # "100 sampled images that are not included in top 100"
                  ],
#     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
}  

data_dict_main[count]=data_test
data_dict_supple[count]=data_test       

count+=1

In [None]:
shorten_concept_name(concept_name, strict=False)

In [None]:
variable_dict[dataset_name]["metadata_all"]

In [None]:
variable_dict[dataset_name]["metadata_all"].loc["ISIC_6410859"]

In [None]:
variable_dict[dataset_name]["metadata_all"]["copyright_license"].value_counts()

In [None]:
variable_dict[dataset_name]["metadata_all"].loc[list(data_dict_supple[i]["sample_list_list"])]

In [None]:
for i in data_dict_supple:
    print()
    sds

In [None]:
fig=plot_slice_figure(dataset_name=dataset_name,
                      data_dict=data_dict_supple,
                      prompt_info=variable_dict[dataset_name]["prompt_info"],
                      example_per_row=10,
                      row_per_slice=10,
                      normalize=True, 
                      show_small_box=True,
                      skip_section=0,
                      print_alphabet=True,
                      print_legend_number=True,
                      print_legend_color=True,
                      print_legend_color_idx=2+90,
                      task_type="melanoma",
                      true_pred_count_fontsize=16,
                      fontsize=16,
                      slice_title_fontsize=16,
                      figure_title=None)
# fig.savefig(log_dir/"plots"/f"model_audit_ADAE_top_bottom_{concept_name}.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"model_audit_ADAE_top_bottom_{concept_name}.pdf", bbox_inches='tight')

In [None]:
fig.savefig(log_dir/"plots"/f"model_audit_ADAE_top_bottom_{concept_name}.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/f"model_audit_ADAE_top_bottom_{concept_name}.pdf", bbox_inches='tight')

In [None]:
plot_slice_figure??

In [None]:
skincon_Atrophy
skincon_Poikiloderma
pinkish

In [None]:
label_subset_proveai.astype(int).loc[top_idx]

In [None]:
variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].sort_values(ascending=True).index[:99]    

In [None]:
for i in variable_dict[dataset_name]["similarity_matrix"].columns:
    print(i)

In [None]:


np.random.choice(variable_dict[dataset_name]["similarity_matrix"][concept_name].loc[idx_select].index,
                size=100)

In [None]:
prove_logits_true.set_index("image_name").loc[idx_select]["prediction"].sum()

In [None]:
similarity_thres=variable_dict[dataset_name]["similarity_matrix"].quantile(0.5, axis=0)
data_dict_main={}
data_dict_supple={}
count=0
    
    
idx_select=pd.Series(variable_dict[dataset_name]["y_pos"], index=variable_dict[dataset_name]["metadata_all"].index)[
    pd.Series(variable_dict[dataset_name]["y_pos"], index=variable_dict[dataset_name]["metadata_all"].index)==0   
].index    

data_test={
    "targets": label_subset_proveai.astype(int),
    "preds": (logits_subset_proveai>threshold_select).astype(int),
#         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
# #                              test_result["labels_ref"].sort_values("kmeans_dist").index
#                             ],
    "sample_list_list":[
        logits_subset_proveai[(logits_subset_proveai>threshold_select)].index[:100], 
        logits_subset_proveai[(logits_subset_proveai<=threshold_select)].index[:100]
    ],           
    "main_title": ["wrong", "correct"],
#     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
}  

data_dict_main[count]=data_test
data_dict_supple[count]=data_test       

count+=1

In [None]:
fig=plot_slice_figure(dataset_name=dataset_name,
                      data_dict=data_dict_supple,
                      prompt_info=variable_dict[dataset_name]["prompt_info"],
                      example_per_row=10,
                      row_per_slice=10,
                      normalize=True, 
                      show_small_box=True,
                      skip_section=0,
                      print_alphabet=True,
                      print_legend_number=True,
                      print_legend_color=True,
                      true_pred_count_fontsize=16,
                      fontsize=16,
                      slice_title_fontsize=16,
                      figure_title=None)
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.pdf", bbox_inches='tight')

In [None]:
fig.savefig(log_dir/"plots"/f"model_audit_ADAE_test_sorted_wrong_correct.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/f"model_audit_ADAE_test_sorted_wrong_correct.pdf", bbox_inches='tight')

In [None]:
sub_title_plus

In [None]:
skincon_Poikiloderma

In [None]:
similarity_thres

In [None]:
similarity_info_focus_copy_group_proveai_concept_only.sort_values("accuracy")\
[["count", "accuracy", "loss", "label_frequent"]]

In [None]:
cluster_idx 0 / metric: 0.047619047619047616 / 0.37662337662337664/ mean metric: 0.3405511811023622
cluster_idx 1 / metric: 0.1388888888888889 / 0.37662337662337664/ mean metric: 0.3405511811023622
cluster_idx 2 / metric: 0.21212121212121213 / 0.6140350877192983/ mean metric: 0.3405511811023622
cluster_idx 3 / metric: 0.2682926829268293 / 0.3488372093023256/ mean metric: 0.3405511811023622
cluster_idx 0 / metric: 0.047619047619047616 / 0.37662337662337664/ mean metric: 0.3405511811023622
cluster_idx 1 / metric: 0.1388888888888889 / 0.37662337662337664/ mean metric: 0.3405511811023622
cluster_idx 2 / metric: 0.21212121212121213 / 0.6140350877192983/ mean metric: 0.3405511811023622
cluster_idx 3 / metric: 0.2682926829268293 / 0.3488372093023256/ mean metric: 0.3405511811023622
cluster_idx 4 / metric: 0.3488372093023256 / 0.3488372093023256/ mean metric: 0.3405511811023622
cluster_idx 5 / metric: 0.3684210526315789 / 0.3684210526315789/ mean metric: 0.3405511811023622
cluster_idx 6 / metric: 0.37662337662337664 / 0.37662337662337664/ mean metric: 0.3405511811023622
cluster_idx 7 / metric: 0.3793103448275862 / 0.3793103448275862/ mean metric: 0.3405511811023622
cluster_idx 8 / metric: 0.4230769230769231 / 0.4230769230769231/ mean metric: 0.3405511811023622
cluster_idx 9 / metric: 0.6140350877192983 / 0.6140350877192983/ mean metric: 0.3405511811023622
508

In [None]:
cluster_idx 0 / metric: 0.047619047619047616 /  mean metric: 0.3405511811023622
cluster_idx 1 / metric: 0.1388888888888889 /  mean metric: 0.3405511811023622
cluster_idx 2 / metric: 0.21212121212121213 /  mean metric: 0.3405511811023622
cluster_idx 3 / metric: 0.2682926829268293 /  mean metric: 0.3405511811023622

In [None]:
fig.savefig("temp.png", bbox_inches='tight')

In [None]:
fig=plot_slice_figure(dataset_name=dataset_name,
                      data_dict=data_dict_supple,
                  prompt_info=variable_dict[dataset_name]["prompt_info"],
                  example_per_row=5,
                  row_per_slice=1,
                  normalize=True, 
                  show_small_box=True,
                  skip_section=0,
                  print_alphabet=True,
                  print_legend_number=True,
                  print_legend_color=True,
                    true_pred_count_fontsize=16,
                    fontsize=16,
                      slice_title_fontsize=16,
                  figure_title=None)
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.pdf", bbox_inches='tight')

In [None]:
data_dict_supple

In [None]:
plot_slice_figure??

In [None]:
test_result_list_proveai_concept_only[]

In [None]:
data_dict_supple[0]["targets"].sum(),data_dict_supple[0]["preds"].sum()

In [None]:
data_dict_supple[0]["targets"].shape

In [None]:
data_dict_supple[1]["targets"].sum(),data_dict_supple[1]["preds"].sum()

In [None]:
data_dict_supple[2]["targets"].sum(),data_dict_supple[2]["preds"].sum()

In [None]:
data_dict_supple[2]["targets"]

In [None]:
data_dict_supple[0]["preds"].sum()

In [None]:
test_result_list_proveai_concept_only

In [None]:
for test_result in test_result_list_proveai_concept_only:
    y_true=prove_logits_true.set_index("image_name")["truth"].loc[test_result["labels"].index]
    y_pred=(prove_logits_true.set_index("image_name")["target"].loc[test_result["labels"].index])>0.05
    #y_pred=(prove_logits_true.set_index("image_name")["prediction"].loc[test_result["labels"].index])
    y_score=prove_logits_true.set_index("image_name")["target"].loc[test_result["labels"].index]

    if (y_true==y_pred).all():
        continue
    
    tn, fp, fn, tp = (sklearn.metrics.confusion_matrix(y_true=y_true, y_pred=y_pred).ravel())
    print(tn, fp, fn, tp)
    print("sensitivity", tp / (tp+fn)) # tp / P
    print("specificity", tn / (tn+fp)) # tn / N
    if y_true.all() or (~y_true).all():
        pass
    else:
        print("auroc", sklearn.metrics.roc_auc_score(y_true=y_true,
                                     y_score=y_score))  
    print("accuracy", (y_true==y_pred).mean())    
    print()

In [None]:
fig=plot_slice_figure(dataset_name=dataset_name,
                      data_dict=data_dict_main,
                  prompt_info=variable_dict[dataset_name]["prompt_info"],
                  example_per_row=5,
                  row_per_slice=1,
                  normalize=True, 
                  show_small_box=True,
                  skip_section=0,
                  print_alphabet=False,
                  print_legend_number=False,
                  print_legend_color=False,
                  slice_title_fontsize=27,
                  figure_title=("E. ", "Trained at Hosp. Barcelona / Tested at Med U. Vienna"))
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_main.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_main.pdf", bbox_inches='tight')