# set working directory

In [None]:
import os

import hydra
import omegaconf
import pyrootutils

root = pyrootutils.setup_root(os.path.abspath("inherently_interpretable_model.ipynb"), pythonpath=True)
import os

os.chdir(root)

# set python path

In [None]:
import sys

sys.path.append(str(root / "src"))

# import packages

In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import tqdm
from scipy.special import softmax
from scipy.stats import norm, pearsonr
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split
from torchvision import models, transforms

import clip
from MONET.datamodules.multiplex_datamodule import MultiplexDatamodule
from MONET.utils.loader import custom_collate_per_key, dataloader_apply_func
from MONET.utils.static import (
    concept_to_prompt,
    fitzpatrick17k_disease_label,
    fitzpatrick17k_ninelabel,
    fitzpatrick17k_threelabel,
    skincon_cols,
)
from MONET.utils.text_processing import generate_prompt_token_from_concept
from MONET.utils.io import load_pkl
from PIL import Image

In [None]:
def wandb_to_exppath(wandb, log_path="/gscratch/cse/chanwkim/MONET_log/train/runs"):
    log_path = Path(log_path)
    for experiment in os.listdir(log_path):
        if os.path.exists(log_path / experiment / "wandb"):
            filenames = os.listdir(log_path / experiment / "wandb")
            filename = [filename for filename in filenames if filename.startswith("run")][0][-8:]
            if filename == wandb:
                return log_path / experiment
    raise RuntimeError("not found")


# exppath = wandb_to_exppath(
#     wandb="15eh81uv", log_path="/projects/leelab2/chanwkim/dermatology_datasets/logs/train/runs"
# )
# print([exppath / "checkpoints" / ckpt for ckpt in os.listdir(exppath / "checkpoints/")])

In [None]:
log_dir = Path("logs")

In [None]:
!gpustat

# Initialize Model

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:6"

cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

model = hydra.utils.instantiate(cfg_model)
model.to(model_device)
model.eval()

In [None]:
model_path_dir = {
    "zt0n2xd0": "/projects/leelab2/chanwkim/dermatology_datasets/logs/train/runs/2023-01-17_20-58-15/checkpoints/last.ckpt",
}

if model_name != "ViT-L/14":
    model_path = model_path_dir[model_name]
    loaded = torch.load(model_path, map_location=model_device)
    model.load_state_dict(loaded["state_dict"])
    print('loaded')

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:6"

cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

model_vanilla = hydra.utils.instantiate(cfg_model)
model_vanilla.to(model_device)
model_vanilla.eval()

In [None]:
variable_dict={}

In [None]:
def setup_dataloader(dataset_name):    
    if dataset_name=="clinical_fd_clean_nodup_nooverlap":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "clinical_fd_clean_nodup_nooverlap=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()  
        
    elif dataset_name=="derm7pt_derm_nodup":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt_derm_nodup=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()    
        
    elif dataset_name=="derm7pt_clinical_nodup":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt_clinical_nodup=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()   
        
    elif dataset_name=="derm7pt":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()
        
    elif dataset_name=="allpubmedtextbook":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "pubmed=all,textbook=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()
        
        dataloader = dm.test_dataloader()          
    
    elif dataset_name=="isic_nodup_nooverlap":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "isic_nodup_nooverlap=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()    
        
    return {"dataloader": dataloader}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", 
                    ]:
    variable_dict.setdefault(dataset_name, {})
    variable_dict[dataset_name].update(setup_dataloader(dataset_name))

In [None]:
def batch_func(batch):
    with torch.no_grad():
        batch["image"] = batch["image"].to(model_device)
        image_features = model.model_step_with_image(batch)["image_features"]
        image_features_vanilla = model_vanilla.model_step_with_image(batch)["image_features"]
    # print(batch["metadata"])
    return {
        "image_features": image_features.detach().cpu(),
        "image_features_vanilla": image_features_vanilla.detach().cpu(),
        "metadata": batch["metadata"],
    }

def setup_features(dataset_name, dataloader):
    if dataset_name=="isic_nodup_nooverlap":
        loader_applied = torch.load(log_dir/"image_features"/"isic_nodup_nooverlap.pt", map_location="cpu")
        image_features = loader_applied["image_features"].cpu()
        metadata_all = loader_applied["metadata_all"]
        
        loader_applied2 = torch.load(log_dir/"image_features"/"isic_nodup_nooverlap_vanilla.pt", map_location="cpu")
        image_features_vanilla = loader_applied2["image_features_vanilla"].cpu()
        metadata_all_vanilla = loader_applied2["metadata_all"]  
        
        assert np.all(metadata_all.index==metadata_all_vanilla.index)

        return {"image_features":image_features, 
                "image_features_vanilla":image_features_vanilla,
                "metadata_all": metadata_all}
    
    else:
        loader_applied = dataloader_apply_func(
            dataloader=dataloader,
            func=batch_func,
            collate_fn=custom_collate_per_key,
        )
        image_features = loader_applied["image_features"].cpu()
        image_features_vanilla = loader_applied["image_features_vanilla"].cpu()
        metadata_all = loader_applied["metadata"]

        return {"image_features":image_features, 
                "image_features_vanilla":image_features_vanilla,
                "metadata_all": metadata_all}

# def setup_features(dataset_name, dataloader):

#     loader_applied = dataloader_apply_func(
#         dataloader=dataloader,
#         func=batch_func,
#         collate_fn=custom_collate_per_key,
#     )
#     image_features = loader_applied["image_features"].cpu()
#     image_features_vanilla = loader_applied["image_features_vanilla"].cpu()
#     metadata_all = loader_applied["metadata"]

#     return {"image_features":image_features, 
#             "image_features_vanilla":image_features_vanilla,
#             "metadata_all": metadata_all}
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", 
                    ]:
    print(dataset_name)
    variable_dict[dataset_name].update(setup_features(dataset_name, variable_dict[dataset_name]["dataloader"])) 

In [None]:
def normalize_embedding(dataset_name, image_features):
    #prompt_ref_tokenized = clip.tokenize(prompt_ref, truncate=True)
    #output = model.model_step_with_text({"text": prompt_ref_tokenized.to(model_device)})
    #prompt_ref_embedding=output["text_features"].detach().cpu()
    #prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=1, keepdim=True)      
    
    image_features_norm = image_features / image_features.norm(dim=1, keepdim=True)
    
    return {"image_features_norm": image_features_norm}    

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", 
                    ]:
    variable_dict[dataset_name].update(
        {"image_features_vanilla_norm":normalize_embedding(dataset_name, 
                        variable_dict[dataset_name]["image_features_vanilla"])["image_features_norm"]}
    )           
    variable_dict[dataset_name].update(
        {"image_features_norm":normalize_embedding(dataset_name, 
                        variable_dict[dataset_name]["image_features"])["image_features_norm"]}
    )  
#     variable_dict[dataset_name].update(
#         {"image_features_vanilla_norm":normalize_embedding(dataset_name, 
#                         variable_dict[dataset_name]["image_features_vanilla"])["image_features_norm"]}
#     )    

In [None]:
variable_dict[dataset_name].keys()

In [None]:
isic_diagnosis_malignant_mapping=\
{'AIMP':'indeterminate',
'acrochordon':'benign',
'actinic keratosis':'benign', # 
'angiofibroma or fibrous papule':'benign', 
'angiokeratoma':'benign',
'angioma':'benign',
'atypical melanocytic proliferation':'indeterminate',
'atypical spitz tumor':'indeterminate', #
'basal cell carcinoma':'malignant', #
'cafe-au-lait macule':'benign',
'clear cell acanthoma':'benign',
'dermatofibroma':'benign', #
'lentigo NOS':'benign',
'lentigo simplex':'benign',
'lichenoid keratosis':'benign',
'melanoma':'malignant',
'melanoma metastasis':'malignant',
'neurofibroma':'benign',
'nevus':'benign',
'other':'indeterminate',
'pigmented benign keratosis':'benign', #??
'scar':'benign',
'seborrheic keratosis':'benign',
'solar lentigo':'benign',
'squamous cell carcinoma':'malignant',
'vascular lesion':'benign', # 
'verruca':'benign'
}

def isic_map_diagnosis_malignant(diagnosis, benign_malignant):
#     if diagnosis=="basal cell carcinoma":
#         print(diagnosis_malignant_mapping[diagnosis], benign_malignant)    
    if isinstance(benign_malignant, str):
        return benign_malignant
    elif diagnosis in isic_diagnosis_malignant_mapping.keys():
        return isic_diagnosis_malignant_mapping[diagnosis]
    elif np.isnan(diagnosis):
        return "indeterminate"
    else:
        raise RuntimeError

In [None]:
derm7pt_diagnosis_malignant_mapping=\
{   'basal cell carcinoma':'malignant', 
    'blue nevus': 'benign', #
    'clark nevus':'benign', # 
    'combined nevus': 'benign', #
    'congenital nevus': 'benign', #
    'dermal nevus': 'benign', 
    'dermatofibroma':'benign', 
    'lentigo': 'benign',
    'melanoma (in situ)': 'malignant',
    'melanoma (less than 0.76 mm)': 'malignant',
    'melanoma (0.76 to 1.5 mm)': 'malignant',
    'melanoma (more than 1.5 mm)': 'malignant',
    'melanoma metastasis': 'malignant',
    'melanosis': 'benign',# 
    'miscellaneous': 'unknown', #
    'recurrent nevus': 'benign', #
    'reed or spitz nevus': 'benign', #
    'seborrheic keratosis':'benign',
    'vascular lesion': 'benign', 
    'melanoma': 'malignant',
}

def derm7pt_map_diagnosis_malignant(diagnosis):
#     if diagnosis=="basal cell carcinoma":
#         print(diagnosis_malignant_mapping[diagnosis], benign_malignant)    
    if diagnosis in derm7pt_diagnosis_malignant_mapping.keys():
        return derm7pt_diagnosis_malignant_mapping[diagnosis]
    elif np.isnan(diagnosis):
        return "indeterminate"
    else:
        raise RuntimeError

In [None]:
ddi_map = {
    "acral-melanotic-macule": "melanoma look-alike",
    "atypical-spindle-cell-nevus-of-reed": "melanoma look-alike",
    "benign-keratosis": "melanoma look-alike",
    "blue-nevus": "melanoma look-alike",
    "dermatofibroma": "melanoma look-alike",
    "dysplastic-nevus": "melanoma look-alike",
    "epidermal-nevus": "melanoma look-alike",
    "hyperpigmentation": "melanoma look-alike",
    "keloid": "melanoma look-alike",
    "inverted-follicular-keratosis": "melanoma look-alike",
    "melanocytic-nevi": "melanoma look-alike",
    "melanoma": "melanoma",
    "melanoma-acral-lentiginous": "melanoma",
    "melanoma-in-situ": "melanoma",
    "nevus-lipomatosus-superficialis": "melanoma look-alike",
    "nodular-melanoma-(nm)": "melanoma",
    "pigmented-spindle-cell-nevus-of-reed": "melanoma look-alike",
    "seborrheic-keratosis": "melanoma look-alike",
    "seborrheic-keratosis-irritated": "melanoma look-alike",
    "solar-lentigo": "melanoma look-alike",
}

In [None]:
ddi_map.keys()

In [None]:
normal_skin=['clean', "smooth", 'Healthy', 'normal', 'soft', 'flat']
concept_reference_dict = {
    "Asymmetry": ["Symmetry", "Regular", "Uniform"],
    "Irregular": ["Regular", "Smooth"],
    "Black": ["White", "Creamy", "Colorless", "Unpigmented"],
    "Blue": ["Green", "Red"],
    "White": ["Black", "Colored", "Pigmented"],
    "Brown": ["Pale", "White"],
    "Erosion":["Deposition", "Buildup"],
    "Multiple Colors": ["Single Color", "Unicolor"],
    "Tiny": ["Large", "Big"],
    "Regular": ["Irregular"],  
}

concept_reference_dict.update({
    'derm7ptconcept_pigment network':['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat'],
    'derm7ptconcept_regression structure':['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat'],
    'derm7ptconcept_pigmentation':['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat'],
    'derm7ptconcept_blue whitish veil':['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat'],
    'derm7ptconcept_vascular structures':['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat'],
    'derm7ptconcept_streaks':['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat'],
    'derm7ptconcept_dots and globules':['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat'],
})


    
for concept_name in skincon_cols:    
    if concept_name=="skincon_Patch":
        negative_terms=["Spotted"]    
    elif concept_name == "skincon_Exudate":
        negative_terms = ["Absence"]
    elif concept_name == "skincon_Xerosis":
        negative_terms = ["Moisturized"]
    elif concept_name == "skincon_Warty/Papillomatous":
        negative_terms = ["Smooth"]
    elif concept_name == "skincon_Dome-shaped":
        negative_terms = ["Flat"]
    elif concept_name == "skincon_Brown(Hyperpigmentation)":
        negative_terms = ["Hypopigmentation"]
    elif concept_name == "skincon_Translucent":
        negative_terms = ["Opaque"]
    elif concept_name == "skincon_White(Hypopigmentation)":
        negative_terms = ["Hyperpigmentation"]
    elif concept_name == "skincon_Purple":
        negative_terms = ["Yellow"]
    elif concept_name == "skincon_Yellow":
        negative_terms = ["Purple"]
    elif concept_name == "skincon_Black":
        negative_terms = ["White", "Creamy", "Colorless", "Unpigmented"]
    elif concept_name == "skincon_Lichenification":
        negative_terms = ["Softening"]
    elif concept_name == "skincon_Blue":
        negative_terms = ["Orange"]
    elif concept_name == "skincon_Gray":
        negative_terms = ["Colorful"]
    else:
        negative_terms = ['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat']
        
    concept_reference_dict[concept_name]=negative_terms
    


# print()
# for key in concept_reference_dict.keys():
#     print(f"{key}: {concept_reference_dict[key]}")

concept_reference_dict_={}
for key,value in concept_reference_dict.items():
    concept_reference_dict_[key if ("skincon_" in key) or ("derm7ptconcept_" in key) else f"cbm_{key}"]=\
    [value if ("skincon_" in v) or ("derm7ptconcept_" in v) else f"cbm_{v}" for v in value]
concept_reference_dict=concept_reference_dict_;del concept_reference_dict_

# concept_reference_dict["red"]=["cbm_White",  "cbm_Colorless", "cbm_Unpigmented"]

print()
for key in concept_reference_dict.keys():
    print(f"{key}: {concept_reference_dict[key]}")

In [None]:
def set_config(dataset_name, metadata_all):
    if "clinical_fd_clean" in dataset_name:
        y_pos_malignant=(((metadata_all["source"]=="fitz")&(metadata_all["three_partition_label"]=="malignant"))|
              ((metadata_all["source"]=="ddi")&(metadata_all["malignant"] == True))).values
        
        y_pos_melanoma=(((metadata_all["source"]=="fitz")&(metadata_all["nine_partition_label"]=="malignant melanoma"))|
              (metadata_all.apply(lambda x: (ddi_map[x["disease"]]=="melanoma") if (x["source"]=="ddi" and x["disease"] in ddi_map.keys()) else False, axis=1))).values
        
        valid_idx_malignant=(metadata_all["skincon_Do not consider this image"]!=1).values
        valid_idx_melanoma=((metadata_all["skincon_Do not consider this image"]!=1)&\
                          (
                            ((metadata_all["source"]=="fitz")&((metadata_all["nine_partition_label"]=='malignant melanoma')|
                                                              (metadata_all["nine_partition_label"]=='benign melanocyte')|
                                                              (metadata_all["label"]=='seborrheic keratosis')|
                                                              (metadata_all["label"]=='dermatofibroma')))|
                            ((metadata_all["source"]=="ddi")&(metadata_all["disease"].isin(ddi_map.keys())))
                          )).values
        
        concept_list=skincon_cols
        
        concept_list=concept_list+\
                        np.unique([f"cbm_{i.lower()}" for i in ['clean', "smooth", 'Healthy', 'normal', 'soft', 'flat']+\
                            ["Asymmetry", "Symmetry", "Regular", "Uniform"]+\
                            ["Irregular", "Regular", "Smooth"]+\
                            ["Black", "White", "Creamy", "Colorless", "Unpigmented"]+\
                            ["Blue", "Green", "Red"]+\
                            ["White", "Black", "Colored", "Pigmented"]+\
                            ["Brown", "Pale", "White"]+\
                            ["Erosion", "Deposition", "Buildup"]+\
                            ["Multiple colors", "Single Color", "Unicolor"]+\
                            ["Tiny", "Large", "Big"]+\
                            ["Regular", "Irregular"]]).tolist()
        
        
        concept_list=concept_list+[concept_name 
                                   for concept_name in np.unique(list(concept_reference_dict.keys())+[j for i in list(concept_reference_dict.values()) for j in i])
                                   if "skincon_" not in concept_name] 
        
        concept_list=[concept_name for concept_name in concept_list if "derm7ptconcept" not in concept_name]

        
    elif "isic" in dataset_name:  
        metadata_all["benign_malignant_full"]=\
        metadata_all.apply(lambda x: isic_map_diagnosis_malignant(x["diagnosis"], x["benign_malignant"]), axis=1)
        #metadata_all["benign_malignant_full"].value_counts()
        #metadata_all.groupby("diagnosis").apply(lambda x: x["benign_malignant_full"].value_counts())
        metadata_all["benign_malignant_bool"]=metadata_all["benign_malignant_full"].str.contains("malignant")
        
        
        #print(metadata_all[["benign_malignant_full", "diagnosis"]].value_counts())
        
        y_pos_malignant=metadata_all["benign_malignant_bool"].values
        y_pos_melanoma=metadata_all["diagnosis"].fillna('null').str.contains("melanoma").values
        
#         print(metadata_all["diagnosis"].fillna(0).value_counts())
        
        valid_idx_malignant = (metadata_all["benign_malignant_full"].str.contains("malignant")|metadata_all["benign_malignant_full"].str.contains("benign")).values
        valid_idx_melanoma = (~metadata_all["diagnosis"].isnull()).values
        
        concept_list=skincon_cols
        
        concept_list=concept_list+\
                            ["purple pen", 
                             "finger", 
                             "nail", 
                             "pinkish", 
                             "red", 
                             "hair", 
                             "orange sticker", 
                             "blue sticker", 
                             "red sticker",
                             "dermoscope border",
                             "gel",
                             "malignant",
                             "melanoma"]      
        
        
        concept_list=concept_list+[f"derm7ptconcept_{derm7ptconcept}" for derm7ptconcept in ["pigment network", "typical pigment network", "atypical pigment network",
                                   "regression structure",
                                   "pigmentation", "regular pigmentation", "irregular pigmentation",
                                   "blue whitish veil", 
                                   "vascular structures", "typical vascular structures", "atypical vascular structures",
                                   "streaks", "regular streaks", "irregular streaks",
                                   "dots and globules", "regular dots and globules", "irregular dots and globules",
                                  ]]
        
        concept_list=concept_list+[f"isicconcept_{isicconcept}" for isicconcept in ["pigment_network", 
                                                                                   "negative_network",
                                                                                   "milia_like_cyst", 
                                                                                   "streaks", 
                                                                                   "globules"]]
        
        concept_list=concept_list+[f"disease_{disease_name}" for disease_name in ['seborrheic keratosis', 'nevus', 'squamous cell carcinoma',
                        'melanoma', 'lichenoid keratosis', 'lentigo',
                        'actinic keratosis', 'basal cell carcinoma', 'dermatofibroma',
                        'atypical melanocytic proliferation', 'verruca',
                        'clear cell acanthoma', 'angiofibroma or fibrous papule', 'scar',
                        'angioma', 'atypical spitz tumor', 'solar lentigo', 'AIMP',
                        'neurofibroma', 'lentigo simplex', 'acrochordon', 
                        'angiokeratoma', 'vascular lesion', 'cafe-au-lait macule',
                        'pigmented benign keratosis']]
        
        concept_list=concept_list+[concept_name 
                                   for concept_name in np.unique(list(concept_reference_dict.keys())+[j for i in list(concept_reference_dict.values()) for j in i])
                                   if "skincon_" not in concept_name]      
        
    elif "derm7pt" in dataset_name:  
        metadata_all["benign_malignant_full"]=\
        metadata_all.apply(lambda x: derm7pt_map_diagnosis_malignant(x["diagnosis"]), axis=1)
        metadata_all["benign_malignant_bool"]=metadata_all["benign_malignant_full"].str.contains("malignant")
        
        y_pos_malignant=metadata_all["benign_malignant_bool"].values
        y_pos_melanoma=metadata_all["diagnosis"].str.contains("melanoma").values
        
        valid_idx_malignant = (metadata_all["benign_malignant_full"].str.contains("malignant")|metadata_all["benign_malignant_full"].str.contains("benign")).values
        valid_idx_melanoma = (~metadata_all["diagnosis"].isnull()).values
        
        concept_list=skincon_cols
        
        concept_list=concept_list+\
                            ["purple pen", 
                             "finger", 
                             "nail", 
                             "pinkish", 
                             "red", 
                             "hair", 
                             "orange sticker", 
                             "blue sticker", 
                             "red sticker",
                             "dermoscope border",
                             "gel",
                             "malignant",
                             "melanoma"]        
        
        concept_list=concept_list+[f"derm7ptconcept_{derm7ptconcept}" for derm7ptconcept in ["pigment network", "typical pigment network", "atypical pigment network",
                                   "regression structure",
                                   "pigmentation", "regular pigmentation", "irregular pigmentation",
                                   "blue whitish veil", 
                                   "vascular structures", "typical vascular structures", "atypical vascular structures",
                                   "streaks", "regular streaks", "irregular streaks",
                                   "dots and globules", "regular dots and globules", "irregular dots and globules",
                                  ]]
        
        concept_list=concept_list+[f"isicconcept_{isicconcept}" for isicconcept in ["pigment_network", 
                                                                                   "negative_network",
                                                                                   "milia_like_cyst", 
                                                                                   "streaks", 
                                                                                   "globules"]]
             
        
        concept_list=concept_list+[f"disease_{disease_name}" for disease_name in ['basal cell carcinoma', 'blue nevus', 'clark nevus',
                                                               'combined nevus', 'congenital nevus', 'dermal nevus',
                                                               'dermatofibroma', 'lentigo', 'melanoma', 'melanosis',
                                                                'recurrent nevus', 'reed or spitz nevus',
                                                               'seborrheic keratosis', 'vascular lesion']]           
        
        
        concept_list=concept_list+[concept_name 
                                   for concept_name in np.unique(list(concept_reference_dict.keys())+[j for i in list(concept_reference_dict.values()) for j in i])
                                   if "skincon_" not in concept_name]           
        
    
    print("y_pos_malignant", y_pos_malignant.sum(), "null", np.isnan(y_pos_malignant).sum())
    print("valid_idx_malignant", valid_idx_malignant.sum(), np.isnan(valid_idx_malignant).sum())
    print("y_pos_melanoma", y_pos_melanoma.sum(), np.isnan(y_pos_melanoma).sum())
    print("valid_idx_melanoma", valid_idx_melanoma.sum(), np.isnan(valid_idx_melanoma).sum())
    
    return {"valid_idx_malignant": valid_idx_malignant,
            "valid_idx_melanoma": valid_idx_melanoma,
            "y_pos_malignant": y_pos_malignant,
            "y_pos_melanoma": y_pos_melanoma,
            "metadata_all": metadata_all,
            "concept_list": concept_list}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup", "isic_nodup_nooverlap"]:
    print(dataset_name)
    variable_dict[dataset_name].update(set_config(dataset_name, variable_dict[dataset_name]["metadata_all"]))
    print()

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["valid_idx_malignant"].sum()

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["valid_idx_melanoma"].sum()

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"]\
[variable_dict["clinical_fd_clean_nodup_nooverlap"]["valid_idx_malignant"]]["md5hash"].fillna("null").value_counts()

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"]\
[variable_dict["clinical_fd_clean_nodup_nooverlap"]["valid_idx_melanoma"]]["md5hash"].fillna("null").value_counts().sum()

In [None]:
769-275

In [None]:
494+275

In [None]:
variable_dict["derm7pt_derm_nodup"]["valid_idx_malignant"].sum()

In [None]:
def get_concept_embedding(dataset_name, concept_list, clip_model):
    prompt_info={}
    
    for concept_name in concept_list:
        if "clinical_fd_clean" in dataset_name:
            if concept_name.startswith("cbm_"):  
                cbm_concept_name=concept_name[4:]
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                prompt_target=[[prompt_template.format(term) for term in [cbm_concept_name]] for prompt_template in prompt_template_list]
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]      
            else:
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v    
                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
                #prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]            
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]

                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]          
        
        elif "isic" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
                
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil","blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)         
                    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                     
                
            elif concept_name.startswith("isicconcept_"):
                isicconcept=concept_name[12:]
                if isicconcept=="pigment_network":
                    concept_term_list=["pigment network"]
                elif isicconcept=="negative_network":
                    concept_term_list=["negative network"]
                elif isicconcept=="milia_like_cyst":
                    concept_term_list=["milia like cyst"]
                    concept_term_list=["seborrheic keratosis"]
                elif isicconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif isicconcept=="globules":
                    concept_term_list=["globules"]
                else:
                    raise ValueError(isicconcept)                
            
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                 
                
                
            elif concept_name.startswith("cbm_"):  
                cbm_concept_name=concept_name[4:]
                prompt_target=[[f"This is dermatoscopy of {cbm_concept_name}"],
                               [f"This is dermoscopy of {cbm_concept_name}"]] 
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                  
                
                
            elif concept_name.startswith("disease_"):  
                if concept_name=="disease_AIMP":
                    disease_name=concept_name[8:]
                    prompt_target=[["This is dermatoscopy of AIMP",
                                    "This is dermatoscopy of Atypical intraepidermal melanocytic proliferation"],
                                   ["This is dermoscopy of AIMP",
                                    "This is dermoscopy of Atypical intraepidermal melanocytic proliferation"]]
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                else:
                    disease_name=concept_name[8:]
                    prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                                   [f"This is dermoscopy of {disease_name}"]] 
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                    
        elif "derm7pt_derm" in dataset_name:
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
            elif concept_name.startswith("derm7ptconcept_"):
                derm7ptconcept=concept_name[15:]
                if derm7ptconcept=="pigment network":
                    concept_term_list=["pigment network", "brown lines forming a grid-like reticular pattern"]
                    concept_term_list=["pigment network", "intersecting brown lines"]
                elif derm7ptconcept=="typical pigment network":
                    concept_term_list=["typical pigment network", "regularly meshed pigment network",]
                elif derm7ptconcept=="atypical pigment network":
#                     concept_term_list=["pigment network", "atypical pigment network", "irregularly meshed pigment network"]
                    #concept_term_list=["atypical pigment network", "irregularly meshed pigment network", "branched streaks"]
                    concept_term_list=["atypical pigment network", "irregularly meshed pigment network"]
                elif derm7ptconcept=="regression structure":
                    concept_term_list=["regression structure"]
                elif derm7ptconcept=="pigmentation":
#                     concept_term_list=["pigmented", "pigmented lesion"]
                    concept_term_list=["pigmented", "pigmented lesion", "colored lesion"]    
                elif derm7ptconcept=="regular pigmentation":
                    concept_term_list=["regular pigmentation", "uniform and consistent coloration"]
                elif derm7ptconcept=="irregular pigmentation":
                    concept_term_list=["irregular pigmentation"]
                elif derm7ptconcept=="blue whitish veil":
                    concept_term_list=["blue whitish veil","blue white veil"]
                elif derm7ptconcept=="vascular structures":
                    concept_term_list=["vascular structures"]
                    concept_term_list=["vascular structures", "Hairpin vessels", "Comma vessels", "dotted vessels", "arborizing vessels"]
                elif derm7ptconcept=="typical vascular structures":
                    concept_term_list=["typical vascular structures"]
                elif derm7ptconcept=="atypical vascular structures":
                    concept_term_list=["atypical vascular structures"]
                elif derm7ptconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif derm7ptconcept=="regular streaks":
                    concept_term_list=["regular streaks", "uniformly spaced linear patterns"]
                elif derm7ptconcept=="irregular streaks":
                    concept_term_list=["irregular streaks"]
                elif derm7ptconcept=="dots and globules":
                    #concept_term_list=["dots and globules", "tiny, pinpoint pigmented specks", "Small, darkly pigmented dots"]
                    concept_term_list=["tiny dots", "globules", "dot clusters", "globule clusters"]
                    concept_term_list=["dots and globules", "scattered globules"]#, "dots and globules clusters"] 0.57
                    concept_term_list=["black dots and globules", "brown dots and globules", "scattered globules"] #0.
                elif derm7ptconcept=="regular dots and globules":
                    concept_term_list=["regular dots and globules"]
                elif derm7ptconcept=="irregular dots and globules":
                    concept_term_list=["irregular dots and globules"]
                else:
                    raise ValueError(derm7ptconcept)
    
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
                
            elif concept_name.startswith("isicconcept_"):
                isicconcept=concept_name[12:]
                if isicconcept=="pigment_network":
                    concept_term_list=["pigment network"]
                elif isicconcept=="negative_network":
                    concept_term_list=["negative network"]
                elif isicconcept=="milia_like_cyst":
                    concept_term_list=["milia like cyst"]
                    concept_term_list=["seborrheic keratosis"]
                elif isicconcept=="streaks":
                    concept_term_list=["streaks", "starburst", "linear patterns", "radially oriented linear projections", "regular, pigmented linear extensions", "irregular, pigmented linear extensions"]
                elif isicconcept=="globules":
                    concept_term_list=["globules"]
                else:
                    raise ValueError(isicconcept)                
            
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]]                  
                
            elif concept_name.startswith("cbm_"):  
                cbm_concept_name=concept_name[4:]
                prompt_target=[[f"This is dermatoscopy of {cbm_concept_name}"],
                               [f"This is dermoscopy of {cbm_concept_name}"]] 
                prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]                  
                
            elif concept_name.startswith("disease_"):  
                disease_name=concept_name[8:]
                prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                               [f"This is dermoscopy of {disease_name}"]] 
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]                    
                
        
        #print(prompt_target, prompt_ref)
        # target embedding
        prompt_target_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_target]
        with torch.no_grad():
            prompt_target_embedding = torch.stack([clip_model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_target_tokenized])
        prompt_target_embedding_norm=prompt_target_embedding/prompt_target_embedding.norm(dim=2, keepdim=True)          

        # reference embedding
        prompt_ref_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_ref]
        with torch.no_grad():
            prompt_ref_embedding = torch.stack([clip_model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_ref_tokenized])
        prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=2, keepdim=True)                

        prompt_info[concept_name]={"prompt_ref_embedding_norm":prompt_ref_embedding_norm,
                                   "prompt_target_embedding_norm":prompt_target_embedding_norm,
                                  }
        print(dataset_name, concept_name, "\n" , prompt_target, prompt_ref)
        print('-----------')
        #print(prompt_target_embedding_norm)
        #print(prompt_ref_embedding_norm)
        del prompt_ref
        del prompt_target
    
    return {"prompt_info": prompt_info}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup_nooverlap", "derm7pt_derm_nodup",
                     "isic_nodup_nooverlap", 
                    ]:

    variable_dict[dataset_name].update(
        {"prompt_info_vanilla":get_concept_embedding(dataset_name, 
                      concept_list=variable_dict[dataset_name]["concept_list"],
                             clip_model=model_vanilla)["prompt_info"]})  
        
    variable_dict[dataset_name].update(
        {"prompt_info":get_concept_embedding(dataset_name, 
                      concept_list=variable_dict[dataset_name]["concept_list"],
                             clip_model=model)["prompt_info"]})  

In [None]:
!ls logs/experiment_results/ -trl

In [None]:
# torch.save(variable_dict, "logs/experiment_results/cbm_variables_1003.pt")

In [None]:
variable_dict = torch.load("logs/experiment_results/cbm_variables_1003.pt", map_location="cpu")

In [None]:
# def calculate_similaity_score(image_features_norm, 
#                               prompt_target_embedding_norm,
#                               prompt_ref_embedding_norm,
#                               temp=1,
#                               normalize=True):

#     target_similarity=prompt_target_embedding_norm.float()@image_features_norm.T.float()
#     ref_similarity=prompt_ref_embedding_norm.float()@image_features_norm.T.float()
    
    
#     target_similarity_mean=target_similarity.mean(dim=[1])
#     ref_similarity_mean=ref_similarity.mean(axis=1)
         
#     if normalize:
#         similarity_score=scipy.special.softmax([target_similarity_mean.numpy()/temp, 
#                             ref_similarity_mean.numpy()/temp], axis=0)[0,:].mean(axis=0)   
#     else:
#         similarity_score=target_similarity_mean.mean(axis=0)

    
#     return similarity_score

In [None]:
variable_dict.keys()

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"][
~variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"]["skincon_Pigmented"].isnull()]\
["skincon_Pigmented"].fillna(-9).value_counts()

# manual annotation

In [None]:
def get_concept_bool_from_metadata(dataset_name, metadata_all, concept_name):
    if "derm7pt_derm" in dataset_name:
        if concept_name=="derm7ptconcept_pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
            concept_bool=(metadata_all["pigment_network"]!="absent")
        elif concept_name=="derm7ptconcept_typical pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
#                 valid_idx=(metadata_all["pigment_network"].str.contains("typical")).values
            concept_bool=(metadata_all["pigment_network"]=="typical")
        elif concept_name=="derm7ptconcept_atypical pigment network":
            valid_idx=(~metadata_all["pigment_network"].isnull()).values
#                 valid_idx=(metadata_all["pigment_network"].str.contains("typical")).values
            concept_bool=(metadata_all["pigment_network"]=="atypical")
        elif concept_name=="derm7ptconcept_regression structure":
            valid_idx=(~metadata_all["regression_structures"].isnull()).values
            concept_bool=(metadata_all["regression_structures"]!="absent")
        elif concept_name=="derm7ptconcept_pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
            concept_bool=(metadata_all["pigmentation"]!="absent")
        elif concept_name=="derm7ptconcept_regular pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
#                 valid_idx=(metadata_all["pigmentation"].str.contains("regular")).values
            concept_bool=(metadata_all["pigmentation"].str.contains(" regular"))
        elif concept_name=="derm7ptconcept_irregular pigmentation":
            valid_idx=(~metadata_all["pigmentation"].isnull()).values
#                 valid_idx=(metadata_all["pigmentation"].str.contains("regular")).values
            concept_bool=(metadata_all["pigmentation"].str.contains(" irregular"))
        elif concept_name=="derm7ptconcept_blue whitish veil":
            valid_idx=(~metadata_all["blue_whitish_veil"].isnull()).values
            concept_bool=(metadata_all["blue_whitish_veil"]!="absent")
        elif concept_name=="derm7ptconcept_vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"]!="absent")
        elif concept_name=="derm7ptconcept_typical vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"].isin(["within regression", "arborizing", "comma", "hairpin", "wreath"]))
        elif concept_name=="derm7ptconcept_atypical vascular structures":
            valid_idx=(~metadata_all["vascular_structures"].isnull()).values
            concept_bool=(metadata_all["vascular_structures"].isin(["dotted", "linear irregular"]))
        elif concept_name=="derm7ptconcept_streaks":
            #print(metadata_all["streaks"].value_counts())
            valid_idx=(~metadata_all["streaks"].isnull()).values
            concept_bool=(metadata_all["streaks"]!="absent")
        elif concept_name=="derm7ptconcept_regular streaks":
            valid_idx=(~metadata_all["streaks"].isnull()).values
            #valid_idx=(metadata_all["streaks"].str.contains("regular")).values
            concept_bool=(metadata_all["streaks"]=="regular")
        elif concept_name=="derm7ptconcept_irregular streaks":
            valid_idx=(~metadata_all["streaks"].isnull()).values
            #valid_idx=(metadata_all["streaks"].str.contains("regular")).values
            concept_bool=(metadata_all["streaks"]=="irregular")
        elif concept_name=="derm7ptconcept_dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]!="absent")
        elif concept_name=="derm7ptconcept_regular dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]=="regular")
        elif concept_name=="derm7ptconcept_irregular dots and globules":
            valid_idx=(~metadata_all["dots_and_globules"].isnull()).values
            concept_bool=(metadata_all["dots_and_globules"]=="irregular")
        else:
            raise ValueError(concept_name)
            
        concept_bool_true=concept_bool
        concept_bool_false=(~concept_bool)


    elif "isic" in dataset_name:
        if concept_name=='isicconcept_pigment_network':
            label=(metadata_all["pigment_network"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values
            concept_bool=(label>30)
        elif concept_name=='isicconcept_negative_network':
            label=(metadata_all["negative_network"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_milia_like_cyst':
            label=(metadata_all["milia_like_cyst"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_streaks':
            label=(metadata_all["streaks"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values    
            concept_bool=(label>30)
        elif concept_name=='isicconcept_globules':
            label=(metadata_all["globules"])
            valid_idx=(label!=-9).values
            concept_bool=(label>(label[valid_idx&(label>0)].quantile(0.9))).values
            concept_bool=(label>30)
        else:
            raise ValueError(concept_name)            
            
    elif "clinical_fd_clean" in dataset_name:
        if concept_name.startswith("skincon_"):
            concept_bool_true=(metadata_all[concept_name]==1)
            concept_bool_false=(metadata_all[concept_name]==0)
        else:
            raise ValueError(concept_name)
        
      
    return {"concept_bool_true": concept_bool_true,
            "concept_bool_false": concept_bool_false,
           }

In [None]:
def train_using_manual_labels(dataset_name,
                              metadata_train,
                              y_train,                              
                              metadata_test,
                              y_test, 
                              concept_list,
                              alpha=0.001):

    clf_manual_labels = SGDClassifier(loss="log_loss", penalty="l1", alpha=alpha)
    
    x_train=[]
    for concept_name in concept_list:
        concept_bool=get_concept_bool_from_metadata(dataset_name=dataset_name, 
                                             metadata_all=metadata_train, 
                                             concept_name=concept_name)
        
        assert len(concept_bool["concept_bool_true"])==len(concept_bool["concept_bool_false"])
        len(concept_bool["concept_bool_true"])==concept_bool["concept_bool_true"].sum()+concept_bool["concept_bool_false"].sum()
        x_train.append(concept_bool["concept_bool_true"])
        
    x_test=[]
    for concept_name in concept_list:
        concept_bool=get_concept_bool_from_metadata(dataset_name=dataset_name, 
                                             metadata_all=metadata_test, 
                                             concept_name=concept_name)
        
        assert len(concept_bool["concept_bool_true"])==len(concept_bool["concept_bool_false"])
        len(concept_bool["concept_bool_true"])==concept_bool["concept_bool_true"].sum()+concept_bool["concept_bool_false"].sum()
        x_test.append(concept_bool["concept_bool_true"])        
    
    x_train=pd.concat(x_train, axis=1).values
    x_test=pd.concat(x_test, axis=1).values
    
    clf_manual_labels.fit(x_train, y_train)
    y_pred = clf_manual_labels.predict(x_test)
    auc = roc_auc_score(y_test, clf_manual_labels.predict_proba(x_test)[:, 1])
    return auc, clf_manual_labels, clf_manual_labels.predict_proba(x_test)[:, 1]


# supervised model

In [None]:
!gpustat

In [None]:
import torchvision
from torch import nn
from torch.nn import functional as F
from torchmetrics import AUROC

class Classifier(nn.Module):
    def __init__(self, output_dim, freeze_backbone=False):
        super().__init__()
        self.backbone = torchvision.models.resnet50(weights="ResNet50_Weights.IMAGENET1K_V1")

        for param in self.backbone.parameters():
            if freeze_backbone:
                param.requires_grad = False
            else:
                param.requires_grad = True
            # pass

        head_in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.head = nn.Linear(head_in_features, output_dim)

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x


class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
        
def train_classifier(train_dataloader, val_dataloader, test_dataloader, freeze_backbone, lr, verbose, n_epoch=20):
#     from tqdm.auto import tqdm
    classifier = Classifier(output_dim=1, freeze_backbone=freeze_backbone)
    classifier_device = "cuda:4"
    classifier.to(classifier_device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=2, verbose=True)
    early_stopper = EarlyStopper(patience=5, min_delta=0)

    train_auroc = AUROC(task="binary")
    val_auroc = AUROC(task="binary")
    for epoch in range(n_epoch):
        train_loss = 0
        train_correct = 0
        classifier.train()
        if verbose:
            pbar=tqdm.tqdm(train_dataloader)
        else:
            pbar=train_dataloader
        for batch in pbar:
            image, label = batch["image"].to(classifier_device), batch["label"].to(classifier_device)
            logits = classifier(image)
            weight = torch.ones(label.shape[0], device=label.device)
            loss = F.binary_cross_entropy_with_logits(
                input=logits[:, 0], target=(label == 1).float()
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * image.size(0)
            train_auroc.update(logits, (label == 1))

        val_loss = 0
        val_auc_best=0
        classifier.eval()
        with torch.no_grad():
            if verbose:
                pbar=tqdm.tqdm(val_dataloader)
            else:
                pbar=val_dataloader
            for batch in pbar:
                image, label = batch["image"].to(classifier_device), batch["label"].to(
                    classifier_device
                )
                logits = classifier(image)
                loss = F.binary_cross_entropy_with_logits(
                    input=logits[:, 0], target=(label == 1).float()
                )
                val_loss += loss.item() * image.size(0)
                val_auroc.update(logits, (label == 1))
        if verbose:
            print(
                f"Epoch {epoch}: Train loss: {train_loss/len(train_dataloader.dataset):.3f} AUROC: {train_auroc.compute():.3f} Val loss: {val_loss/len(val_dataloader.dataset):.3f} AUROC: {val_auroc.compute():.3f}"
            )
        else:
            print(
                f"Epoch {epoch}: Train loss: {train_loss/len(train_dataloader.dataset):.3f} AUROC: {train_auroc.compute():.3f} Val loss: {val_loss/len(val_dataloader.dataset):.3f} AUROC: {val_auroc.compute():.3f}"
            )            
        val_auroc_compute=val_auroc.compute()
        if val_auroc.compute() > val_auc_best:
            val_auc_best = val_auroc.compute()        
        
        scheduler.step(val_loss)
        if early_stopper.early_stop(val_loss):
            if verbose:
                print("break")
            break
        train_auroc.reset()
        val_auroc.reset() 
        
    test_auroc = AUROC(task="binary")    
    test_loss = 0
    classifier.eval()
    test_preds=[]
    with torch.no_grad():
        if verbose:
            pbar=tqdm.tqdm(test_dataloader)
        else:
            pbar=test_dataloader        
        for batch in pbar:
            image, label = batch["image"].to(classifier_device), batch["label"].to(
                classifier_device
            )
            logits = classifier(image)
            loss = F.binary_cross_entropy_with_logits(
                input=logits[:, 0], target=(label == 1).float()
            )
            test_loss += loss.item() * image.size(0)
            test_auroc.update(logits, (label == 1))
            test_preds+=logits.detach().cpu().numpy().tolist()

    if verbose:
        print(
            f"Test loss: {test_loss/len(test_dataloader.dataset):.3f} AUROC: {test_auroc.compute():.3f}"
        )
    else:
         print(
            f"Test loss: {test_loss/len(test_dataloader.dataset):.3f} AUROC: {test_auroc.compute():.3f}"
        )
    
    return classifier, test_auroc.compute(), test_preds, val_auroc_compute



from MONET.datamodules.components.base_dataset import BaseDataset

def generate_data_and_train_resnet(dataloader,
                                   metadata_train, 
                                   y_select_train, 
                                   metadata_test, 
                                   y_select_test, 
                                   random_seed,
                                   freeze_backbone=False, lr=1e-3, n_epoch=20, verbose=False):  
    
    train_idx_train, train_idx_valid=train_test_split(np.arange(len(y_select_train)), random_state=random_seed, test_size=0.25, shuffle=True)
    
    metadata_train=metadata_train.copy()
    metadata_train["label"]=y_select_train
    metadata_train, metadata_val = metadata_train.iloc[train_idx_train], metadata_train.iloc[train_idx_valid]
    
    metadata_test=metadata_test.copy()
    metadata_test["label"]=y_select_test
    
    data_train = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=224,
        norm_mean=(0.485, 0.456, 0.406),
        norm_std=(0.229, 0.224, 0.225),
        augment=False,
        metadata_all=metadata_train,
        integrity_level="weak",
        return_label=["label"],
    )

    data_val = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=224,
        norm_mean=(0.485, 0.456, 0.406),
        norm_std=(0.229, 0.224, 0.225),
        augment=False,
        metadata_all=metadata_val,
        integrity_level="weak",
        return_label=["label"],
    )    
    
    data_test = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=224,
        norm_mean=(0.485, 0.456, 0.406),
        norm_std=(0.229, 0.224, 0.225),
        augment=False,
        metadata_all=metadata_test,
        integrity_level="weak",
        return_label=["label"],
    )        
    
    from MONET.utils.loader import custom_collate
    
    train_dataloader_resnet = torch.utils.data.DataLoader(
            data_train, batch_size=32, shuffle=True, pin_memory=True, collate_fn=custom_collate,
            drop_last=False, num_workers=4)

    valid_dataloader_resnet = torch.utils.data.DataLoader(
            data_val, batch_size=32, shuffle=True, pin_memory=True, collate_fn=custom_collate,
            drop_last=False, num_workers=4)

    test_dataloader_resnet = torch.utils.data.DataLoader(
            data_test, batch_size=32, shuffle=False, pin_memory=True, collate_fn=custom_collate,
            drop_last=False, num_workers=4)            
    
    classifier, auc_best, test_preds, auc_valid=train_classifier(train_dataloader_resnet, 
                                                  valid_dataloader_resnet, 
                                                  test_dataloader_resnet, freeze_backbone=freeze_backbone, lr=lr, verbose=verbose, n_epoch=n_epoch)
    
    return auc_best.item(), test_preds, auc_valid.item()
    

In [None]:
class BaseDatasetCache(BaseDataset):
    def __getitem__(self,idx):
        try:
            self.item_cache
        except:
            self.item_cache={}
            
        if idx not in self.item_cache.keys():
            self.item_cache[idx]=super().__getitem__(idx)

        return self.item_cache[idx]

In [None]:
!gpustat

In [None]:
# from sklearn.model_selection import train_test_split
# from MONET.datamodules.components.base_dataset import BaseDataset

# # def get_training_data(dataloader, metadata_all, valid_idx, y_pos, subset_idx_train, n_px=None):
# def generate_data_and_train_resnet(dataloader, ):
#     metadata_all_new = dataloader.dataset.metadata_all.copy()
#     metadata_all_new["label"]=y_pos.astype(int)
#     metadata_all_new_train=metadata_all_new[valid_idx&subset_idx_train]
#     # metadata_all_new=metadata_all_new.iloc[list(true_set.union(false_set))]

#     train_idx, val_idx = train_test_split(metadata_all_new_train.index, test_size=0.2, random_state=42)
    
#     print("train:", len(metadata_all_new_train))

#     if n_px is None:
#         n_px=dataloader.dataset.n_px
    
#     data_train = BaseDataset(
#         image_path_or_binary_dict=dataloader.dataset.image_path_dict,
#         n_px=n_px,
#         norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
#         norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
#         augment=False,
#         metadata_all=metadata_all_new_train.loc[train_idx],
#         integrity_level="weak",
#         return_label=["label"],
#     )

#     data_val = BaseDataset(
#         image_path_or_binary_dict=dataloader.dataset.image_path_dict,
#         n_px=n_px,
#         norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
#         norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
#         augment=False,
#         metadata_all=metadata_all_new_train.loc[val_idx],
#         integrity_level="weak",
#         return_label=["label"],
#     )

#     from MONET.utils.loader import custom_collate

#     train_dataloader = torch.utils.data.DataLoader(
#         dataset=data_train,
#         batch_size=64,
#         num_workers=4,
#         pin_memory=False,
#         persistent_workers=False,
#         shuffle=True,
#         collate_fn=custom_collate,
#     )
#     val_dataloader = torch.utils.data.DataLoader(
#         dataset=data_val,
#         batch_size=64,
#         num_workers=4,
#         pin_memory=False,
#         persistent_workers=False,
#         shuffle=False,
#         collate_fn=custom_collate,
#     )       
    
#     return train_dataloader, val_dataloader

# automatic

In [None]:
class SGDClassifierValidation(SGDClassifier):
    def __init__(self,
                 loss="hinge",
                *,
                penalty="l2",
                alpha=0.0001,
                l1_ratio=0.15,
                fit_intercept=True,
                max_iter=1000,
                tol=1e-3,
                shuffle=True,
                verbose=0,
                epsilon=0.1,
                n_jobs=None,
                random_state=None,
                learning_rate="optimal",
                eta0=0.0,
                power_t=0.5,
                early_stopping=False,
                validation_fraction=0.1,
                n_iter_no_change=5,
                class_weight=None,
                warm_start=False,
                average=False,):
        # Call the __init__ method of class A using super()
        super().__init__(
            loss=loss,
            penalty=penalty,
            alpha=alpha,
            l1_ratio=l1_ratio,
            fit_intercept=fit_intercept,
            max_iter=max_iter,
            tol=tol,
            shuffle=shuffle,
            verbose=verbose,
            epsilon=epsilon,
            n_jobs=n_jobs,
            random_state=random_state,
            learning_rate=learning_rate,
            eta0=eta0,
            power_t=power_t,
            early_stopping=early_stopping,
            validation_fraction=validation_fraction,
            n_iter_no_change=n_iter_no_change,
            class_weight=class_weight,
            warm_start=warm_start,
            average=average,
        )
#     def _make_validation_split(self, y, sample_mask):
#         print(y)
#         print(sample_mask)
#         validation_mask = super()._make_validation_split(y, sample_mask)
        
#         self.validation_mask_save=validation_mask
        
#         print('mask',validation_mask)
        
#         return validation_mask




    def _make_validation_split(self, y, sample_mask):
        """Split the dataset between training set and validation set.

        Parameters
        ----------
        y : ndarray of shape (n_samples, )
            Target values.

        sample_mask : ndarray of shape (n_samples, )
            A boolean array indicating whether each sample should be included
            for validation set.

        Returns
        -------
        validation_mask : ndarray of shape (n_samples, )
            Equal to True on the validation set, False on the training set.
        """
        from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit
        
        n_samples = y.shape[0]
        validation_mask = np.zeros(n_samples, dtype=np.bool_)
        if not self.early_stopping:
            # use the full set for training, with an empty validation set
            return validation_mask

        if getattr(estimator, "_estimator_type", None) == "classifier":
            splitter_type = StratifiedShuffleSplit
        else:
            splitter_type = ShuffleSplit
        cv = splitter_type(
            test_size=self.validation_fraction, random_state=self.random_state
        )
        import ipdb
        ipdb.set_trace()
        idx_train, idx_val = next(cv.split(np.zeros(shape=(y.shape[0], 1)), y))

        if not np.any(sample_mask[idx_val]):
            raise ValueError(
                "The sample weights for validation set are all zero, consider using a"
                " different random state."
            )

        if idx_train.shape[0] == 0 or idx_val.shape[0] == 0:
            raise ValueError(
                "Splitting %d samples into a train set and a validation set "
                "with validation_fraction=%r led to an empty set (%d and %d "
                "samples). Please either change validation_fraction, increase "
                "number of samples, or disable early_stopping."
                % (
                    n_samples,
                    self.validation_fraction,
                    idx_train.shape[0],
                    idx_val.shape[0],
                )
            )

        validation_mask[idx_val] = True
        return validation_mask

In [None]:
def train_with_best_temp_softmax(image_features_train_norm, 
                                 image_features_test_norm, 
                                 y_train, y_test, 
                                 concept_list, 
                                 prompt_info,
                                 prompt_info_vanilla,
                                 temp, 
                                 random_seed,
                                 concept_reference_dict,
                                 with_validation=False,
                                 num_ref_concepts=5, 
                                 alpha=0.001):
    
    x_dict_train = {}
    x_dict_test = {}


    for j, concept_target in enumerate(concept_list):
        similarity_list_train=[]
        similarity_list_test=[]
        
        similarity_train=prompt_info[concept_target]["prompt_target_embedding_norm"].float()@image_features_train_norm.T.float()
        similarity_test=prompt_info[concept_target]["prompt_target_embedding_norm"].float()@image_features_test_norm.T.float()
        similarity_list_train.append(similarity_train.mean(dim=[0,1]).detach().cpu())
        similarity_list_test.append(similarity_test.mean(dim=[0,1]).detach().cpu())        
        
        if concept_reference_dict is None:
            similarity_train=prompt_info[concept_target]["prompt_ref_embedding_norm"].float()@image_features_train_norm.T.float()
            similarity_test=prompt_info[concept_target]["prompt_ref_embedding_norm"].float()@image_features_test_norm.T.float()
            similarity_list_train.append(similarity_train.mean(dim=[0,1]).detach().cpu())
            similarity_list_test.append(similarity_test.mean(dim=[0,1]).detach().cpu())                                
        else:
            concept_sampled=np.random.choice(a=concept_reference_dict[concept_target], size=min(num_ref_concepts, len(concept_reference_dict[concept_target])), replace=False).tolist()
            for concept_ref in concept_sampled:
                similarity_train=prompt_info[concept_ref]["prompt_target_embedding_norm"].float()@image_features_train_norm.T.float()
                similarity_test=prompt_info[concept_ref]["prompt_target_embedding_norm"].float()@image_features_test_norm.T.float()
                similarity_list_train.append(similarity_train.mean(dim=[0,1]).detach().cpu())
                similarity_list_test.append(similarity_test.mean(dim=[0,1]).detach().cpu())                    
            
        x_dict_train[concept_target]=np.stack(similarity_list_train).T
        x_dict_test[concept_target]=np.stack(similarity_list_test).T
        
#     print([(key,value.shape) for key,value in x_dict_train.items()])
#     print([(key,value.shape) for key,value in x_dict_test.items()])        
    #print((x_dict_train["skincon_Vesicle"]/0.02).shape)
    #print((x_dict_train["skincon_Vesicle"]/0.02))
#     print(x_dict_train["skincon_Vesicle"].shape)
#     print(x_dict_train["skincon_Vesicle"]/temp)
#     print(softmax(x_dict_train["skincon_Vesicle"]/temp, axis=1))
        
    if num_ref_concepts>0:
        x_softmax_train = np.array(
            [softmax(x_dict_train[concept] / temp, axis=1)[:, 0] for concept in x_dict_train.keys()]
        ).T

        x_softmax_test = np.array(
            [softmax(x_dict_test[concept] / temp, axis=1)[:, 0] for concept in x_dict_test.keys()]
        ).T
    else:
        x_softmax_train = np.array(
            [(x_dict_train[concept]/temp )[:, 0] for concept in x_dict_train.keys()]
        ).T

        x_softmax_test = np.array(
            [(x_dict_test[concept]/temp )[:, 0] for concept in x_dict_test.keys()]
        ).T        
        

#     xtrain, xtest, ytrain, ytest = train_test_split(
#         x_softmax, y, random_state=8, test_size=0.2, shuffle=True
#     )
#     print(x_softmax_train)
    
    
    if with_validation:
        train_idx_train, train_idx_valid=train_test_split(np.arange(len(y_train)), 
                                                          random_state=random_seed,
                                                          test_size=0.25, shuffle=True)

        clf = SGDClassifier(loss="log_loss", penalty="l1", alpha=alpha)  # , eta0=1e-1)
        clf.fit(x_softmax_train[train_idx_train], y_train[train_idx_train])

        auc_valid = roc_auc_score(y_train[train_idx_valid], clf.predict_proba(x_softmax_train[train_idx_valid])[:, 1])

        auc = roc_auc_score(y_test, clf.predict_proba(x_softmax_test)[:, 1])

    #     print([(key,value.shape) for key,value in x_dict_train.items()])
    #     print([(key,value.shape) for key,value in x_dict_test.items()])
        return clf, auc, clf.predict_proba(x_softmax_test)[:, 1], auc_valid
    
    else:

        clf = SGDClassifier(loss="log_loss", penalty="l1", alpha=alpha)  # , eta0=1e-1)
        clf.fit(x_softmax_train, y_train)

        auc = roc_auc_score(y_test, clf.predict_proba(x_softmax_test)[:, 1])

    #     print([(key,value.shape) for key,value in x_dict_train.items()])
    #     print([(key,value.shape) for key,value in x_dict_test.items()])
        return clf, auc, clf.predict_proba(x_softmax_test)[:, 1], None          

# Ensure epoch and augent option is set correctly.

In [None]:
train_with_best_temp_softmax

In [None]:
def get_subset_index(dataset_name, metadata_all, attribution):
    if dataset_name=="isic_nodup_nooverlap":
        if attribution=="all":
            #pd.Series([True], index=metadata_all)
            subset_idx=np.array([True]*len(metadata_all))
        else:
            collection_65=(metadata_all["collection_65"]==1).values
            
            if attribution=="barcelona_all":
                subset_idx=((metadata_all["attribution"]=="Department of Dermatology, Hospital Clínic de Barcelona")|(metadata_all["attribution"]=="Hospital Clínic de Barcelona")).values
            elif attribution=="mskcc_all":
                subset_idx=((metadata_all["attribution"]=="MSKCC")|(metadata_all["attribution"]=="Memorial Sloan Kettering Cancer Center")).values
            else:
                subset_idx=(metadata_all["attribution"]==attribution).values          
                
            subset_idx=subset_idx&collection_65
        #for attribution in [None]+["barcelona", "vienna", "barcelona_all"]:
        #for attribution in ["all"]+list(metadata_all["attribution"].unique())+["barcelona_all", "mskcc_all"]:
  
                    
    return subset_idx

In [None]:
def get_setting(setting_name):
    if setting_name=="clinical":
        dataset_name="clinical_fd_clean_nodup_nooverlap"
        method_list=["manual_annotation", "automatic", "resnet", "resnet_freeze_backbone"]
        manual_annotation_concepts=skincon_cols
        automatic_concept_info={
            "skincon": skincon_cols,
            "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                          'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
        }
        random_seed_range=range(21,21+1)
        train_test_idx_preset=None
        with_validation=True
        n_epoch=50
        
    elif setting_name=="clinical_manualmatching":
        dataset_name="clinical_fd_clean_nodup_nooverlap"
        method_list=["manual_annotation", "automatic"]
        manual_annotation_concepts=skincon_cols
        automatic_concept_info={
            "skincon": skincon_cols,
            "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                          'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
        }
        random_seed_range=range(21,21+1)
        train_test_idx_preset=None
        with_validation=True
        n_epoch=50  
        
    elif setting_name=="isic_all":
        dataset_name="isic_nodup_nooverlap"
        method_list=["automatic", "resnet", "resnet_freeze_backbone"]
        manual_annotation_concepts=skincon_cols
        automatic_concept_info={
            "skincon": skincon_cols,
            "derm7pt": ['derm7ptconcept_pigment network',
                            'derm7ptconcept_regression structure',
                            'derm7ptconcept_pigmentation',
                            'derm7ptconcept_blue whitish veil',
                            'derm7ptconcept_vascular structures',
                            'derm7ptconcept_streaks',
                            'derm7ptconcept_dots and globules'],
            "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                          'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
        }
        random_seed_range=range(1,20+1)
        train_test_idx_preset=None
        with_validation=True
        n_epoch=20
        
    elif setting_name=="isic_barcelona_vienna":
        dataset_name="isic_nodup_nooverlap"
        method_list=["automatic", "resnet", "resnet_freeze_backbone"]
        # method_list=["automatic"]
        manual_annotation_concepts=skincon_cols
        automatic_concept_info={
            "skincon": skincon_cols,
            "derm7pt": ['derm7ptconcept_pigment network',
                        'derm7ptconcept_regression structure',
                        'derm7ptconcept_pigmentation',
                        'derm7ptconcept_blue whitish veil',
                        'derm7ptconcept_vascular structures',
                        'derm7ptconcept_streaks',
                        'derm7ptconcept_dots and globules'],
            "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                          'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular'],

        }
        train_test_idx_preset=[
            get_subset_index("isic_nodup_nooverlap", 
                         variable_dict["isic_nodup_nooverlap"]["metadata_all"],
                                              "Hospital Clínic de Barcelona"

                        ),
            get_subset_index("isic_nodup_nooverlap", 
                         variable_dict["isic_nodup_nooverlap"]["metadata_all"],
                             "ViDIR Group, Department of Dermatology, Medical University of Vienna"

                        )
        ]

        with_validation=True
        random_seed_range=range(1,1+1)
        n_epoch=20 
        
    elif setting_name=="isic_vienna_barcelona": 
        
        dataset_name="isic_nodup_nooverlap"
        method_list=["automatic", "resnet", "resnet_freeze_backbone"]
        # method_list=["automatic"]
        manual_annotation_concepts=skincon_cols
        automatic_concept_info={
            "skincon": skincon_cols,
            "derm7pt": ['derm7ptconcept_pigment network',
                        'derm7ptconcept_regression structure',
                        'derm7ptconcept_pigmentation',
                        'derm7ptconcept_blue whitish veil',
                        'derm7ptconcept_vascular structures',
                        'derm7ptconcept_streaks',
                        'derm7ptconcept_dots and globules'],
            "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                          'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular'],

        }
        train_test_idx_preset=[
            get_subset_index("isic_nodup_nooverlap", 
                         variable_dict["isic_nodup_nooverlap"]["metadata_all"],
                            "ViDIR Group, Department of Dermatology, Medical University of Vienna"

                        ),
            get_subset_index("isic_nodup_nooverlap", 
                         variable_dict["isic_nodup_nooverlap"]["metadata_all"],
                             "Hospital Clínic de Barcelona"
                             

                        )
        ]

        with_validation=True
        random_seed_range=range(1,1+1)
        n_epoch=20 
        
    return dataset_name, method_list, manual_annotation_concepts, automatic_concept_info,\
            train_test_idx_preset, with_validation, random_seed_range, n_epoch





# dataset_name="derm7pt_derm_nodup"
# method_list=["manual_annotation", "automatic", "resnet", "resnet_freeze_backbone"]
# manual_annotation_concepts=['derm7ptconcept_pigment network',
#                     'derm7ptconcept_regression structure',
#                     'derm7ptconcept_pigmentation',
#                     'derm7ptconcept_blue whitish veil',
#                     'derm7ptconcept_vascular structures',
#                     'derm7ptconcept_streaks',
#                     'derm7ptconcept_dots and globules']
# automatic_concept_info={
#     "skincon": skincon_cols,
#     "derm7pt": ['derm7ptconcept_pigment network',
#                     'derm7ptconcept_regression structure',
#                     'derm7ptconcept_pigmentation',
#                     'derm7ptconcept_blue whitish veil',
#                     'derm7ptconcept_vascular structures',
#                     'derm7ptconcept_streaks',
#                     'derm7ptconcept_dots and globules'],
#     "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
#                   'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
# }
# random_seed_range=range(1,20+1)
# train_test_idx_preset=None
# with_validation=False
# n_epoch=10

setting_name="isic_all"

(dataset_name, method_list, manual_annotation_concepts, automatic_concept_info, 
train_test_idx_preset, with_validation, random_seed_range, n_epoch)=get_setting(setting_name)
random_seed_range=range(20,20+1)
method_list=["resnet", "resnet_freeze_backbone"]
with_validation=True

record_all_list_temp=[]
for random_seed in tqdm.tqdm(random_seed_range):
#     for task in ["melanoma", "malignant"]:
    for task in ["melanoma"]:
        valid_idx_select=variable_dict[dataset_name][f"valid_idx_{task}"]
        y_select=variable_dict[dataset_name][f"y_pos_{task}"][valid_idx_select]
        metadata_select=variable_dict[dataset_name]["metadata_all"][valid_idx_select]
        image_norm_monet_select=variable_dict[dataset_name]["image_features_norm"][valid_idx_select]
        image_norm_vanilla_select=variable_dict[dataset_name]["image_features_vanilla_norm"][valid_idx_select]

        assert len(y_select)==len(metadata_select)==len(image_norm_monet_select)==len(image_norm_vanilla_select)

        if train_test_idx_preset is None:
            if "clinical_fd_clean" in dataset_name:
                train_idx, test_idx = train_test_split(np.arange(len(y_select)), 
                                                       train_size=int(len(y_select)*0.8),
                                                       random_state=random_seed,
                                                       shuffle=True,
                                                       stratify=metadata_select["skincon_Vesicle"].isnull().values)
                ####### This line
                if "manualmatching" in  setting_name:
                    test_idx=test_idx[~metadata_select.iloc[test_idx]["skincon_Vesicle"].isnull().values]            
            else:
                train_idx, test_idx = train_test_split(np.arange(len(y_select)), 
                                                       train_size=int(len(y_select)*0.8),
                                                       random_state=random_seed,
                                                       shuffle=True)
        else:
            assert len(variable_dict[dataset_name][f"y_pos_{task}"])==len(train_test_idx_preset[0])==len(train_test_idx_preset[1])
            train_idx=np.arange(len(y_select))[train_test_idx_preset[0][valid_idx_select]]
            test_idx=np.arange(len(y_select))[train_test_idx_preset[1][valid_idx_select]]    
        
        metadata_select_train=metadata_select.iloc[train_idx]
        metadata_select_test=metadata_select.iloc[test_idx]

        y_select_train=np.array(y_select)[train_idx]
        y_select_test=np.array(y_select)[test_idx] 
        
        image_norm_monet_select_train=image_norm_monet_select[train_idx]
        image_norm_monet_select_test=image_norm_monet_select[test_idx]
        
        image_norm_vanilla_select_train=image_norm_vanilla_select[train_idx]
        image_norm_vanilla_select_test=image_norm_vanilla_select[test_idx]
        
        
        for method in method_list:            
            if method=="manual_annotation":                
                if "clinical_fd_clean" in dataset_name:                
                    metadata_select_train_manual, y_select_train_manual = metadata_select_train[~metadata_select_train["skincon_Vesicle"].isnull()], y_select_train[~metadata_select_train["skincon_Vesicle"].isnull()]
                    metadata_select_test_manual, y_select_test_manual = metadata_select_test[~metadata_select_test["skincon_Vesicle"].isnull()], y_select_test[~metadata_select_test["skincon_Vesicle"].isnull()]
                else:
                    metadata_select_train_manual, y_select_train_manual = metadata_select_train[:], y_select_train[:]
                    metadata_select_test_manual, y_select_test_manual = metadata_select_test[:], y_select_test[:]
                
                for test_mode in ["full", "less_concept", "less_sample"]:
                    for alpha in [0.0001, 0.0005, 0.001, 0.005, 0.01]:
                        if test_mode=="full":   
                            auc,_,y_test_pred=train_using_manual_labels(dataset_name=dataset_name,
                                                                        metadata_train=metadata_select_train_manual,
                                                                        metadata_test=metadata_select_test_manual,
                                                                        y_train=y_select_train_manual,
                                                                        y_test=y_select_test_manual,
                                                                        concept_list=manual_annotation_concepts,
                                                                        alpha=alpha)


                            # print(task, len(y_select), random_seed, method, test_mode, alpha, f"{auc:.3f}")
                            #sdsds                            
                            record_all_list_temp.append({"task": task,
                                                    "num_sample": len(y_select_train_manual)+len(y_select_test_manual),
                                                    "num_sample_pos": y_select_train_manual.sum()+y_select_test_manual.sum(),
                                                    "num_sample_train": len(y_select_train_manual),
                                                    "num_sample_train_pos": y_select_train_manual.sum(),
                                                    "num_sample_test": len(y_select_test_manual),
                                                    "num_sample_test_pos": y_select_test_manual.sum(),
                                                    "random_seed": random_seed,                                            
                                                    "method": method+"_"+test_mode,
                                                    "auc":auc,
                                                    "y_test":y_select_test_manual,
                                                    "y_test_pred":y_test_pred,
                                                    "alpha": alpha,
                                                   })

                            print([(key, str(value)[:20]) for key, value in record_all_list_temp[-1].items()])


                        elif test_mode=="less_concept":
                            if len(manual_annotation_concepts)>40:
                                num_concept_list=[1]+list(range(5, len(manual_annotation_concepts), 5))+[len(manual_annotation_concepts)]
                            else:
                                num_concept_list=np.arange(1,len(manual_annotation_concepts)+1)
                            for num_concept in num_concept_list:
                                manual_annotation_concepts_sampled=np.random.choice(manual_annotation_concepts, 
                                                                                     size=num_concept, 
                                                                                     replace=False, p=None).tolist()                      


                                auc,_,y_test_pred=train_using_manual_labels(dataset_name=dataset_name,
                                                                            metadata_train=metadata_select_train_manual,
                                                                            metadata_test=metadata_select_test_manual,
                                                                            y_train=y_select_train_manual,
                                                                            y_test=y_select_test_manual,
                                                                            concept_list=manual_annotation_concepts_sampled,
                                                                            alpha=0.001)

                                record_all_list_temp.append({"task": task,
                                                        "num_sample": len(y_select_train_manual)+len(y_select_test_manual),
                                                        "num_sample_pos": y_select_train_manual.sum()+y_select_test_manual.sum(),
                                                        "num_sample_train": len(y_select_train_manual),
                                                        "num_sample_train_pos": y_select_train_manual.sum(),
                                                        "num_sample_test": len(y_select_test_manual),
                                                        "num_sample_test_pos": y_select_test_manual.sum(),
                                                        "random_seed": random_seed,                                            
                                                        "method": method+"_"+test_mode,
                                                        "num_concept": num_concept,
                                                        "auc":auc,
                                                        "y_test":y_select_test_manual,
                                                        "y_test_pred":y_test_pred, 
                                                        "alpha": alpha,
                                                       })                            
                                print([(key, str(value)[:20]) for key, value in record_all_list_temp[-1].items()])

                        elif test_mode=="less_sample":
                            if len(metadata_select_train_manual)<1000:                        
                                num_sample_train_list=np.unique(list(range(50, len(metadata_select_train_manual), 50))+[len(metadata_select_train_manual)])
                            elif len(metadata_select_train_manual)<2000:
                                num_sample_train_list=np.unique(list(range(100, len(metadata_select_train_manual), 100))+[len(metadata_select_train_manual)])
                            else:
                                num_sample_train_list=np.unique([100]+ list(range(500, len(metadata_select_train_manual), 500))+[len(metadata_select_train_manual)])

                            for num_sample_train in num_sample_train_list:   

                                train_idx_select=np.random.choice(np.arange(len(metadata_select_train_manual)), size=num_sample_train, replace=False)

                                metadata_select_train_manual_less=metadata_select_train_manual.copy().iloc[train_idx_select]
                                y_select_train_manual_less=y_select_train_manual[train_idx_select]

                                auc,_,y_test_pred=train_using_manual_labels(dataset_name=dataset_name,
                                                                            metadata_train=metadata_select_train_manual_less,
                                                                            metadata_test=metadata_select_test_manual,
                                                                            y_train=y_select_train_manual_less,
                                                                            y_test=y_select_test_manual,
                                                                            concept_list=manual_annotation_concepts,
                                                                            alpha=0.001)                            


                                record_all_list_temp.append({"task": task,
                                                        "num_sample": len(y_select_train_manual_less)+len(y_select_test_manual),
                                                        "num_sample_pos": y_select_train_manual_less.sum()+y_select_test_manual.sum(),
                                                        "num_sample_train": len(y_select_train_manual_less),
                                                        "num_sample_train_pos": y_select_train_manual_less.sum(),
                                                        "num_sample_test": len(y_select_test_manual),
                                                        "num_sample_test_pos": y_select_test_manual.sum(),
                                                        "random_seed": random_seed,                            
                                                        "method": method+"_"+test_mode,
                                                        "auc":auc,
                                                        "y_test":y_select_test_manual,
                                                        "y_test_pred":y_test_pred,                                                        
                                                        "alpha": alpha,
                                                       })
                                print([(key, str(value)[:20]) for key, value in record_all_list_temp[-1].items()])

                        else:
                            raise ValueError(test_mode)
                        
                        
#             def generate_data_and_train_resnet(metadata_select, y_select, train_idx, test_idx, freeze_backbone=False, lr=1e-3):  
#                 train_idx_train, train_idx_valid=train_test_split(train_idx, random_state=random_seed, test_size=0.25, shuffle=True)

#                 train_dataset_resnet = ResNetDataset(image_paths_list=[image_path_dict_all[idx] for idx in metadata_select["index"].iloc[train_idx_train]], 
#                                                      labels=np.array(y_select)[train_idx_train])
#                 train_dataloader_resnet = torch.utils.data.DataLoader(
#                         train_dataset_resnet, batch_size=32, shuffle=True, pin_memory=True,
#                         drop_last=False, num_workers=4)

#                 valid_dataset_resnet = ResNetDataset(image_paths_list=[image_path_dict_all[idx] for idx in metadata_select["index"].iloc[train_idx_valid]], 
#                                                      labels=np.array(y_select)[train_idx_valid])
#                 valid_dataloader_resnet = torch.utils.data.DataLoader(
#                         valid_dataset_resnet, batch_size=32, shuffle=True, pin_memory=True,
#                         drop_last=False, num_workers=4)

#                 test_dataset_resnet = ResNetDataset(image_paths_list=[image_path_dict_all[idx] for idx in metadata_select["index"].iloc[test_idx]], 
#                                                      labels=np.array(y_select)[test_idx])
#                 test_dataloader_resnet = torch.utils.data.DataLoader(
#                         test_dataset_resnet, batch_size=32, shuffle=False, pin_memory=True,
#                         drop_last=False, num_workers=4)            

#                 classifier, auc_best, test_preds=train_classifier(train_dataloader_resnet, 
#                                                       valid_dataloader_resnet, 
#                                                       test_dataloader_resnet, freeze_backbone=freeze_backbone, lr=lr, verbose=False)

#                 return auc_best.item(), test_preds                        
                        
            elif method=="resnet": 
                auc, y_test_pred, auc_valid=generate_data_and_train_resnet(dataloader=variable_dict[dataset_name]["dataloader"], 
                                                                metadata_train=metadata_select_train.copy(),
                                                                y_select_train=y_select_train,
                                                                metadata_test=metadata_select_test.copy(),
                                                                y_select_test=y_select_test,
                                                                random_seed=random_seed,
                                                                freeze_backbone=False,
                                                                verbose=False,
                                                                n_epoch=n_epoch,
                                                               )

                record_all_list_temp.append({"task": task,
                                        "num_sample": len(y_select_train)+len(y_select_test),
                                        "num_sample_pos": y_select_train.sum()+y_select_test.sum(),
                                        "num_sample_train": len(y_select_train),
                                        "num_sample_train_pos": y_select_train.sum(),
                                        "num_sample_test": len(y_select_test),
                                        "num_sample_test_pos": y_select_test.sum(),                                          
                                        "random_seed": random_seed,                                            
                                        "method": method,
                                        "auc":auc,
                                        "auc_valid":auc_valid,
                                        "y_test":y_select_test,
                                        "y_test_pred":y_test_pred,                                             
                                       })  
                print([(key, str(value)[:20]) for key, value in record_all_list_temp[-1].items()])
                
            elif method=="resnet_freeze_backbone": 
                auc, y_test_pred, auc_valid=generate_data_and_train_resnet(dataloader=variable_dict[dataset_name]["dataloader"], 
                                                                metadata_train=metadata_select_train.copy(),
                                                                y_select_train=y_select_train,
                                                                metadata_test=metadata_select_test.copy(),
                                                                y_select_test=y_select_test,
                                                                random_seed=random_seed,
                                                                freeze_backbone=True,
                                                                verbose=False,
                                                                n_epoch=n_epoch,                                                                
                                                               )


                record_all_list_temp.append({"task": task,
                                        "num_sample": len(y_select_train)+len(y_select_test),
                                        "num_sample_pos": y_select_train.sum()+y_select_test.sum(),
                                        "num_sample_train": len(y_select_train),
                                        "num_sample_train_pos": y_select_train.sum(),
                                        "num_sample_test": len(y_select_test),
                                        "num_sample_test_pos": y_select_test.sum(),                                          
                                        "random_seed": random_seed,                                            
                                        "method": method,
                                        "auc":auc,
                                        "auc_valid":auc_valid,
                                        "y_test":y_select_test,
                                        "y_test_pred":y_test_pred,                                             
                                       })   
                print([(key, str(value)[:20]) for key, value in record_all_list_temp[-1].items()])
                

            elif method=="automatic":
                for concept_list_type, concept_list_curated in automatic_concept_info.items():                  
                    for trained in ["monet", "vanilla"]:
                        if trained=="monet":
                            image_norm_select_train=image_norm_monet_select_train
                            image_norm_select_test=image_norm_monet_select_test
                        elif trained=="vanilla":
                            image_norm_select_train=image_norm_vanilla_select_train
                            image_norm_select_test=image_norm_vanilla_select_test
                        else:
                            raise ValueError(trained)
                            
                        #for test_mode in ["full", "less_concept", "less_reference", "less_sample"]:                
                        for test_mode in ["full", "less_concept", "less_sample"]:                
                            for alpha in [0.0001, 0.0005, 0.001, 0.005, 0.01]:
                                for temp in [0.02]:
                                    if test_mode=="full":    
                                        clf, auc, y_test_pred, auc_valid = train_with_best_temp_softmax(
                                                        image_features_train_norm=image_norm_select_train,
                                                         image_features_test_norm=image_norm_select_test, 
                                                         y_train=y_select_train, 
                                                         y_test=y_select_test, 
                                                         concept_list=concept_list_curated, 
                                                         prompt_info=variable_dict[dataset_name]["prompt_info"],
                                                         prompt_info_vanilla=variable_dict[dataset_name]["prompt_info_vanilla"],
                                                         temp=temp, 
                                                         concept_reference_dict=concept_reference_dict,
                                                         random_seed=random_seed,
                                                         with_validation=with_validation,                                                                                    
                                                         num_ref_concepts=100, 
                                                         alpha=alpha)  

                                        record_all_list_temp.append({"task": task,
                                                                "num_sample": len(y_select_train)+len(y_select_test),
                                                                "num_sample_pos": y_select_train.sum()+y_select_test.sum(),
                                                                "num_sample_train": len(y_select_train),
                                                                "num_sample_train_pos": y_select_train.sum(),
                                                                "num_sample_test": len(y_select_test),
                                                                "num_sample_test_pos": y_select_test.sum(),                                                      
                                                                "random_seed": random_seed,                                            
                                                                "method": method+"_"+concept_list_type+"_"+trained+"_"+test_mode,
                                                                "auc": auc,
                                                                "auc_valid":auc_valid,
                                                                "y_test":y_select_test,
                                                                "y_test_pred":y_test_pred,                                                                       
                                                                "alpha": alpha,
                                                                "temp": temp,
                                                                "clf":clf
                                                               })  
                                        print([(key, str(value)[:20]) for key, value in record_all_list_temp[-1].items()])

                                    elif test_mode=="less_concept":
                                        if len(concept_list_curated)>40:
                                            num_concept_list=[1]+list(range(5, len(concept_list_curated), 5))+[len(concept_list_curated)]
                                        else:
                                            num_concept_list=list(range(1, len(concept_list_curated)+1))

                                        for num_concept in num_concept_list:
                                            concept_list_curated_sampled=np.random.choice(concept_list_curated, 
                                                                                         size=num_concept, 
                                                                                         replace=False, p=None)

                                            clf, auc, y_test_pred, auc_valid = train_with_best_temp_softmax(
                                                            image_features_train_norm=image_norm_select_train,
                                                             image_features_test_norm=image_norm_select_test, 
                                                             y_train=y_select_train, 
                                                             y_test=y_select_test, 
                                                             concept_list=concept_list_curated_sampled, 
                                                             prompt_info=variable_dict[dataset_name]["prompt_info"],
                                                             prompt_info_vanilla=variable_dict[dataset_name]["prompt_info_vanilla"],
                                                             temp=temp, 
                                                             concept_reference_dict=concept_reference_dict,
                                                             random_seed=random_seed,
                                                             with_validation=with_validation,                                        
                                                             num_ref_concepts=100, 
                                                             alpha=alpha)  

                                            record_all_list_temp.append({"task": task,
                                                                    "num_sample": len(y_select_train)+len(y_select_test),
                                                                    "num_sample_pos": y_select_train.sum()+y_select_test.sum(),
                                                                    "num_sample_train": len(y_select_train),
                                                                    "num_sample_train_pos": y_select_train.sum(),
                                                                    "num_sample_test": len(y_select_test),
                                                                    "num_sample_test_pos": y_select_test.sum(),                                                      
                                                                    "random_seed": random_seed,                                            
                                                                    "method": method+"_"+concept_list_type+"_"+trained+"_"+test_mode,
                                                                    "auc": auc,
                                                                    "auc_valid":auc_valid,
                                                                    "num_concept": num_concept,
                                                                    "y_test":y_select_test,
                                                                    "y_test_pred":y_test_pred,  
                                                                    "alpha": alpha,
                                                                    "temp": temp,
                                                                    "clf":clf
                                                                   })
                                            print([(key, str(value)[:20]) for key, value in record_all_list_temp[-1].items()])


                                    elif test_mode=="less_reference":
                                        for num_ref_concepts in range(0, 5+1):
                                            clf, auc, y_test_pred, auc_valid = train_with_best_temp_softmax(
                                                            image_features_train_norm=image_norm_select_train,
                                                             image_features_test_norm=image_norm_select_test, 
                                                             y_train=y_select_train, 
                                                             y_test=y_select_test, 
                                                             concept_list=concept_list_curated, 
                                                             prompt_info=variable_dict[dataset_name]["prompt_info"],
                                                             prompt_info_vanilla=variable_dict[dataset_name]["prompt_info_vanilla"],
                                                             temp=temp, 
                                                             concept_reference_dict=concept_reference_dict,
                                                             random_seed=random_seed,
                                                             with_validation=with_validation,
                                                             num_ref_concepts=num_ref_concepts, 
                                                             alpha=alpha)                                     


                                            record_all_list_temp.append({"task": task,
                                                                    "num_sample": len(y_select_train)+len(y_select_test),
                                                                    "num_sample_pos": y_select_train.sum()+y_select_test.sum(),
                                                                    "num_sample_train": len(y_select_train),
                                                                    "num_sample_train_pos": y_select_train.sum(),
                                                                    "num_sample_test": len(y_select_test),
                                                                    "num_sample_test_pos": y_select_test.sum(),                                                      
                                                                    "random_seed": random_seed,                                            
                                                                    "method": method+"_"+concept_list_type+"_"+trained+"_"+test_mode,
                                                                    "auc": auc,
                                                                    "auc_valid":auc_valid,
                                                                    "num_ref_concepts": num_ref_concepts,
                                                                    "y_test":y_select_test,
                                                                    "y_test_pred":y_test_pred,  
                                                                    "alpha": alpha,
                                                                    "temp": temp,                                                                         
                                                                    "clf":clf
                                                                   })  
                                            print([(key, str(value)[:20]) for key, value in record_all_list_temp[-1].items()])

                                    elif test_mode=="less_sample":
                                        for sample_prop in [0.1, 0.2, 0.4, 0.6, 0.8, 1]:

                                            less_sample_idx=np.random.choice(np.arange(len(image_norm_select_train)), 
                                                                           size=int(len(image_norm_select_train)*sample_prop), replace=False)                                    

                                                
                                            if y_select_train[less_sample_idx].sum()==0:
                                                continue
                                                
                                            clf, auc, y_test_pred, auc_valid = train_with_best_temp_softmax(
                                                            image_features_train_norm=image_norm_select_train[less_sample_idx],
                                                             image_features_test_norm=image_norm_select_test, 
                                                             y_train=y_select_train[less_sample_idx], 
                                                             y_test=y_select_test, 
                                                             concept_list=concept_list_curated, 
                                                             prompt_info=variable_dict[dataset_name]["prompt_info"],
                                                             prompt_info_vanilla=variable_dict[dataset_name]["prompt_info_vanilla"],
                                                             temp=temp, 
                                                             concept_reference_dict=concept_reference_dict,                                        
                                                             random_seed=random_seed,
                                                             with_validation=with_validation,                                        
                                                             num_ref_concepts=100, 
                                                             alpha=alpha) 

                                            record_all_list_temp.append({"task": task,
                                                                    "num_sample": len(y_select_train[less_sample_idx])+len(y_select_test),
                                                                    "num_sample_pos": y_select_train[less_sample_idx].sum()+y_select_test.sum(),
                                                                    "num_sample_train": len(y_select_train[less_sample_idx]),
                                                                    "num_sample_train_pos": y_select_train[less_sample_idx].sum(),
                                                                    "num_sample_test": len(y_select_test),
                                                                    "num_sample_test_pos": y_select_test.sum(),                                                      
                                                                    "random_seed": random_seed,                                            
                                                                    "method": method+"_"+concept_list_type+"_"+trained+"_"+test_mode,
                                                                    "auc": auc,
                                                                    "auc_valid": auc_valid,
                                                                    "sample_prop": sample_prop,
                                                                    "y_test":y_select_test,
                                                                    "y_test_pred":y_test_pred, 
                                                                    "alpha": alpha,
                                                                    "temp": temp,                                                                         
                                                                    "clf":clf
                                                                   })
                                            print([(key, str(value)[:20]) for key, value in record_all_list_temp[-1].items()])
        torch.save(record_all_list_temp, 
               f"logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_{task}_{random_seed}.pt")

In [None]:
cbm_isic_nodup_nooverlap_all_1007_malignant_4.pt

In [None]:
torch.save(record_all_list_temp, 
           f"logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_{random_seed_set}.pt")

In [None]:
random_seed_set=2
random_seed_range=range(random_seed_set,random_seed_set+1)

record_all_list_temp=[]

In [None]:
f"logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_{random_seed_set}.pt"

In [None]:
record_all_list_temp

In [None]:
np.random.choice(np.arange(1,10+1))

# recover confounder

In [None]:
variable_dict[dataset_name]["prompt_info"].keys()

In [None]:
dataset_name="isic_nodup_nooverlap"
method_list=["automatic"]
# method_list=["automatic"]
manual_annotation_concepts=skincon_cols

train_test_idx_preset=[
    get_subset_index("isic_nodup_nooverlap", 
                 variable_dict["isic_nodup_nooverlap"]["metadata_all"],
                     "ViDIR Group, Department of Dermatology, Medical University of Vienna"
#                                      "Hospital Clínic de Barcelona"

                ),
    get_subset_index("isic_nodup_nooverlap", 
                 variable_dict["isic_nodup_nooverlap"]["metadata_all"],
                     "Hospital Clínic de Barcelona"
#                     "ViDIR Group, Department of Dermatology, Medical University of Vienna"

                )
]

with_validation=True
random_seed_range=range(1,1000+1)
n_epoch=20

record_all_list_temp=[]

# for concept_name in ["red", "skincon_Ulcer", "derm7ptconcept_blue whitish veil", "skincon_Erosion", "skincon_Warty/Papillomatous", "pinkish", 
#                  "skincon_Blue", "nail", "derm7ptconcept_vascular structures", "skincon_Crust", "skincon_Pedunculated", 
#                  "derm7ptconcept_pigment network", "skincon_Xerosis", "skincon_Salmon", "skincon_White(Hypopigmentation)", "orange sticker"]:
for concept_name in ["red"]:
    
    
#     if concept_name not in concept_reference_dict.keys():
#         concept_reference_dict[concept_name]=['cbm_clean',
#                                               'cbm_smooth',
#                                               'cbm_Healthy',
#                                               'cbm_normal',
#                                               'cbm_soft',
#                                               'cbm_flat']


    
    
    for task in ["malignant"]:
        valid_idx_select=variable_dict[dataset_name][f"valid_idx_{task}"]
        y_select=variable_dict[dataset_name][f"y_pos_{task}"][valid_idx_select]
        metadata_select=variable_dict[dataset_name]["metadata_all"][valid_idx_select]
        image_norm_monet_select=variable_dict[dataset_name]["image_features_norm"][valid_idx_select]
        image_norm_vanilla_select=variable_dict[dataset_name]["image_features_vanilla_norm"][valid_idx_select]

        assert len(y_select)==len(metadata_select)==len(image_norm_monet_select)==len(image_norm_vanilla_select)


        for random_seed in tqdm.tqdm(random_seed_range):


            concept_sampled=np.random.choice(['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                              'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular'], 
                             size=np.random.choice(np.arange(1,10+1)), 
                             replace=False).tolist()


            automatic_concept_info={
                "without_concept": concept_sampled,
                "with_concept": concept_sampled+[concept_name]
            }        


            if train_test_idx_preset is None:
                if "clinical_fd_clean" in dataset_name:
                    train_idx, test_idx = train_test_split(np.arange(len(y_select)), 
                                                           train_size=int(len(y_select)*0.8),
                                                           random_state=random_seed,
                                                           shuffle=True,
                                                           stratify=metadata_select["skincon_Vesicle"].isnull().values)
                    ####### This line
                    test_idx=test_idx[~metadata_select.iloc[test_idx]["skincon_Vesicle"].isnull().values]            
                else:
                    train_idx, test_idx = train_test_split(np.arange(len(y_select)), 
                                                           train_size=int(len(y_select)*0.8),
                                                           random_state=random_seed,
                                                           shuffle=True)
            else:
                assert len(variable_dict[dataset_name][f"y_pos_{task}"])==len(train_test_idx_preset[0])==len(train_test_idx_preset[1])
                train_idx=np.arange(len(y_select))[train_test_idx_preset[0][valid_idx_select]]
                test_idx=np.arange(len(y_select))[train_test_idx_preset[1][valid_idx_select]]    

            metadata_select_train=metadata_select.iloc[train_idx]
            metadata_select_test=metadata_select.iloc[test_idx]

            y_select_train=np.array(y_select)[train_idx]
            y_select_test=np.array(y_select)[test_idx] 

            image_norm_monet_select_train=image_norm_monet_select[train_idx]
            image_norm_monet_select_test=image_norm_monet_select[test_idx]

            image_norm_vanilla_select_train=image_norm_vanilla_select[train_idx]
            image_norm_vanilla_select_test=image_norm_vanilla_select[test_idx]


            for method in method_list:            


                if method=="automatic":
                    for concept_list_type, concept_list_curated in automatic_concept_info.items():                  
                        for trained in ["monet"]:
                            if trained=="monet":
                                image_norm_select_train=image_norm_monet_select_train
                                image_norm_select_test=image_norm_monet_select_test
                            elif trained=="vanilla":
                                image_norm_select_train=image_norm_vanilla_select_train
                                image_norm_select_test=image_norm_vanilla_select_test
                            else:
                                raise ValueError(trained)

                            #for test_mode in ["full", "less_concept", "less_reference", "less_sample"]:                
                            for test_mode in ["full"]:                
                                if test_mode=="full":
                                    for alpha in [0.001]:
                                        for temp in [0.02]:
                                            clf, auc, y_test_pred, auc_valid = train_with_best_temp_softmax(
                                                            image_features_train_norm=image_norm_select_train,
                                                             image_features_test_norm=image_norm_select_test, 
                                                             y_train=y_select_train, 
                                                             y_test=y_select_test, 
                                                             concept_list=concept_list_curated, 
                                                             prompt_info=variable_dict[dataset_name]["prompt_info"],
                                                             prompt_info_vanilla=variable_dict[dataset_name]["prompt_info_vanilla"],
                                                             temp=temp, 
                                                             random_seed=random_seed,
                                                             concept_reference_dict=None,
                                                             with_validation=with_validation,   
                                                             num_ref_concepts=100, 
                                                             alpha=alpha)  

                                            record_all_list_temp.append({"task": task,
                                                                    "num_sample": len(y_select_train)+len(y_select_test),
                                                                    "num_sample_pos": y_select_train.sum()+y_select_test.sum(),
                                                                    "num_sample_train": len(y_select_train),
                                                                    "num_sample_train_pos": y_select_train.sum(),
                                                                    "num_sample_test": len(y_select_test),
                                                                    "num_sample_test_pos": y_select_test.sum(),                                                      
                                                                    "random_seed": random_seed,                                            
                                                                    "method": method+"_"+concept_list_type+"_"+trained+"_"+test_mode,
                                                                    "concept_name":concept_name,
                                                                    "auc": auc,
                                                                    "auc_valid":auc_valid,
                                                                    "y_test":y_select_test,
                                                                    "y_test_pred":y_test_pred,                                                                       
                                                                    "alpha": alpha,
                                                                    "temp": temp,
                                                                    "clf":clf
                                                                   })  
                                            print([(key, str(value)[:20]) for key, value in record_all_list_temp[-1].items()])


In [None]:
import scipy

In [None]:
record_all_list_temp_df=pd.DataFrame(record_all_list_temp)
record_all_list_temp_df["num_concept"]=record_all_list_temp_df.apply(lambda x: (x["clf"].coef_).shape[1], axis=1)
record_all_list_temp_df["last_weight"]=record_all_list_temp_df.apply(lambda x: (x["clf"].coef_)[0,-1], axis=1)
record_all_list_temp_df["absolute_weight_mean"]=record_all_list_temp_df\
.apply(lambda x: (   np.abs((x["clf"].coef_)[0,:-1])  if "_with_" in x["method"] else np.abs((x["clf"].coef_)[0,:-1])
                 
                 ), axis=1)

In [None]:
scipy.stats.mstats.ttest_onesamp()

In [None]:
len(record_all_list_temp_df)

In [None]:
print(record_all_list_temp_df.groupby(["method", "concept_name"])\
.apply(lambda x:   
       
scipy.stats.mstats.ttest_onesamp(np.hstack((x["clf"].map(lambda x: x.coef_[0,-1]))), 
                                 popmean=0,
                                 alternative="less")[1]       ))

In [None]:
20 
pinkish,  vascular structure, skincon_Pedunculated


In [None]:
scipy.stats.mstats.ttest_onesamp

In [None]:
print(record_all_list_temp_df.groupby(["method", "concept_name"]).mean())
print()
print()
print(record_all_list_temp_df.groupby(["method"]).std())

In [None]:
# Barcelona -> Vienna
with AUC 0.739 (+-0.032) -> 0.642 (+-0.131)
without AUC 0.708 (+-0.079) -> 0.687 (+-0.098)
2.027 (+-1.800) / compared to 2.084 (+-2.716)

# Vienna -> Barcelona 

with AUC 0.758 (+-0.070) -> 0.614 (+-0.128)
without AUC 0.728 (+-0.076) -> 0.686 (+-0.094)
-3.570 (+-1.800) / compared to 2.550 (+-3.272)

In [None]:
record_all_list_temp_df.groupby(["method"]).apply(lambda x: np.hstack(x["absolute_weight_mean"].values).mean())

In [None]:
record_all_list_temp_df.groupby(["method"]).apply(lambda x: np.hstack(x["absolute_weight_mean"].values).mean())

In [None]:
record_all_list_temp_df.groupby(["method"]).apply(lambda x: np.hstack(x["absolute_weight_mean"].values).std())

In [None]:
record_all_list_temp_df=pd.DataFrame(record_all_list_temp)
record_all_list_temp_df[~record_all_list_temp_df["method"].str.contains("less")]\
.fillna(-9).groupby(["task", "method"]).mean()

In [None]:
pd.DataFrame(record_all_list_temp).fillna("null").groupby(["task", "method"]).mean()

In [None]:
record_all_list_temp_df=pd.DataFrame(record_all_list_temp)
record_all_list_temp_df[~record_all_list_temp_df["method"].str.contains("less")].fillna("null").groupby(["task", "method"]).mean()

In [None]:
record_all_list_temp_df

In [None]:
concept_reference_dict

In [None]:
torch.save(
    record_all_list, log_dir/"experiment_results"/f"cbm_{dataset_name}_0829.pt"
)

In [None]:
!ls logs/experiment_results/cbm* -trl

In [None]:
manual_annotaion_full

In [None]:
# record_all_list_new=[]

# for record in record_all_list_dict["derm7pt_derm_nodup"]:
#     record["method"]=record["method"].replace("annotaion","annotation")
#     record_all_list_new.append(record)

In [None]:
# torch.save(record_all_list_new, "logs/experiment_results/cbm_derm7pt_derm_nodup_1001.pt")

In [None]:
!ls logs/experiment_results/ -lrt

In [None]:
torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_11.pt", map_location="cpu")+\
torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_2.pt", map_location="cpu")+\
torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_6.pt", map_location="cpu")

In [None]:
cbm_isic_nodup_nooverlap_all_1007_malignant_2

In [None]:
import torch

In [None]:
record_all_list_dict={
    "clinical_fd_clean_nodup_nooverlap": torch.load("logs/experiment_results/cbm_clinical_fd_clean_nodup_nooverlap_1005.pt", map_location="cpu"),
    
    "clinical_fd_clean_nodup_nooverlap_manualmatch": torch.load("logs/experiment_results/cbm_clinical_fd_clean_nodup_nooverlap_manualmatch_1005.pt", map_location="cpu"),
    "clinical_fd_clean_nodup_nooverlap_with_validation_automatic_manual": torch.load("logs/experiment_results/cbm_clinical_fd_clean_nodup_nooverlap_1007_automatic_manual_with_validation.pt", map_location="cpu"),    
    
    "clinical_fd_clean_nodup_nooverlap_with_validation_seed2040": torch.load("logs/experiment_results/cbm_clinical_fd_clean_nodup_nooverlap_1008_all_methods_with_validation_seed20_40.pt", map_location="cpu"),
    "clinical_fd_clean_nodup_nooverlap_with_validation_manualmatch_seed2040": torch.load("logs/experiment_results/cbm_clinical_fd_clean_nodup_nooverlap_1008_manualmatch_with_validation_seed20_40.pt", map_location="cpu"),    
    "derm7pt_derm_nodup": torch.load("logs/experiment_results/cbm_derm7pt_derm_nodup_1001.pt", map_location="cpu"),
    "cbm_isic_nodup_nooverlap_all": torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1005.pt", map_location="cpu"),
    "cbm_isic_barcelona_vienna": torch.load("logs/experiment_results/cbm_isic_barcelona_vienna_1006.pt", map_location="cpu"),
    "cbm_isic_vienna_barcelona": torch.load("logs/experiment_results/cbm_isic_vienna_barcelona_1006.pt", map_location="cpu"),
    "cbm_isic_nodup_nooverlap_all_automatic_with_validation": torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_automatic_with_validation.pt", map_location="cpu"),

}

In [None]:
# record_all_list_dict["cbm_isic_nodup_nooverlap_all_more"]=\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_2.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_3.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_4.pt", map_location="cpu")+\
# torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_5.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_6.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_7.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_8.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_9.pt", map_location="cpu")+\
# torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_10.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_11.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_12.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_13.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_14.pt", map_location="cpu")+\
# torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_15.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_16.pt", map_location="cpu")+\
# torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_17.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_18.pt", map_location="cpu")+\
# torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_19.pt", map_location="cpu")+\
# torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_20.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_3.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_4.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_5.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_7.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_8.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_10.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_12.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_13.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_14.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_15.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_16.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_17.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_18.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_19.pt", map_location="cpu")+\
# # torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_19.pt", map_location="cpu")
# # malignant 20
# # melanoma 2, 6, 9, 11, 20

In [None]:
cbm_isic_nodup_nooverlap_all_1007_melanoma_20.pt
cbm_isic_nodup_nooverlap_all_1007_malignant_20.pt

In [None]:
cbm_isic_nodup_nooverlap_all

In [None]:
record_all_list_dict["cbm_isic_nodup_nooverlap_all_more"]=\
torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_5.pt", map_location="cpu")+\
torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_10.pt", map_location="cpu")+\
torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_15.pt", map_location="cpu")+\
torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_17.pt", map_location="cpu")+\
torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_19.pt", map_location="cpu")+\
torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_20.pt", map_location="cpu")+\
torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_20.pt", map_location="cpu")
# malignant 20
# melanoma 2, 6, 9, 11, 20

In [None]:
import glob

In [None]:
for path_name in sorted(glob.glob("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_malignant_*.pt"),key=lambda x: int(x.split('_')[-1][:-3])):
    loaded=torch.load(path_name,map_location="cpu")
    df_temp=pd.DataFrame(loaded)
    print(path_name, df_temp[df_temp["method"]=="resnet"][["random_seed", "task"]].value_counts())

In [None]:
print(path_name, df_temp[df_temp["method"]=="resnet"][["random_seed", "task"]].value_counts())

In [None]:
df_temp=pd.DataFrame(record_all_list_dict["cbm_isic_nodup_nooverlap_all"])
print(df_temp[df_temp["method"]=="resnet"]["random_seed"].value_counts())

In [None]:
for path_name in sorted(glob.glob("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_*.pt"),key=lambda x: int(x.split('_')[-1][:-3])):
    loaded=torch.load(path_name,map_location="cpu")
    df_temp=pd.DataFrame(loaded)
    print(path_name, df_temp[df_temp["method"]=="resnet"]["random_seed"].value_counts())

In [None]:
record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_automatic_manual_seed2040"]=torch.load("logs/experiment_results/cbm_clinical_fd_clean_nodup_nooverlap_1007_automatic_manual_with_validation_seed20_40.pt", map_location="cpu")



In [None]:
setting_name

In [None]:
len(record_all_list_temp)

In [None]:
for record in record_all_list_dict["clinical_fd_clean_nodup_nooverlap_manualmatch"]:
    if "resnet" in record["method"]:
        record_all_list_temp.append(record)

In [None]:
len(record_all_list_temp)

In [None]:
# torch.save(record_all_list_temp, 
#            "logs/experiment_results/cbm_clinical_fd_clean_nodup_nooverlap_manualmatch_1005.pt")

In [None]:
df_temp=pd.DataFrame(record_all_list_dict["cbm_isic_nodup_nooverlap_all"])

In [None]:
df_temp=pd.DataFrame(
    torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_melanoma_4.pt", map_location="cpu")
    #record_all_list_dict["cbm_isic_nodup_nooverlap_all_more"]

)
df_temp.groupby(["task","method", "random_seed"]).apply(len)#[["task","method", "random_seed"]].value_counts()

In [None]:
# melanoma 2, 6, 9, 11, 20

In [None]:
cbm_isic_nodup_nooverlap_all_automatic_with_validation

In [None]:
record_all_list_df=pd.DataFrame(record_all_list_dict["clinical_fd_clean_nodup_nooverlap"])

record_all_list_df_grouped=record_all_list_df.fillna("null").groupby(["task", "method", "num_concept", "num_sample_train"]).median()

In [None]:
record_all_list_df_grouped_resetindex=record_all_list_df_grouped.reset_index()
record_all_list_df_grouped_resetindex[record_all_list_df_grouped_resetindex["method"].str.contains("full")]

In [None]:
record_all_list_df_grouped_resetindex[~record_all_list_df_grouped_resetindex["method"].str.contains("less")]

In [None]:
pd.set_option('display.max_rows', 500)

In [None]:
record_all_list_df_grouped

In [None]:
variable_dict[dataset_name].keys()

In [None]:
pd.DataFrame(record_all_list).groupby(["task", "is_clean", "method"]).median()

In [None]:
malignancy
Use negative 0.811498
Use template 0.812509
No reference 0.804630

melanoma
Use negative 0.907219
Use template 0.877861
No reference 0.832318

In [None]:
for key, value in concept_dict.items():
    print(key+":",", ".join(value[1:]))

In [None]:
record_all_df=pd.DataFrame(record_all_list)

In [None]:
record_all_df.groupby(["task", "is_clean", "method"]).mean()

In [None]:
#record_all_df[record_all_df["method"].str.contains("skincon_manual")].groupby(["task", "is_clean", "method", "num_concept"]).mean()
#record_all_df[record_all_df["method"].str.contains("skincon_manual")].groupby(["task", "is_clean", "method", "num_sample_train"]).mean()
#record_all_df[record_all_df["method"].str.contains("automatic_skincon_monet")]
#record_all_df[record_all_df["method"].str.contains("automatic_skincon_monet")].groupby(["task", "is_clean", "method"]).mean()

# plotting

In [None]:
import scipy

In [None]:
from matplotlib import gridspec
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

In [None]:
# !cp /System/Library/Fonts/Supplemental ~/.local/share/fonts/
# rm -fr ~/.cache/matplotlib
from matplotlib import font_manager
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.edgecolor']='1.0'
plt.rcParams['legend.framealpha']=0

# https://github.com/dsc/colorbrewer-python/blob/master/colorbrewer.py

Set1 = {
    3: [[228,26,28], [55,126,184], [77,175,74]],
    4: [[228,26,28], [55,126,184], [77,175,74], [152,78,163]],
    5: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0]],
    6: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51]],
    7: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40]],
    8: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191]],
    9: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191], [153,153,153]],
}

Paired = {
    3: [(166,206,227), [31,120,180], [178,223,138]],
    4: [[166,206,227], [31,120,180], [178,223,138], [51,160,44]],
    5: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153]],
    6: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28]],
    7: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111]],
    8: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0]],
    9: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214]],
    10: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154]],
    11: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153]],
    12: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153], [177,89,40]]
}

color_qual_7=['#F53345',
            '#87D303',
            '#04CBCC',
            '#8650CD',
            (160/256, 95/256, 0),
            '#F5A637',              
            '#DBD783',            
             ]

pd.set_option('display.max_rows', 500)

In [None]:
def shorten_concept_name(concept_name, strict=True):
    if concept_name.startswith("cbm_"):
        short_name=concept_name[4:]
    elif concept_name.startswith("derm7ptconcept_"):
        short_name=concept_name[15:]
        short_name=short_name[0].upper()+short_name[1:]
    else:
        raise ValueError(concept_name)
        
    return short_name

In [None]:
automatic_concept_info

In [None]:
def shorten_method_name(method_name):
#     if method_name=="automatic_monet_full":
#         short_name="MONET+CBM (Curated)"
    if method_name=="automatic_curated_monet_full":
        short_name="MONET+CBM (Curated)"       
    elif method_name=="automatic_curated_vanilla_full":
        short_name="CLIP+CBM (Curated)"
    elif method_name=="automatic_skincon_monet_full":
        short_name="MONET+CBM (SkinCon)"   
#     elif method_name=="skincon_manual":
#         short_name="Manual Label (SkinCon)"
    elif method_name=="manual_annotation_full":
        short_name="Manual Label"        
    elif method_name=="resnet":
        short_name="Supervised (ResNet-50)"
    elif method_name=="resnet_freeze_backbone":
        short_name="Linear probing (ResNet-50)"
#     elif method_name=="automatic_vanilla_full":
#         short_name="CLIP+CBM"
    else:
        raise ValueError(method_name)
        
    return short_name

def very_shorten_method_name(method_name):
#     if method_name=="automatic_monet_full":
#         short_name="MONET+CBM (Curated)"
    if method_name=="automatic_curated_monet_full":
        short_name="MONET+CBM"       
    elif method_name=="automatic_curated_vanilla_full":
        short_name="CLIP+CBM"
    elif method_name=="automatic_skincon_monet_full":
        short_name="MONET+CBM"   
#     elif method_name=="skincon_manual":
#         short_name="Manual Label (SkinCon)"
    elif method_name=="manual_annotation_full":
        short_name="Manual Label"        
    elif method_name=="resnet":
        short_name="Supervised"
    elif method_name=="resnet_freeze_backbone":
        short_name="Linear probing"
#     elif method_name=="automatic_vanilla_full":
#         short_name="CLIP+CBM"
    else:
        raise ValueError(method_name)
        
    return short_name

In [None]:
record_all_df_perf["method"].value_counts()

In [None]:
(method_name, record_all_df_perf_filtered_pvalue.loc[method_name][task])

In [None]:
#record_all_list=torch.load(f=log_dir/"experiment_results"/"cbm_complete_new.pt")

In [None]:
variable_dict.keys()

In [None]:
record_all_list_dict.keys()

In [None]:
clinical_fd_clean_nodup_nooverlap_manualmatch

In [None]:
def plot_cbm_results_oldv2(
    dataset_name,
    main_method_list,
    manual_annotation_concepts_name,
    automatic_concept_info,
    automatic_concept_info_plotorder,
    record_all_list,
    method_weight,
    stage_list,
    subpanel_offset,
    task_suffix,
    show_pval,
    stack_mode,
    alpha,
    temp,

):


    if "original"=="original":
        fig = plt.figure(figsize=(3*10, 3*(3 + 2.5 + 4 + 3 + 0.35*3)))

        box1 = gridspec.GridSpec(4,1,
                                 height_ratios=[3, 2.5, 4, 3],
                                 wspace=0.0,
#                                  hspace=0.35
                                 hspace=0.3
                                )

    elif stack_mode=="suggestion":  

        fig = plt.figure(figsize=(3*10, 3*(3 + 2.5 + 4 + 6 + 0.35*3)))

        box1 = gridspec.GridSpec(4,1,
                                 height_ratios=[3, 2.5, 4, 6],
                                 wspace=0.0,
                                 hspace=0.35)    


    # temp array([  nan, 0.02 , 0.01 , 0.005])
    # alpha array([0.001 , 0.0001,    nan])


    axd={}
    for idx1, stage in enumerate(["overview", "skincon", "performance", "weight"]):
        if stage=="overview":
            plot_key=stage
            ax=plt.Subplot(fig, box1[idx1])
            fig.add_subplot(ax) 
            axd[plot_key]=ax
        elif stage=="empty":
            plot_key=stage
            ax=plt.Subplot(fig, box1[idx1])
            fig.add_subplot(ax) 
            axd[plot_key]=ax        
        elif stage=="performance":
            box2 = gridspec.GridSpecFromSubplotSpec(1, 4,
                            subplot_spec=box1[idx1], 
                            width_ratios=[0.1, 1, 0.1, 1], wspace=0., hspace=0.)        
            for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
    #             elif investigation_type=="statistics":
                plot_key=f"{stage}_{task}"
                ax=plt.Subplot(fig, box2[idx2])
                fig.add_subplot(ax)
                axd[plot_key]=ax  

        elif stage=="weight":
            if "original"=="original":
                box2 = gridspec.GridSpecFromSubplotSpec(1, 4,
                                subplot_spec=box1[idx1], 
                                width_ratios=[0.15, 1, 0.15, 1], wspace=0.1, hspace=0.)        
                for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
        #             elif investigation_type=="statistics":
                    plot_key=f"{stage}_{task}"
                    ax=plt.Subplot(fig, box2[idx2])
                    fig.add_subplot(ax)
                    axd[plot_key]=ax  

            elif stack_mode=="suggestion":

                box2 = gridspec.GridSpecFromSubplotSpec(2, 2,
                                subplot_spec=box1[idx1], 
                                width_ratios=[0.07, 1], wspace=0.1, hspace=0.3)        

                print(box2)
                for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
        #             elif investigation_type=="statistics":
                    plot_key=f"{stage}_{task}"
                    ax=plt.Subplot(fig, box2[idx2])
                    fig.add_subplot(ax)
                    axd[plot_key]=ax              

        elif stage=="skincon":
            box2 = gridspec.GridSpecFromSubplotSpec(1, 8,
                            subplot_spec=box1[idx1], width_ratios=[0.15, 1, 0.15, 1,     0.27, 1, 0.15, 1], wspace=0.0, hspace=0.)        
            for idx2, variable in enumerate(["empty", "malignant_num_concept", "empty1", "malignant_num_sample", 
                                             "empty2", "melanoma_num_concept", "empty3", "melanoma_num_sample"]):
    #             elif investigation_type=="statistics":
                plot_key=f"{stage}_{variable}"
                ax=plt.Subplot(fig, box2[idx2])
                fig.add_subplot(ax)
                axd[plot_key]=ax                    




    for plot_key in axd.keys():
        if 'overview' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0) 

        if 'empty' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0)           

    for idx1, stage in enumerate(stage_list):
        if stage=="overview":
            plot_key=stage

            axd[plot_key].text(x=-0.0, y=1.0, transform=axd[plot_key].transAxes,
                                     s="A", fontsize=35, weight='bold')           

        elif stage=="performance":  
            
            plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
#                                                                                             Paired[12][3],
                                                                                            Paired[12][5],
                                                                                            Paired[12][7],
                                                                                            Paired[12][9],
                                                                                            Paired[12][11]
                                                                                            ]])  
            
#             plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
# #                                                                                             Paired[12][3],
#                                                                                        Paired[12][9],
#                                                                                             Paired[12][5],
#                                                                                             Paired[12][7],
                                                                                            
#                                                                                             Paired[12][11]
#                                                                                             ]])     


            plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                            Paired[12][3],
                                                                                            Paired[12][5],
                                                                                            Paired[12][7],
                                                                                            Paired[12][9],
                                                                                            Paired[12][11]
                                                                                            ]])  
            
            for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
    #             elif investigation_type=="statistics":
                plot_key=f"{stage}_{task}"

                if task=="malignant" or task=="melanoma":
                    record_all_df_perf=pd.DataFrame(record_all_list)
                    record_all_df_perf=record_all_df_perf[record_all_df_perf["random_seed"]<=20]

    #                 record_all_df_perf=record_all_df_perf[(record_all_df_perf["method"].str.contains("automatic")&(record_all_df_perf["alpha"]==alpha))|\
    #                                                      (~record_all_df_perf["method"].str.contains("automatic"))
    #                                                      ]
                    record_all_df_perf_filtered=record_all_df_perf.copy()
                    record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["method"].isin(main_method_list)]
    #                 print(record_all_df_perf_filtered["method"].value_counts())

    #                 anual_annotation_full            200
    #                 automatic_curated_monet_full      200
    #                 automatic_curated_vanilla_full    200
    #                 resnet                             40
    #                 resnet_freeze_backbone             40            
                    record_all_df_perf_filtered=record_all_df_perf_filtered[
                        ((record_all_df_perf_filtered["alpha"].isnull())|(record_all_df_perf_filtered["alpha"]==alpha))&
                        ((record_all_df_perf_filtered["temp"].isnull())|(record_all_df_perf_filtered["temp"]==temp))
                    ]
    #                 record_all_df_perf_filtered=record_all_df_perf_filtered[(  ~(record_all_df_perf_filtered["method"].str.contains("automatic") | record_all_df_perf_filtered["method"].str.contains("manual"))  )|\
    #                                                              (   (record_all_df_perf_filtered["method"].str.contains("automatic") | record_all_df_perf_filtered["method"].str.contains("manual"))   &(record_all_df_perf_filtered["alpha"]==alpha)&(record_all_df_perf_filtered["temp"]==temp))]
                    #dsds
    #                 record_all_df_new=pd.DataFrame(record_all_list_new)
    #                 record_all_df_filtered_new=record_all_df_new[record_all_df_new["is_clean"]=="clean_only"]
    #                 record_all_df_filtered_new=record_all_df_filtered_new[record_all_df_filtered_new["method"]=="automatic_skincon_monet_full"]
    #                 record_all_df_filtered=pd.concat([record_all_df_filtered, record_all_df_filtered_new], axis=0)
    #                 dsdsdsds
                    if task=="malignant":
    #                     dsds
                        record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["task"]=="malignant"]
                        axd[plot_key].set_ylim(0.49, 1.01)
                    elif task=="melanoma":
                        record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["task"]=="melanoma"]
                        axd[plot_key].set_ylim(0.49, 1.01)
                    else:
                        raise ValueError
    #                 sdsdsd
    #                 sdsd
                    b=sns.boxplot(x="method", y="auc", 
                                  order=main_method_list,
                                  width=0.5,
                                  linewidth=3,
                                  saturation=1.3,
                                  boxprops=dict(alpha=.9),
                                  data=record_all_df_perf_filtered, 
                                ax=axd[plot_key])


                    sns.swarmplot(x="method", y="auc", 
                                  order=main_method_list,
                                  color='black', 
                                  alpha=0.8,
                                  size=9,
                                  data=record_all_df_perf_filtered, ax=axd[plot_key])


                    record_all_df_perf_filtered_pvalue=record_all_df_perf_filtered.groupby("task")\
                    .apply(lambda x: x.groupby("method")\
                    .apply(lambda y: 
                    scipy.stats.ttest_rel(y.set_index('random_seed')["auc"], 
                    x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"], alternative="less").pvalue)).T

                    record_all_df_perf_filtered_diff=record_all_df_perf_filtered.groupby("task")\
                    .apply(lambda x: x.groupby("method")\
                    .apply(lambda y: 
                    (   ((x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"]-y.set_index('random_seed')["auc"])>0).sum(),  len((x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"]-y.set_index('random_seed')["auc"])>0))
                          )).T    
                    print(record_all_df_perf_filtered_diff)


                    if show_pval:
                        print(show_pval)
                        count=0
                        for method_name in main_method_list[::-1]:
                            print(method_name)
        #                     continue
                            if method_name=="automatic_curated_monet_full":
                                continue

                            pvalue_x1=main_method_list.index("automatic_curated_monet_full")
                            pvalue_x2=main_method_list.index(method_name)

                            pvalue_y=record_all_df_perf_filtered[record_all_df_perf_filtered["method"]=="automatic_curated_monet_full"]["auc"].max()
                            pvalue_y_=record_all_df_perf_filtered[record_all_df_perf_filtered["method"]==method_name]["auc"].max()



                            print(method_name, record_all_df_perf_filtered_pvalue.loc[method_name][task])
                            if record_all_df_perf_filtered_pvalue.loc[method_name][task]<0.001:
                                pvalue_str="***"
                            elif record_all_df_perf_filtered_pvalue.loc[method_name][task]<0.01:
                                pvalue_str="**"
                            elif record_all_df_perf_filtered_pvalue.loc[method_name][task]<0.05:
                                pvalue_str="*"                        
                            else:
                                pvalue_str="ns"  

                            axd[plot_key].text((pvalue_x2), 
                                               pvalue_y_+0.0035,
                                     s=pvalue_str, 
                                     fontsize=25,
                                     ha='center', 
                                     va='bottom', 
                                     color="k")
                            count+=1   
                            continue


                            axd[plot_key].plot([pvalue_x1, 
                                                pvalue_x1, 
                                                pvalue_x2, 
                                                pvalue_x2], 
                                               [pvalue_y+0.013+0.023*(count),
                                                pvalue_y+0.013+0.023*(count)+0.005, 
                                                pvalue_y+0.013+0.023*(count)+0.005, 
                                                pvalue_y+0.013+0.023*(count)], 
                                               lw=3, c='black')




                    print(task, record_all_df_perf_filtered.groupby("method")["auc"].apply(lambda x: {"mean": x.mean(),
                                                                          "std": x.std(),
                                                                          "q3": x.quantile(q=0.75),                                                                      
                                                                          "median": x.median(),
                                                                          "q1": x.quantile(q=0.25),
                                                                         }))

                    if task=="malignant":
                        axd[plot_key].set_title("Malignancy"+task_suffix, fontsize=30)
                    if task=="melanoma":
                        axd[plot_key].set_title("Melanoma"+task_suffix, fontsize=30)

                    for axis in ['top','bottom','left','right']:
                        axd[plot_key].spines[axis].set_linewidth(3)                

                    if idx2==1: 
                        axd[plot_key].set_ylabel('Area under the ROC curve', fontsize=25)
                    if idx2==3: 
                        axd[plot_key].set_ylabel('')                    

                    axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
                    axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
                    axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                    axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
                    #axd[plot_key].yaxis.grid(True, which='minor', linewidth=2, alpha=0.1)

                    axd[plot_key].spines['right'].set_visible(False)
                    axd[plot_key].spines['top'].set_visible(False)                   
                    axd[plot_key].spines['bottom'].set_visible(False)    

                    axd[plot_key].tick_params(axis='x', which='major', labelsize=25)
                    axd[plot_key].tick_params(axis='y', which='major', labelsize=25)   

                    axd[plot_key].tick_params(
                        axis='x',          # changes apply to the x-axis
                        which='both',      # both major and minor ticks are affected
                        bottom=False,
                        labelbottom=False,      # ticks along the bottom edge are off
                        )            

                    axd[plot_key].set_xlabel(None)

                if task=="empty_malignant":
                    axd[plot_key].text(x=-0.05, y=1.02, transform=axd[plot_key].transAxes,
                                         s=["A", "B", "C", "D", "E", "F", "G", "H", "I"][subpanel_offset+5], fontsize=35, weight='bold') 

                if task=="malignant":

                    legend_elements=[Patch(facecolor=plt.rcParams["axes.prop_cycle"].by_key()['color'][method_idx], 
                                           edgecolor="black", linewidth=2, 
                                           label=shorten_method_name(method_name)) for method_idx, method_name in enumerate(main_method_list)]


                    axd[plot_key].legend(handles=legend_elements, 
                                ncol=5, 
                                handlelength=2.5,
                                handletextpad=0.4, 
                                columnspacing=1.3,
                                fontsize=23,
                                loc='lower center', bbox_to_anchor=(1., -0.1))                   

                    #axd[plot_key].set_ylabel('Concepts', fontsize=30)


                if task=="empty_melanoma":
                    axd[plot_key].text(x=-0.0, y=1.02, transform=axd[plot_key].transAxes,
                                         s=["A", "B", "C", "D", "E", "F", "G", "H", "I"][subpanel_offset+6], fontsize=35, weight='bold')  





        elif stage=="weight":  
            for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
                plot_key=f"{stage}_{task}"
                if task=="malignant" or task=="melanoma":
                    record_all_df_weight=pd.DataFrame(record_all_list)
                    record_all_df_weight=record_all_df_weight[record_all_df_weight["random_seed"]<=20]
                    record_all_df_weight_filtered=record_all_df_weight.copy()#[record_all_df_weight["is_clean"]=="clean_only"]

                    record_all_df_weight_filtered=record_all_df_weight_filtered[(record_all_df_weight_filtered["method"]==method_weight)&(record_all_df_weight_filtered["alpha"]==alpha)&(record_all_df_weight_filtered["temp"]==temp)]

    #                 record_all_df_filtered=record_all_df_filtered[record_all_df_filtered["method"].isin(main_method_list)]
    #                 record_all_df_filtered=record_all_df_filtered[(~record_all_df_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))|\
    #                                                              ((record_all_df_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))&(record_all_df_filtered["alpha"]==alpha)&(record_all_df_filtered["temp"]==temp))]                

                    #print(record_all_df_filtered["alpha"])

                    if task=="malignant":
                        record_all_df_weight_filtered=record_all_df_weight_filtered[record_all_df_weight_filtered["task"]=="malignant"]
                    elif task=="melanoma":
                        record_all_df_weight_filtered=record_all_df_weight_filtered[record_all_df_weight_filtered["task"]=="melanoma"]
                    else:
                        raise ValueError


                    print(record_all_df_weight_filtered)

                    coef_dict_list=[]
                    for clf_idx, clf in enumerate(record_all_df_weight_filtered["clf"]):
                        for concept_name, coef in zip(automatic_concept_info["curated"], clf.coef_[0]):
                            coef_dict_list.append({"concept_name": concept_name,
                                                   "coef": coef,
                                                   "clf_idx": clf_idx
                                                  })

    #                 print(coef_dict_list)
    #                 print(sdsd)
                    #pd.DataFrame(coef_dict_list).groupby("concept_name")["coef"].to_csv(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}_{task}.csv")
    #                 print(task, pd.DataFrame(coef_dict_list).groupby("concept_name")["coef"].mean())

                    #sdsd
                    coef_dict_list_df_barplot=pd.DataFrame(coef_dict_list)
                    coef_dict_list_df_barplot["concept_name"]=coef_dict_list_df_barplot["concept_name"].map(shorten_concept_name)

                    weight_bar=sns.barplot(y="concept_name", x="coef", 
#                                 color=np.array(Paired[12][7])/256,
                                           color=np.array([21, 156, 239])/256,
                                edgecolor='black',
                                linewidth=2,
                                width=0.5,    
                                order= list(map(shorten_concept_name,automatic_concept_info_plotorder["curated"])),
                                errwidth=5,

                                data=coef_dict_list_df_barplot, ax=axd[plot_key])




    #                 for container in weight_bar.containers:
    #                     axd[plot_key].bar_label(container)                

    #                 for p in weight_bar.patches:
    #                     _x = p.get_x() + p.get_width() / 2
    #                     _y = p.get_y() + p.get_height()
    #                     value = '{:.2f}'.format(p.get_height())
    #                     weight_bar.text(_x, _y, value, ha="center")  

                    #rint(axd[plot_key].get_yticks())

    #                 bar_labels=[]
    #                 for concept_name in concept_list_curated:
    #                     temp=pd.DataFrame(coef_dict_list)
    #                     temp=temp[temp["concept_name"]==concept_name]

    #                     if temp["coef"].mean()>3:
    #                         bar_labels.append()
    #                     else:
    #                         bar_labels.append("")



    #                 print(temp, bar_labels)
    #                 axd[plot_key].bar_label(weight_bar.containers[0], labels=bar_labels, fontsize=20)
                    #print(pd.DataFrame(coef_dict_list))
                    #dsd


                    if stack_mode=="original":

                        for p, concept_name in zip(weight_bar.patches, automatic_concept_info_plotorder["curated"]):
                            _x = p.get_x() + p.get_width() / 2
                            _y = p.get_y() + p.get_height()

                            coef_dict_list_df=pd.DataFrame(coef_dict_list)
        #                     dsdsd
                            coef_dict_list_df=coef_dict_list_df[coef_dict_list_df["concept_name"]==concept_name]

                            if coef_dict_list_df["coef"].mean()>3:
                                value=f"{coef_dict_list_df['coef'].mean():.2f} (±{1.96*coef_dict_list_df['coef'].std()/ np.sqrt(len(coef_dict_list_df)) :.2f})"
                                axd[plot_key].text(2.8, _y+0.6, value, ha="center", fontsize=20, zorder=100)

                        #value = '{:.2f}'.format(p.get_height())


    #                 for c in weight_bar.containers:
    #                     c_mean=c.datavalues.mean()
    #                     c_mean=np.round(c_mean,2)
    #                     ci=1.96*c.datavalues.std()/np.sqrt(len(c.datavalues))
    #                     ci=np.round(ci,2)
                        #axd[plot_key].bar_label(c, labels=[f"{c_mean:.2f} (±{ci})"], fontsize=20)

        #             sns.boxplot(x="method", y="auc", 
        #                         data=record_all_df_filtered, 
        #                         width=0.5,
        #                         linewidth=3,
        #                         ax=axd[plot_key])
        #             sns.swarmplot(x="method", y="auc", 
        #                           color='black', 
        #                           alpha=0.8,
        #                           size=10,
        #                           data=record_all_df_filtered, ax=axd[plot_key])


    #             if task=="empty_malignancy":
    #                 axd[plot_key].text(x=-0.0, y=1.0, transform=axd[plot_key].transAxes,
    #                                                          s="D", fontsize=35, weight='bold')
    #             if task=="empty_melanoma":
    #                 axd[plot_key].text(x=-0.0, y=1.0, transform=axd[plot_key].transAxes,
    #                                                          s="E", fontsize=35, weight='bold')

                ["A", "B", "C", "D", "E", "F", "G", "H", "I"][subpanel_offset+7]
                if task=="empty_malignant":
                    axd[plot_key].text(x=-0., y=1.05, transform=axd[plot_key].transAxes,
                                         s=["A", "B", "C", "D", "E", "F", "G", "H", "I"][subpanel_offset+7], fontsize=35, weight='bold')              

                    #axd[plot_key].set_ylabel('Concepts', fontsize=30)


                if task=="empty_melanoma":
                    axd[plot_key].text(x=-0.1, y=1.05, transform=axd[plot_key].transAxes,
                                         s=["A", "B", "C", "D", "E", "F", "G", "H", "I"][subpanel_offset+8], fontsize=35, weight='bold')    

                if task=="malignant":
                    axd[plot_key].set_title("Malignancy"+task_suffix, fontsize=30, pad=20)
                if task=="melanoma":
                    axd[plot_key].set_title("Melanoma"+task_suffix, fontsize=30, pad=20)

                if task=="malignant" or task=="melanoma":
                    if stack_mode=="original":
                        if task=="malignant":
                            axd[plot_key].set_xlim(-3,3)
                        elif task=="melanoma":
                            axd[plot_key].set_xlim(-3,3)
                        else:
                            raise ValueError 

    #                 axd[plot_key].set_xlim(-(coef_dict_list_df_barplot['coef'].max()+0.2),
    #                                        (coef_dict_list_df_barplot['coef'].max()+0.2)
    #                                       )


                    axd[plot_key].axvline(x=0, ymin=0, ymax=1, color='black', alpha=0.7, linewidth=5, zorder=-5)

                    for axis in ['top','bottom','left','right']:
                        axd[plot_key].spines[axis].set_linewidth(3)                

                    axd[plot_key].set_ylabel('Area under the curve', fontsize=25)

        #             axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
        #             axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01)) 
                    if stack_mode=="original":
                        axd[plot_key].xaxis.set_major_locator(MultipleLocator(0.5))
                        axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.1))
                    elif stack_mode=="suggestion":
                        axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
                        axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.5))                    

                    axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                    axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

                    axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                    #axd[plot_key].yaxis.grid(True, which='minor', linewidth=2, alpha=0.1)

                    axd[plot_key].spines['left'].set_visible(False)
                    axd[plot_key].spines['right'].set_visible(False)
                    axd[plot_key].spines['top'].set_visible(False)                   
                    #axd[plot_key].spines['bottom'].set_visible(False)    

                    axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
                    axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

                    axd[plot_key].tick_params(
                        axis='x',          # changes apply to the x-axis
                        which='both',      # both major and minor ticks are affected
                        top=False,
                        labeltop=False,      # ticks along the bottom edge are off                    
                        bottom=True,
                        labelbottom=True,      # ticks along the bottom edge are off
                        )            

                    axd[plot_key].set_ylabel(None)
                    axd[plot_key].set_xlabel("Coefficients of linear model", fontsize=25, labelpad=5)
                    #axd[plot_key].xaxis.set_label_position('top') 
        elif stage=="skincon":  
            plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                       Paired[12][9],
                                                                                            Paired[12][3],
                                                                                            Paired[12][5],
                                                                                            Paired[12][7],
                                                                                            Paired[12][9],
                                                                                            Paired[12][11]
                                                                                            ]])               
            
            for idx2, variable in enumerate(["malignant_num_concept", "malignant_num_sample", 
                                             "melanoma_num_concept", "melanoma_num_sample", ]):
                plot_key=f"{stage}_{variable}"    

                record_all_df_skincon=pd.DataFrame(record_all_list)
                record_all_df_skincon=record_all_df_skincon[record_all_df_skincon["random_seed"]<20]
                record_all_df_skincon_filtered=record_all_df_skincon.copy()#[record_all_df_skincon["is_clean"]=="clean_only"]

    #             import ipdb
    #             ipdb.set_trace()            
                record_all_df_skincon_filtered=record_all_df_skincon_filtered[
                    record_all_df_skincon_filtered["method"].isin([f"automatic_{manual_annotation_concepts_name}_monet_full", "manual_annotation_less_sample", "manual_annotation_less_concept"])]
        
                record_all_df_skincon_filtered=record_all_df_skincon_filtered[
                    ((record_all_df_skincon_filtered["alpha"].isnull())|(record_all_df_skincon_filtered["alpha"]==alpha))&
                    ((record_all_df_skincon_filtered["temp"].isnull())|(record_all_df_skincon_filtered["temp"]==temp))
                ]        


                if variable.startswith("malignant"):
                    record_all_df_skincon_filtered=record_all_df_skincon_filtered[record_all_df_skincon_filtered["task"]=="malignant"]
                elif variable.startswith("melanoma"):
                    record_all_df_skincon_filtered=record_all_df_skincon_filtered[record_all_df_skincon_filtered["task"]=="melanoma"]
                else:
                    raise ValueError


                record_all_df_skincon_filtered_ref=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]==f"automatic_{manual_annotation_concepts_name}_monet_full"]
                if variable.endswith("num_sample"):
                    record_all_df_skincon_filtered_obs=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]=="manual_annotation_less_sample"]
                    if variable.startswith("melanoma"):
                        pass
    #                     sdsd
                    record_all_df_skincon_filtered_obs_sample_prop=record_all_df_skincon_filtered_obs.groupby("num_sample_train")["num_sample_train"].mean()
                    #record_all_df_skincon_filtered_obs["num_sample_train_pseudo"]=record_all_df_skincon_filtered_obs.apply(lambda x: record_all_df_skincon_filtered_obs_sample_prop[x["num_sample_train_select"]], axis=1)
                    #jjj
                    b=sns.lineplot(x="num_sample_train", y="auc", 
                                   color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1], linewidth=4,
                                   data=record_all_df_skincon_filtered_obs, ax=axd[plot_key])
                    if b.legend_ is not None:
                        b.legend_.remove()
                    #axd[plot_key].axhline(y=ref_value, xmin=0, xmax=100000, color='red', alpha=0.7, linewidth=3, zorder=-5)


                    axd[plot_key].scatter(0, record_all_df_skincon_filtered_ref["auc"].mean(), s=300, marker='X', color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0])
                    print(stage, variable, record_all_df_skincon_filtered_ref["auc"].mean())
    #                 record_all_df_skincon_filtered_ref_matched=pd.DataFrame(list(itertools.product(record_all_df_skincon_filtered_obs_sample_prop.values, 
    #                                                     record_all_df_skincon_filtered_ref["auc"].values)),
    #                             columns=["num_sample_train_pseudo", "auc"])                
    #                 b=sns.lineplot(x="num_sample_train_pseudo", y="auc", 
    #                                color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], linewidth=4, linestyle="--",
    #                                data=record_all_df_skincon_filtered_ref_matched, ax=axd[plot_key])
    #                 if b.legend_ is not None:
    #                     b.legend_.remove()
                    #axd[plot_key].set_xlim(record_all_df_skincon_filtered_obs["num_sample_train_pseudo"].min(), record_all_df_skincon_filtered_obs["num_sample_train_pseudo"].max())

                elif variable.endswith("num_concept"):
                    #sdsdd
                    record_all_df_skincon_filtered_obs=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]=="manual_annotation_less_concept"]
                    b=sns.lineplot(x="num_concept", y="auc", 
                                   color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1], linewidth=4,
                                   data=record_all_df_skincon_filtered_obs, ax=axd[plot_key])
                    if b.legend_ is not None:
                        b.legend_.remove()


                    axd[plot_key].scatter((record_all_df_skincon_filtered_ref["clf"].iloc[0].coef_).shape[1], 
                                          record_all_df_skincon_filtered_ref["auc"].mean(), s=300, marker='X', color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0])
                    print(stage, variable, record_all_df_skincon_filtered_ref["auc"].mean())
                    #sds
                    #axd[plot_key].axhline(y=ref_value, xmin=0, xmax=100000, color='red', alpha=0.7, linewidth=3, zorder=-5)                    

    #                 record_all_df_skincon_filtered_ref_matched=pd.DataFrame(list(itertools.product(record_all_df_skincon_filtered_obs["num_concept"].unique(), 
    #                                                     record_all_df_skincon_filtered_ref["auc"].values)),
    #                             columns=["num_sample_train_pseudo", "auc"])                
    #                 b=sns.lineplot(x="num_sample_train_pseudo", y="auc", 
    #                                color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], linewidth=4, linestyle="--",
    #                                data=record_all_df_skincon_filtered_ref_matched, ax=axd[plot_key])                
    #                 if b.legend_ is not None:
    #                     b.legend_.remove()





    #             dsds  
    #             ref_value=record_all_df_skincon_filtered_ref[record_all_df_skincon_filtered_ref["task"]=="melanoma"]["auc"].mean()

    #             #

    #             #sns.lineplot(x="sample_prop", y="auc", hue="task", style="method", data=record_all_df_skincon_filtered, ax=axd[plot_key])
    #             dsdsd


    #             
    #             axd[plot_key].set_xlim(0, record_all_df_skincon_filtered["num_sample_train"].max())
    #                 #axd[plot_key].set_xlim(1-0.2, 11.5)

    #             record_all_df_skincon_filtered     


                axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
                axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
                axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)    
                axd[plot_key].tick_params(axis='y', which='major', labelsize=20)

                if variable.endswith("num_concept"):
                    if dataset_name=='clinical_fd_clean_nodup_nooverlap':
                        axd[plot_key].xaxis.set_major_locator(MultipleLocator(10))
                        axd[plot_key].xaxis.set_minor_locator(MultipleLocator(5))
                    elif dataset_name=='derm7pt_derm_nodup':
                        axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
                        axd[plot_key].xaxis.set_minor_locator(MultipleLocator(1))                    
                    else:
                        raise ValueError

                    axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                    axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
                    axd[plot_key].tick_params(axis='x', which='major', labelsize=20)


                    axd[plot_key].set_xlabel("Num. of concepts", fontsize=25)
                    axd[plot_key].set_xlim(-0.1, 49)


    #                 axd[plot_key].tick_params(
    #                     axis='y',          # changes apply to the x-axis
    #                     which='both',      # both major and minor ticks are affected
    #                     labelleft=False)                

                elif variable=="num_reference":
                    axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
                    axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                    axd[plot_key].tick_params(axis='x', which='major', labelsize=20)



                elif variable.endswith("num_sample"):
                    #axd[plot_key].set_xticks([0.05, 0.1 , 0.2 , 0.4 , 0.6 , 0.8 , 1.])
                    #.set_xticks([2,4,6,8,10])
                    dataset_name=='clinical_fd_clean_nodup_nooverlap'
                    #'derm7pt_derm_nodup', 
                    if dataset_name=='clinical_fd_clean_nodup_nooverlap':
                        if variable.startswith("malignant"):
                            axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
                            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))                
                        elif variable.startswith("malignant"):
                            axd[plot_key].xaxis.set_major_locator(MultipleLocator(100))
                            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(50))  
                    elif dataset_name=='derm7pt_derm_nodup':
                        if variable.startswith("malignant"):
                            axd[plot_key].xaxis.set_major_locator(MultipleLocator(200))
                            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))                
                        elif variable.startswith("malignant"):
                            axd[plot_key].xaxis.set_major_locator(MultipleLocator(100))
                            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(50))                      
                    else:
                        raise ValueError

                    axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                    axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
                    axd[plot_key].tick_params(axis='x', which='major', labelsize=20)

                    axd[plot_key].set_xlabel("Num. of expert-labeled samples", fontsize=25)#, labelpad=-10)

                if idx2==0:
                    axd[plot_key].set_ylabel("Area under the ROC curve", fontsize=25)
                else:
                    axd[plot_key].set_ylabel(None)    


                    #axd[plot_key].tick_params(axis='x', which='major', left=False, labelleft=False)

                if variable.startswith("malignant"):
                    #axd[plot_key].set_ylim(0.61, 0.98)
                    #axd[plot_key].set_ylim(0.531, 0.881)
                    axd[plot_key].set_ylim(0.49, 1.01)
                    axd[plot_key].set_title("Malignancy"+task_suffix, fontsize=25, pad=20)
                elif variable.startswith("melanoma"):                
                    axd[plot_key].set_title("Melanoma"+task_suffix, fontsize=25, pad=20)
                    axd[plot_key].set_ylim(0.49, 1.01)

                if variable.endswith("num_sample"):
                    if dataset_name=='clinical_fd_clean_nodup_nooverlap':
                        if variable.startswith("malignant"):
                            axd[plot_key].set_xlim(left=-50)
                        elif variable.startswith("melanoma"):    
                            axd[plot_key].set_xlim(left=-10)
                    elif dataset_name=='derm7pt_derm_nodup':
                        if variable.startswith("malignant"):
                            axd[plot_key].set_xlim(left=-30)
                        elif variable.startswith("melanoma"):    
                            axd[plot_key].set_xlim(left=-30)                    

                    else:
                        raise ValueError                        

                if variable.endswith("num_concept"):
                    if dataset_name=='clinical_fd_clean_nodup_nooverlap':
                        axd[plot_key].set_xlim(left=-1)
                        axd[plot_key].set_xlim(right=50) 
                    elif dataset_name=='derm7pt_derm_nodup':
                        axd[plot_key].set_xlim(left=-0.3)
                        axd[plot_key].set_xlim(right=7.5) 
                    else:
                        raise ValueError                    


                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(3)                     
                axd[plot_key].spines['right'].set_visible(False)
                axd[plot_key].spines['top'].set_visible(False) 

                axd[plot_key].text(x=-0.15, y=1.05, transform=axd[plot_key].transAxes,
                                         s=["A", "B", "C", "D", "E", "F", "G", "H", "I"][subpanel_offset+1+idx2], fontsize=35, weight='bold')  




                if idx2==1:
                    legend_elements=[Line2D([], [], color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], label="MONET+CBM (SkinCon)", linestyle='None', marker='X', markersize=20),
                                     Line2D([0], [0], color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1], linewidth=10, label="Manual Label (SkinCon)")]
                    axd[plot_key].legend(handles=legend_elements, 
                                ncol=2, 
                                handlelength=3,
                                handletextpad=0.6, 
                                columnspacing=1.5,
                                fontsize=23,
                                loc='lower center', 
                                bbox_to_anchor=(1, -0.33)).set_zorder(100)              

                #record_all_df_filtered=record_all_df_filtered[record_all_df_filtered["method"]=="automatic_monet_full"]    

    # fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.png", bbox_inches='tight')
    # fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.jpg", bbox_inches='tight')
    # fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.svg", bbox_inches='tight')
    # fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.pdf", bbox_inches='tight')

    # fig.savefig(log_dir/"plots"/"main_cbm.png", bbox_inches='tight')
    # fig.savefig(log_dir/"plots"/"main_cbm.jpg", bbox_inches='tight')
    # fig.savefig(log_dir/"plots"/"main_cbm.svg", bbox_inches='tight')
    # fig.savefig(log_dir/"plots"/"main_cbm.pdf", bbox_inches='tight')
    # # plt.close(fig)
    return fig

# clinical and dermoscpic together

In [None]:
plot_cbm_results?

In [None]:
record_all_list_dict.keys()

In [None]:
    "clinical_fd_clean_nodup_nooverlap": torch.load("logs/experiment_results/cbm_clinical_fd_clean_nodup_nooverlap_1005.pt", map_location="cpu"),
    
    "clinical_fd_clean_nodup_nooverlap_manualmatch": torch.load("logs/experiment_results/cbm_clinical_fd_clean_nodup_nooverlap_manualmatch_1005.pt", map_location="cpu"),
    "clinical_fd_clean_nodup_nooverlap_with_validation_automatic_manual": torch.load("logs/experiment_results/cbm_clinical_fd_clean_nodup_nooverlap_1007_automatic_manual_with_validation.pt", map_location="cpu"),    
    
    "clinical_fd_clean_nodup_nooverlap_with_validation_seed2040": torch.load("logs/experiment_results/cbm_clinical_fd_clean_nodup_nooverlap_1008_all_methods_with_validation_seed20_40.pt", map_location="cpu"),
    "clinical_fd_clean_nodup_nooverlap_with_validation_manualmatch_seed2040": torch.load("logs/experiment_results/cbm_clinical_fd_clean_nodup_nooverlap_1008_manualmatch_with_validation_seed20_40.pt", map_location="cpu"),    
    "derm7pt_derm_nodup": torch.load("logs/experiment_results/cbm_derm7pt_derm_nodup_1001.pt", map_location="cpu"),
    "cbm_isic_nodup_nooverlap_all": torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1005.pt", map_location="cpu"),
    "cbm_isic_barcelona_vienna": torch.load("logs/experiment_results/cbm_isic_barcelona_vienna_1006.pt", map_location="cpu"),
    "cbm_isic_vienna_barcelona": torch.load("logs/experiment_results/cbm_isic_vienna_barcelona_1006.pt", map_location="cpu"),
    "cbm_isic_nodup_nooverlap_all_automatic_with_validation": torch.load("logs/experiment_results/cbm_isic_nodup_nooverlap_all_1007_automatic_with_validation.pt", map_location="cpu"),


In [None]:
dataset_name="clinical_fd_clean_nodup_nooverlap"

main_method_list=["automatic_curated_monet_full", 
                  "automatic_curated_vanilla_full",
#                   "manual_annotation_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  ]

manual_annotation_concepts_name="skincon"

automatic_concept_info={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

automatic_concept_info_plotorder={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

subpanel_offset=-1
show_pval=True
stack_mode="suggestion"
stage_list=["skincon"]
task_suffix=""
method_weight="automatic_curated_vanilla_full"

record_all_list_manual=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_manualmatch_seed2040"]

fig=plot_cbm_results(
    dataset_name=dataset_name,
    main_method_list=main_method_list,
    manual_annotation_concepts_name=manual_annotation_concepts_name,
    automatic_concept_info=automatic_concept_info,
    automatic_concept_info_plotorder=automatic_concept_info_plotorder,
    record_all_list=None,
    record_all_list_manual=record_all_list_manual,
    subpanel_offset=subpanel_offset,
    method_weight=method_weight,
    stage_list=stage_list,
    show_pval=show_pval,
    stack_mode=stack_mode,
    task_suffix=task_suffix,
    alpha=0.001,
    temp=0.02,
    skincon_height=3,  
    random_seed_range=list(range(21,21+20))
)
print()

# fig.savefig(log_dir/"plots"/"cbm_clinical_num_concepts_samples.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_clinical_num_concepts_samples.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_clinical_num_concepts_samples.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_clinical_num_concepts_samples.pdf", bbox_inches='tight')

In [None]:
record_all_list_dict.keys()

In [None]:
def plot_cbm_results(
    dataset_name,
    main_method_list,
    manual_annotation_concepts_name,
    automatic_concept_info,
    automatic_concept_info_plotorder,
    record_all_list,
    record_all_list_manual,
    method_weight,
    stage_list,
    subpanel_offset,
    task_suffix,
    show_pval,
    stack_mode,
    alpha,
    temp,
    random_seed_range,
    skincon_height=3,
    performance_height=4,
    swarm_size=9,
    hspace=0.35,
    improve_small_range=False,
    

):


    if "original"=="original":
        fig = plt.figure(figsize=(3*10, 3*(3 + 2 + performance_height + skincon_height + 0.35*3)))

        box1 = gridspec.GridSpec(4,1,
                                 height_ratios=[3, 2, performance_height, skincon_height],
                                 wspace=0.0,
                                 hspace=hspace
#                                  hspace=0.3
                                )

    elif stack_mode=="suggestion":  

        fig = plt.figure(figsize=(3*10, 3*(3 + 2.5 + 4 + 6 + 0.35*3)))

        box1 = gridspec.GridSpec(4,1,
                                 height_ratios=[3, 2.5, 4, 6],
                                 wspace=0.0,
                                 hspace=0.35)    


    # temp array([  nan, 0.02 , 0.01 , 0.005])
    # alpha array([0.001 , 0.0001,    nan])


    axd={}
    for idx1, stage in enumerate(["overview", "skincon", "performance", "weight"]):
        if stage=="overview":
            plot_key=stage
            ax=plt.Subplot(fig, box1[idx1])
            fig.add_subplot(ax) 
            axd[plot_key]=ax
        elif stage=="empty":
            plot_key=stage
            ax=plt.Subplot(fig, box1[idx1])
            fig.add_subplot(ax) 
            axd[plot_key]=ax        
        elif stage=="performance":
            box2 = gridspec.GridSpecFromSubplotSpec(1, 4,
                            subplot_spec=box1[idx1], 
                            width_ratios=[0.1, 1, 0.1, 1], wspace=0., hspace=0.)        
            for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
    #             elif investigation_type=="statistics":
                plot_key=f"{stage}_{task}"
                ax=plt.Subplot(fig, box2[idx2])
                fig.add_subplot(ax)
                axd[plot_key]=ax  

        elif stage=="weight":
            if "original"=="original":
                box2 = gridspec.GridSpecFromSubplotSpec(1, 4,
                                subplot_spec=box1[idx1], 
                                width_ratios=[0.15, 1, 0.15, 1], wspace=0.1, hspace=0.)        
                for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
        #             elif investigation_type=="statistics":
                    plot_key=f"{stage}_{task}"
                    ax=plt.Subplot(fig, box2[idx2])
                    fig.add_subplot(ax)
                    axd[plot_key]=ax  

            elif stack_mode=="suggestion":

                box2 = gridspec.GridSpecFromSubplotSpec(2, 2,
                                subplot_spec=box1[idx1], 
                                width_ratios=[0.07, 1], wspace=0.1, hspace=0.3)        

                print(box2)
                for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
        #             elif investigation_type=="statistics":
                    plot_key=f"{stage}_{task}"
                    ax=plt.Subplot(fig, box2[idx2])
                    fig.add_subplot(ax)
                    axd[plot_key]=ax              

        elif stage=="skincon":
            box2 = gridspec.GridSpecFromSubplotSpec(1, 8,
                            subplot_spec=box1[idx1], width_ratios=[0.15, 1, 0.15, 1,     0.27, 1, 0.15, 1], wspace=0.0, hspace=0.)        
            for idx2, variable in enumerate(["empty", "malignant_num_concept", "empty1", "malignant_num_sample", 
                                             "empty2", "melanoma_num_concept", "empty3", "melanoma_num_sample"]):
    #             elif investigation_type=="statistics":
                plot_key=f"{stage}_{variable}"
                ax=plt.Subplot(fig, box2[idx2])
                fig.add_subplot(ax)
                axd[plot_key]=ax                    




    for plot_key in axd.keys():
        if 'overview' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0) 

        if 'empty' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0)           

    for idx1, stage in enumerate(stage_list):
        if stage=="overview":
            plot_key=stage

            axd[plot_key].text(x=-0.0, y=1.0, transform=axd[plot_key].transAxes,
                                     s="A", fontsize=35, weight='bold')           

        elif stage=="performance":  
            
            plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
#                                                                                             Paired[12][3],
                                                                                            Paired[12][5],
                                                                                            Paired[12][7],
                                                                                            Paired[12][9],
                                                                                            Paired[12][11]
                                                                                            ]])  
            
#             plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
# #                                                                                             Paired[12][3],
#                                                                                        Paired[12][9],
#                                                                                             Paired[12][5],
#                                                                                             Paired[12][7],
                                                                                            
#                                                                                             Paired[12][11]
#                                                                                             ]])     


            plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                            Paired[12][3],
                                                                                            Paired[12][5],
                                                                                            Paired[12][7],
                                                                                            Paired[12][9],
                                                                                            Paired[12][11]
                                                                                            ]])  
            
            for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
    #             elif investigation_type=="statistics":
                plot_key=f"{stage}_{task}"

                if task=="malignant" or task=="melanoma":
                    record_all_df_perf=pd.DataFrame(record_all_list)
                    record_all_df_perf=record_all_df_perf[record_all_df_perf["random_seed"].isin(random_seed_range)]

    #                 record_all_df_perf=record_all_df_perf[(record_all_df_perf["method"].str.contains("automatic")&(record_all_df_perf["alpha"]==alpha))|\
    #                                                      (~record_all_df_perf["method"].str.contains("automatic"))
    #                                                      ]
                    record_all_df_perf_filtered=record_all_df_perf.copy()
                    record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["method"].isin(main_method_list)]
    #                 print(record_all_df_perf_filtered["method"].value_counts())

    #                 anual_annotation_full            200
    #                 automatic_curated_monet_full      200
    #                 automatic_curated_vanilla_full    200
    #                 resnet                             40
    #                 resnet_freeze_backbone             40            
                    record_all_df_perf_filtered=record_all_df_perf_filtered[
                        ((record_all_df_perf_filtered["alpha"].isnull())|(record_all_df_perf_filtered["alpha"]==alpha))&
                        ((record_all_df_perf_filtered["temp"].isnull())|(record_all_df_perf_filtered["temp"]==temp))
                    ]
    #                 record_all_df_perf_filtered=record_all_df_perf_filtered[(  ~(record_all_df_perf_filtered["method"].str.contains("automatic") | record_all_df_perf_filtered["method"].str.contains("manual"))  )|\
    #                                                              (   (record_all_df_perf_filtered["method"].str.contains("automatic") | record_all_df_perf_filtered["method"].str.contains("manual"))   &(record_all_df_perf_filtered["alpha"]==alpha)&(record_all_df_perf_filtered["temp"]==temp))]
                    #dsds
    #                 record_all_df_new=pd.DataFrame(record_all_list_new)
    #                 record_all_df_filtered_new=record_all_df_new[record_all_df_new["is_clean"]=="clean_only"]
    #                 record_all_df_filtered_new=record_all_df_filtered_new[record_all_df_filtered_new["method"]=="automatic_skincon_monet_full"]
    #                 record_all_df_filtered=pd.concat([record_all_df_filtered, record_all_df_filtered_new], axis=0)
    #                 dsdsdsds
                    if task=="malignant":
    #                     dsds
                        record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["task"]=="malignant"]
                        axd[plot_key].set_ylim(0.49, 1.01)
                    elif task=="melanoma":
                        record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["task"]=="melanoma"]
                        axd[plot_key].set_ylim(0.49, 1.01)
                    else:
                        raise ValueError
    #                 sdsdsd
    #                 sdsd
                    if improve_small_range:
                        b=sns.boxplot(x="method", y="auc", 
                                      order=main_method_list,
                                      width=0.7,
                                      linewidth=1.5,
                                      saturation=1.4,
                                      boxprops=dict(alpha=.9, linewidth=0.8),
                                      data=record_all_df_perf_filtered, 
                                      showfliers=False,
                                    ax=axd[plot_key])  
                        #import ipdb
                        #ipdb.set_trace()
                    else:
                        b=sns.boxplot(x="method", y="auc", 
                                      order=main_method_list,
                                      width=0.65,
                                      linewidth=2,
                                      saturation=1.4,
                                      boxprops=dict(alpha=.82, linewidth=0.8),
                                      data=record_all_df_perf_filtered, 
                                      showfliers=False,
                                    ax=axd[plot_key])

                    #print(record_all_df_perf_filtered)
                    sns.swarmplot(x="method", y="auc", 
                                  order=main_method_list,
                                  color='black', 
                                  alpha=0.8,
                                  size=swarm_size,
                                  data=record_all_df_perf_filtered, ax=axd[plot_key])


                    record_all_df_perf_filtered_pvalue=record_all_df_perf_filtered.groupby("task")\
                    .apply(lambda x: x.groupby("method")\
                    .apply(lambda y: 
                    scipy.stats.ttest_rel(y.set_index('random_seed')["auc"], 
                    x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"], alternative="less").pvalue)).T
#                     import ipdb
#                     ipdb.set_trace()


                    

                    
                    record_all_df_perf_filtered_diff=record_all_df_perf_filtered.groupby("task")\
                    .apply(lambda x: x.groupby("method")\
                    .apply(lambda y: 
                    (   ((x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"]-y.set_index('random_seed')["auc"])>0).sum(),  len((x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"]-y.set_index('random_seed')["auc"])>0))
                          )).T   
                    print(record_all_df_perf_filtered_pvalue)
                    def val_to_exp(val):
                        if np.isnan(val):
                            return val
                        mantissa, exponent=f"{val:.3e}".split('e')
                        pvalue_str_new = f"{mantissa} × 10^{{{int(exponent)}}}"
                        return pvalue_str_new
                    print(record_all_df_perf_filtered_pvalue.applymap(val_to_exp))
                    print(record_all_df_perf_filtered_diff)

                    print(record_all_df_perf_filtered.groupby("task").apply(lambda x: x.groupby("random_seed").apply(lambda x: x.sort_values("auc").iloc[-1]["method"]).value_counts()).T)
                    
                    
                    if improve_small_range:
                        if True:
                            for method_name in main_method_list[::-1]:

                                x_idx=main_method_list.index(method_name)
                                y_loc=record_all_df_perf_filtered[record_all_df_perf_filtered["method"]==method_name]["auc"].max()

                                text_temp=axd[plot_key].text((x_idx), 
                                                   y_loc+0.015,
                                         s=very_shorten_method_name(method_name), 
                                                   
                                         fontsize=22,
                                         ha='center', 
                                         va='bottom', 
                                         color="k")   
                                
                                text_temp.set_bbox(dict(facecolor='white', alpha=0.5, edgecolor="white"))
                            
                        
                    if show_pval:
                        count=0
                        for method_name in main_method_list[::-1]:
        #                     continue
                            if method_name=="automatic_curated_monet_full":
                                continue

                            pvalue_x1=main_method_list.index("automatic_curated_monet_full")
                            pvalue_x2=main_method_list.index(method_name)

                            pvalue_y=record_all_df_perf_filtered[record_all_df_perf_filtered["method"]=="automatic_curated_monet_full"]["auc"].max()
                            pvalue_y_=record_all_df_perf_filtered[record_all_df_perf_filtered["method"]==method_name]["auc"].max()



                            print(method_name, record_all_df_perf_filtered_pvalue.loc[method_name][task])
                            if record_all_df_perf_filtered_pvalue.loc[method_name][task]<0.001:
                                pvalue_str="***"
                            elif record_all_df_perf_filtered_pvalue.loc[method_name][task]<0.01:
                                pvalue_str="**"
                            elif record_all_df_perf_filtered_pvalue.loc[method_name][task]<0.05:
                                pvalue_str="*"                        
                            else:
                                pvalue_str="ns"  
                                
#                             #print(pvalue_x2, pvalue_y_, pvalue_str)
#                             axd[plot_key].text((pvalue_x2), 
#                                                pvalue_y_+0.0035,
#                                      s=pvalue_str, 
#                                      fontsize=25,
#                                      ha='center', 
#                                      va='bottom', 
#                                      color="k")                                  
                                
                            mantissa, exponent=f"{record_all_df_perf_filtered_pvalue.loc[method_name][task]:.3e}".split('e')
                            pvalue_str_new = f"$p={mantissa} × 10^{{{int(exponent)}}}$"
                            print(pvalue_x2, pvalue_y_, pvalue_str)
                            axd[plot_key].text((pvalue_x2), 
                                               pvalue_y_+0.0665,
                                     s=pvalue_str_new, 
                                     fontsize=20,
                                     ha='center', 
                                     va='bottom', 
                                     color="k")                              
                          
                            #aaaa
                            count+=1   
                            continue


                            axd[plot_key].plot([pvalue_x1, 
                                                pvalue_x1, 
                                                pvalue_x2, 
                                                pvalue_x2], 
                                               [pvalue_y+0.013+0.023*(count),
                                                pvalue_y+0.013+0.023*(count)+0.005, 
                                                pvalue_y+0.013+0.023*(count)+0.005, 
                                                pvalue_y+0.013+0.023*(count)], 
                                               lw=3, c='black')


#                     dsdsdsds

                    print(task, record_all_df_perf_filtered.groupby("method")["auc"].apply(lambda x: {"mean": x.mean(),
                                                                          "std": x.std(),
                                                                          "max": x.max(),
                                                                          "q3": x.quantile(q=0.75),                                                                      
                                                                          "median": x.median(),
                                                                          "q1": x.quantile(q=0.25),
                                                                          "min": x.min(),            
#                                                                           "conf":scipy.stats.t.interval(0.95, len(x)-1, loc=x.mean(), scale=x.std()),
                                                                          "conf2":(x.mean()-1.96*x.std()/np.sqrt(len(x)), x.mean()+1.96*x.std()/np.sqrt(len(x)))
                                                                         }))

                    if task=="malignant":
                        axd[plot_key].set_title("Malignancy"+task_suffix, fontsize=30)
                    if task=="melanoma":
                        axd[plot_key].set_title("Melanoma"+task_suffix, fontsize=30)

                    for axis in ['top','bottom','left','right']:
                        axd[plot_key].spines[axis].set_linewidth(3)                

                    if idx2==1: 
                        axd[plot_key].set_ylabel('Area under the ROC curve', fontsize=25)
                    if idx2==3: 
                        axd[plot_key].set_ylabel('')                    

                    axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
                    axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
                    axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                    axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
                    #axd[plot_key].yaxis.grid(True, which='minor', linewidth=2, alpha=0.1)

                    axd[plot_key].spines['right'].set_visible(False)
                    axd[plot_key].spines['top'].set_visible(False)                   
                    axd[plot_key].spines['bottom'].set_visible(False)    

                    axd[plot_key].tick_params(axis='x', which='major', labelsize=25)
                    axd[plot_key].tick_params(axis='y', which='major', labelsize=25)   

                    axd[plot_key].tick_params(
                        axis='x',          # changes apply to the x-axis
                        which='both',      # both major and minor ticks are affected
                        bottom=False,
                        labelbottom=False,      # ticks along the bottom edge are off
                        )            

                    axd[plot_key].set_xlabel(None)

                if task=="empty_malignant":
                    axd[plot_key].text(x=-0.05, y=1.015, transform=axd[plot_key].transAxes,
                                         s=["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M"][subpanel_offset+5], fontsize=35, weight='bold') 

                if task=="malignant":

                    legend_elements=[Patch(facecolor=plt.rcParams["axes.prop_cycle"].by_key()['color'][method_idx], 
                                           edgecolor="black", linewidth=2, alpha=0.95,
                                           label=shorten_method_name(method_name)) for method_idx, method_name in enumerate(main_method_list)]


                    axd[plot_key].legend(handles=legend_elements, 
                                ncol=5, 
                                handlelength=2.5,
                                handletextpad=0.4, 
                                columnspacing=1.3,
                                fontsize=23,
                                loc='lower center', bbox_to_anchor=(1., -0.12))                   

                    #axd[plot_key].set_ylabel('Concepts', fontsize=30)


                if task=="empty_melanoma":
                    axd[plot_key].text(x=-0.0, y=1.015, transform=axd[plot_key].transAxes,
                                         s=["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M"][subpanel_offset+6], fontsize=35, weight='bold')  





        elif stage=="weight":  
            for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
                plot_key=f"{stage}_{task}"
                if task=="malignant" or task=="melanoma":
                    record_all_df_weight=pd.DataFrame(record_all_list)
                    record_all_df_weight=record_all_df_weight[record_all_df_weight["random_seed"].isin(random_seed_range)]
                    record_all_df_weight_filtered=record_all_df_weight.copy()#[record_all_df_weight["is_clean"]=="clean_only"]

                    record_all_df_weight_filtered=record_all_df_weight_filtered[(record_all_df_weight_filtered["method"]==method_weight)&(record_all_df_weight_filtered["alpha"]==alpha)&(record_all_df_weight_filtered["temp"]==temp)]

    #                 record_all_df_filtered=record_all_df_filtered[record_all_df_filtered["method"].isin(main_method_list)]
    #                 record_all_df_filtered=record_all_df_filtered[(~record_all_df_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))|\
    #                                                              ((record_all_df_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))&(record_all_df_filtered["alpha"]==alpha)&(record_all_df_filtered["temp"]==temp))]                

                    #print(record_all_df_filtered["alpha"])

                    if task=="malignant":
                        record_all_df_weight_filtered=record_all_df_weight_filtered[record_all_df_weight_filtered["task"]=="malignant"]
                    elif task=="melanoma":
                        record_all_df_weight_filtered=record_all_df_weight_filtered[record_all_df_weight_filtered["task"]=="melanoma"]
                    else:
                        raise ValueError


                    print(record_all_df_weight_filtered)

                    coef_dict_list=[]
                    for clf_idx, clf in enumerate(record_all_df_weight_filtered["clf"]):
                        for concept_name, coef in zip(automatic_concept_info[method_weight.split('_')[1]], clf.coef_[0]):
                            coef_dict_list.append({"concept_name": concept_name,
                                                   "coef": coef,
                                                   "clf_idx": clf_idx
                                                  })

    #                 print(coef_dict_list)
    #                 print(sdsd)
                    #pd.DataFrame(coef_dict_list).groupby("concept_name")["coef"].to_csv(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}_{task}.csv")
    #                 print(task, pd.DataFrame(coef_dict_list).groupby("concept_name")["coef"].mean())

                    #sdsd
                    coef_dict_list_df_barplot=pd.DataFrame(coef_dict_list)
                    coef_dict_list_df_barplot["concept_name"]=coef_dict_list_df_barplot["concept_name"].map(shorten_concept_name)

                
                    errwidth=3
                    weight_bar=sns.barplot(y="concept_name", x="coef", 
#                                 color=np.array(Paired[12][7])/256,
                                           color=np.array([21, 156, 239])/256,
                                           alpha=0.9,
                                edgecolor='black',
                                linewidth=1,#2,
                                width=0.8,    
                                order= list(map(shorten_concept_name,automatic_concept_info_plotorder[method_weight.split('_')[1]])),
                                #err_kws={"zorder":100, "linewidth": 3, "color": 0.1,},
                                errwidth=errwidth,
                                errcolor=(0,0,0,1),
                                capsize=.4,
                                data=coef_dict_list_df_barplot, ax=axd[plot_key])
                    
                    for i in weight_bar.get_lines():
                        if i.get_linewidth()==errwidth:
                            #print(i.get_xdata(), i.get_ydata()[1]-i.get_ydata()[0], )
                            i.set_zorder(100)
                    
                    
#                     sns.stripplot(y="concept_name", 
#                                   x="coef",
#                                   order=list(map(shorten_concept_name,automatic_concept_info_plotorder[method_weight.split('_')[1]])),
#                                   color=(1,1,1,0),#"black",
#                                   #edgecolor=(0,0,0,1),
#                                   edgecolor=(0,0,0,1),
#                                   linewidth=0.8,
#                                   #alpha=0.3,
#                                   #zorder=-100,
#                                   size=7,
#                                   jitter=0.0, #0.3,
#                                   data=coef_dict_list_df_barplot, ax=axd[plot_key])

                    sns.stripplot(y="concept_name", 
                                  x="coef",
                                  order=list(map(shorten_concept_name,automatic_concept_info_plotorder[method_weight.split('_')[1]])),
                                  color=(1,1,1,0),#"black",
                                  #edgecolor=(0,0,0,1),
                                  edgecolor=(0,0,0,1),
                                  linewidth=0.8,
                                  #alpha=0.3,
                                  #zorder=-100,
                                  size=5.5,
                                  jitter=0.15, #0.3,
                                  data=coef_dict_list_df_barplot, ax=axd[plot_key])




    #                 for container in weight_bar.containers:
    #                     axd[plot_key].bar_label(container)                

    #                 for p in weight_bar.patches:
    #                     _x = p.get_x() + p.get_width() / 2
    #                     _y = p.get_y() + p.get_height()
    #                     value = '{:.2f}'.format(p.get_height())
    #                     weight_bar.text(_x, _y, value, ha="center")  

                    #rint(axd[plot_key].get_yticks())

    #                 bar_labels=[]
    #                 for concept_name in concept_list_curated:
    #                     temp=pd.DataFrame(coef_dict_list)
    #                     temp=temp[temp["concept_name"]==concept_name]

    #                     if temp["coef"].mean()>3:
    #                         bar_labels.append()
    #                     else:
    #                         bar_labels.append("")



    #                 print(temp, bar_labels)
    #                 axd[plot_key].bar_label(weight_bar.containers[0], labels=bar_labels, fontsize=20)
                    #print(pd.DataFrame(coef_dict_list))
                    #dsd


                    if stack_mode=="original":

                        for p, concept_name in zip(weight_bar.patches, automatic_concept_info_plotorder["curated"]):
                            _x = p.get_x() + p.get_width() / 2
                            _y = p.get_y() + p.get_height()

                            coef_dict_list_df=pd.DataFrame(coef_dict_list)
        #                     dsdsd
                            coef_dict_list_df=coef_dict_list_df[coef_dict_list_df["concept_name"]==concept_name]

                            if coef_dict_list_df["coef"].mean()>3:
                                value=f"{coef_dict_list_df['coef'].mean():.2f} (±{1.96*coef_dict_list_df['coef'].std()/ np.sqrt(len(coef_dict_list_df)) :.2f})"
                                axd[plot_key].text(2.8, _y+0.6, value, ha="center", fontsize=20, zorder=100)

                        #value = '{:.2f}'.format(p.get_height())


    #                 for c in weight_bar.containers:
    #                     c_mean=c.datavalues.mean()
    #                     c_mean=np.round(c_mean,2)
    #                     ci=1.96*c.datavalues.std()/np.sqrt(len(c.datavalues))
    #                     ci=np.round(ci,2)
                        #axd[plot_key].bar_label(c, labels=[f"{c_mean:.2f} (±{ci})"], fontsize=20)

        #             sns.boxplot(x="method", y="auc", 
        #                         data=record_all_df_filtered, 
        #                         width=0.5,
        #                         linewidth=3,
        #                         ax=axd[plot_key])
        #             sns.swarmplot(x="method", y="auc", 
        #                           color='black', 
        #                           alpha=0.8,
        #                           size=10,
        #                           data=record_all_df_filtered, ax=axd[plot_key])


    #             if task=="empty_malignancy":
    #                 axd[plot_key].text(x=-0.0, y=1.0, transform=axd[plot_key].transAxes,
    #                                                          s="D", fontsize=35, weight='bold')
    #             if task=="empty_melanoma":
    #                 axd[plot_key].text(x=-0.0, y=1.0, transform=axd[plot_key].transAxes,
    #                                                          s="E", fontsize=35, weight='bold')

#                 ["A", "B", "C", "D", "E", "F", "G", "H", "I"][subpanel_offset+7]
                if task=="empty_malignant":
                    axd[plot_key].text(x=-0., y=1.05, transform=axd[plot_key].transAxes,
                                         s=["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M"][subpanel_offset+7], fontsize=35, weight='bold')              

                    #axd[plot_key].set_ylabel('Concepts', fontsize=30)


                if task=="empty_melanoma":
                    axd[plot_key].text(x=-0.1, y=1.05, transform=axd[plot_key].transAxes,
                                         s=["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M"][subpanel_offset+8], fontsize=35, weight='bold')    

                if task=="malignant":
                    axd[plot_key].set_title("Malignancy"+task_suffix, fontsize=30, pad=20)
                if task=="melanoma":
                    axd[plot_key].set_title("Melanoma"+task_suffix, fontsize=30, pad=20)

                if task=="malignant" or task=="melanoma":
                    if stack_mode=="original":
                        if task=="malignant":
                            axd[plot_key].set_xlim(-3,3)
                        elif task=="melanoma":
                            axd[plot_key].set_xlim(-3,3)
                        else:
                            raise ValueError 

    #                 axd[plot_key].set_xlim(-(coef_dict_list_df_barplot['coef'].max()+0.2),
    #                                        (coef_dict_list_df_barplot['coef'].max()+0.2)
    #                                       )


                    axd[plot_key].axvline(x=0, ymin=0, ymax=1, color='black', alpha=0.7, linewidth=5, zorder=-5)

                    for axis in ['top','bottom','left','right']:
                        axd[plot_key].spines[axis].set_linewidth(3)                

                    axd[plot_key].set_ylabel('Area under the curve', fontsize=25)

        #             axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
        #             axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01)) 
                    if stack_mode=="original":
                        axd[plot_key].xaxis.set_major_locator(MultipleLocator(0.5))
                        axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.1))
                    elif stack_mode=="suggestion":
                        axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
                        axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.5))                    

                    axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                    axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

                    axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                    #axd[plot_key].yaxis.grid(True, which='minor', linewidth=2, alpha=0.1)

                    axd[plot_key].spines['left'].set_visible(False)
                    axd[plot_key].spines['right'].set_visible(False)
                    axd[plot_key].spines['top'].set_visible(False)                   
                    #axd[plot_key].spines['bottom'].set_visible(False)    

                    axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
                    axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

                    axd[plot_key].tick_params(
                        axis='x',          # changes apply to the x-axis
                        which='both',      # both major and minor ticks are affected
                        top=False,
                        labeltop=False,      # ticks along the bottom edge are off                    
                        bottom=True,
                        labelbottom=True,      # ticks along the bottom edge are off
                        )            

                    axd[plot_key].set_ylabel(None)
                    axd[plot_key].set_xlabel("Coefficients of linear model", fontsize=25, labelpad=5)
                    #axd[plot_key].xaxis.set_label_position('top') 
        elif stage=="skincon":  
            plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                       Paired[12][9],
                                                                                            Paired[12][3],
                                                                                            Paired[12][5],
                                                                                            Paired[12][7],
                                                                                            Paired[12][9],
                                                                                            Paired[12][11]
                                                                                            ]])               
            
            for idx2, variable in enumerate(["malignant_num_concept", "malignant_num_sample", 
                                             "melanoma_num_concept", "melanoma_num_sample", ]):
                plot_key=f"{stage}_{variable}"    

                record_all_df_skincon=pd.DataFrame(record_all_list_manual)
                record_all_df_skincon=record_all_df_skincon[record_all_df_skincon["random_seed"].isin(random_seed_range)]
                record_all_df_skincon_filtered=record_all_df_skincon.copy()#[record_all_df_skincon["is_clean"]=="clean_only"]

    #             import ipdb
    #             ipdb.set_trace()            
                record_all_df_skincon_filtered=record_all_df_skincon_filtered[
                    record_all_df_skincon_filtered["method"].isin([f"automatic_{manual_annotation_concepts_name}_monet_full", "manual_annotation_less_sample", "manual_annotation_less_concept"])]
        
                record_all_df_skincon_filtered=record_all_df_skincon_filtered[
                    ((record_all_df_skincon_filtered["alpha"].isnull())|(record_all_df_skincon_filtered["alpha"]==alpha))&
                    ((record_all_df_skincon_filtered["temp"].isnull())|(record_all_df_skincon_filtered["temp"]==temp))
                ]        


                if variable.startswith("malignant"):
                    record_all_df_skincon_filtered=record_all_df_skincon_filtered[record_all_df_skincon_filtered["task"]=="malignant"]
                elif variable.startswith("melanoma"):
                    record_all_df_skincon_filtered=record_all_df_skincon_filtered[record_all_df_skincon_filtered["task"]=="melanoma"]
                else:
                    raise ValueError


                record_all_df_skincon_filtered_ref=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]==f"automatic_{manual_annotation_concepts_name}_monet_full"]
                if variable.endswith("num_sample"):
                    record_all_df_skincon_filtered_obs=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]=="manual_annotation_less_sample"]
                    if variable.startswith("melanoma"):
                        pass
    #                     sdsd
                    record_all_df_skincon_filtered_obs_sample_prop=record_all_df_skincon_filtered_obs.groupby("num_sample_train")["num_sample_train"].mean()
                    #record_all_df_skincon_filtered_obs["num_sample_train_pseudo"]=record_all_df_skincon_filtered_obs.apply(lambda x: record_all_df_skincon_filtered_obs_sample_prop[x["num_sample_train_select"]], axis=1)
                    #jjj
                    b=sns.lineplot(x="num_sample_train", y="auc", 
                                   color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1], linewidth=4,
                                   data=record_all_df_skincon_filtered_obs, ax=axd[plot_key])
                    if b.legend_ is not None:
                        b.legend_.remove()
                    #axd[plot_key].axhline(y=ref_value, xmin=0, xmax=100000, color='red', alpha=0.7, linewidth=3, zorder=-5)


                    axd[plot_key].scatter(0, record_all_df_skincon_filtered_ref["auc"].mean(), s=300, marker='X', color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0])
                    print(stage, variable, record_all_df_skincon_filtered_ref["auc"].mean())
    #                 record_all_df_skincon_filtered_ref_matched=pd.DataFrame(list(itertools.product(record_all_df_skincon_filtered_obs_sample_prop.values, 
    #                                                     record_all_df_skincon_filtered_ref["auc"].values)),
    #                             columns=["num_sample_train_pseudo", "auc"])                
    #                 b=sns.lineplot(x="num_sample_train_pseudo", y="auc", 
    #                                color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], linewidth=4, linestyle="--",
    #                                data=record_all_df_skincon_filtered_ref_matched, ax=axd[plot_key])
    #                 if b.legend_ is not None:
    #                     b.legend_.remove()
                    #axd[plot_key].set_xlim(record_all_df_skincon_filtered_obs["num_sample_train_pseudo"].min(), record_all_df_skincon_filtered_obs["num_sample_train_pseudo"].max())

                elif variable.endswith("num_concept"):
                    #sdsdd
                    record_all_df_skincon_filtered_obs=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]=="manual_annotation_less_concept"]
                    b=sns.lineplot(x="num_concept", y="auc", 
                                   color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1], linewidth=4,
                                   data=record_all_df_skincon_filtered_obs, ax=axd[plot_key])
                    if b.legend_ is not None:
                        b.legend_.remove()


                    axd[plot_key].scatter((record_all_df_skincon_filtered_ref["clf"].iloc[0].coef_).shape[1], 
                                          record_all_df_skincon_filtered_ref["auc"].mean(), s=300, marker='X', color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0])
                    print(stage, variable, record_all_df_skincon_filtered_ref["auc"].mean())
                    #sds
                    #axd[plot_key].axhline(y=ref_value, xmin=0, xmax=100000, color='red', alpha=0.7, linewidth=3, zorder=-5)                    

    #                 record_all_df_skincon_filtered_ref_matched=pd.DataFrame(list(itertools.product(record_all_df_skincon_filtered_obs["num_concept"].unique(), 
    #                                                     record_all_df_skincon_filtered_ref["auc"].values)),
    #                             columns=["num_sample_train_pseudo", "auc"])                
    #                 b=sns.lineplot(x="num_sample_train_pseudo", y="auc", 
    #                                color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], linewidth=4, linestyle="--",
    #                                data=record_all_df_skincon_filtered_ref_matched, ax=axd[plot_key])                
    #                 if b.legend_ is not None:
    #                     b.legend_.remove()





    #             dsds  
    #             ref_value=record_all_df_skincon_filtered_ref[record_all_df_skincon_filtered_ref["task"]=="melanoma"]["auc"].mean()

    #             #

    #             #sns.lineplot(x="sample_prop", y="auc", hue="task", style="method", data=record_all_df_skincon_filtered, ax=axd[plot_key])
    #             dsdsd


    #             
    #             axd[plot_key].set_xlim(0, record_all_df_skincon_filtered["num_sample_train"].max())
    #                 #axd[plot_key].set_xlim(1-0.2, 11.5)

    #             record_all_df_skincon_filtered     


                axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
                axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
                axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)    
                axd[plot_key].tick_params(axis='y', which='major', labelsize=20)

                if variable.endswith("num_concept"):
                    if dataset_name=='clinical_fd_clean_nodup_nooverlap':
                        axd[plot_key].xaxis.set_major_locator(MultipleLocator(10))
                        axd[plot_key].xaxis.set_minor_locator(MultipleLocator(5))
                    elif dataset_name=='derm7pt_derm_nodup':
                        axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
                        axd[plot_key].xaxis.set_minor_locator(MultipleLocator(1))                    
                    else:
                        raise ValueError

                    axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                    axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
                    axd[plot_key].tick_params(axis='x', which='major', labelsize=20)


                    axd[plot_key].set_xlabel("Num. of concepts", fontsize=25)
                    axd[plot_key].set_xlim(-0.1, 49)


    #                 axd[plot_key].tick_params(
    #                     axis='y',          # changes apply to the x-axis
    #                     which='both',      # both major and minor ticks are affected
    #                     labelleft=False)                

                elif variable=="num_reference":
                    axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
                    axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                    axd[plot_key].tick_params(axis='x', which='major', labelsize=20)



                elif variable.endswith("num_sample"):
                    #axd[plot_key].set_xticks([0.05, 0.1 , 0.2 , 0.4 , 0.6 , 0.8 , 1.])
                    #.set_xticks([2,4,6,8,10])
                    dataset_name=='clinical_fd_clean_nodup_nooverlap'
                    #'derm7pt_derm_nodup', 
                    if dataset_name=='clinical_fd_clean_nodup_nooverlap':
                        if variable.startswith("malignant"):
                            axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
                            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))                
                        elif variable.startswith("malignant"):
                            axd[plot_key].xaxis.set_major_locator(MultipleLocator(100))
                            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(50))  
                    elif dataset_name=='derm7pt_derm_nodup':
                        if variable.startswith("malignant"):
                            axd[plot_key].xaxis.set_major_locator(MultipleLocator(200))
                            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))                
                        elif variable.startswith("malignant"):
                            axd[plot_key].xaxis.set_major_locator(MultipleLocator(100))
                            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(50))                      
                    else:
                        raise ValueError

                    axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                    axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
                    axd[plot_key].tick_params(axis='x', which='major', labelsize=20)

                    axd[plot_key].set_xlabel("Num. of expert-labeled samples", fontsize=25)#, labelpad=-10)

                if idx2==0:
                    axd[plot_key].set_ylabel("Area under the ROC curve", fontsize=25)
                else:
                    axd[plot_key].set_ylabel(None)    


                    #axd[plot_key].tick_params(axis='x', which='major', left=False, labelleft=False)

                if variable.startswith("malignant"):
                    #axd[plot_key].set_ylim(0.61, 0.98)
                    #axd[plot_key].set_ylim(0.531, 0.881)
                    axd[plot_key].set_ylim(0.49, 1.01)
                    axd[plot_key].set_title("Malignancy"+task_suffix, fontsize=25, pad=20)
                elif variable.startswith("melanoma"):                
                    axd[plot_key].set_title("Melanoma"+task_suffix, fontsize=25, pad=20)
                    axd[plot_key].set_ylim(0.49, 1.01)

                if variable.endswith("num_sample"):
                    if dataset_name=='clinical_fd_clean_nodup_nooverlap':
                        if variable.startswith("malignant"):
                            axd[plot_key].set_xlim(left=-50)
                        elif variable.startswith("melanoma"):    
                            axd[plot_key].set_xlim(left=-10)
                    elif dataset_name=='derm7pt_derm_nodup':
                        if variable.startswith("malignant"):
                            axd[plot_key].set_xlim(left=-30)
                        elif variable.startswith("melanoma"):    
                            axd[plot_key].set_xlim(left=-30)                    

                    else:
                        raise ValueError                        

                if variable.endswith("num_concept"):
                    if dataset_name=='clinical_fd_clean_nodup_nooverlap':
                        axd[plot_key].set_xlim(left=-1)
                        axd[plot_key].set_xlim(right=50) 
                    elif dataset_name=='derm7pt_derm_nodup':
                        axd[plot_key].set_xlim(left=-0.3)
                        axd[plot_key].set_xlim(right=7.5) 
                    else:
                        raise ValueError                    


                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(3)                     
                axd[plot_key].spines['right'].set_visible(False)
                axd[plot_key].spines['top'].set_visible(False) 

                axd[plot_key].text(x=-0.15, y=1.05, transform=axd[plot_key].transAxes,
                                         s=["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L"][subpanel_offset+1+idx2], fontsize=35, weight='bold')  




                if idx2==1:
                    legend_elements=[Line2D([], [], color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], label="MONET+CBM (SkinCon)", linestyle='None', marker='X', markersize=20),
                                     Line2D([0], [0], color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1], linewidth=10, label="Manual Label (SkinCon)")]
                    axd[plot_key].legend(handles=legend_elements, 
                                ncol=2, 
                                handlelength=3,
                                handletextpad=0.6, 
                                columnspacing=1.5,
                                fontsize=22.5,
                                loc='lower center', 
                                bbox_to_anchor=(1, -0.34)).set_zorder(100)              

                #record_all_df_filtered=record_all_df_filtered[record_all_df_filtered["method"]=="automatic_monet_full"]    

    # fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.png", bbox_inches='tight')
    # fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.jpg", bbox_inches='tight')
    # fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.svg", bbox_inches='tight')
    # fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.pdf", bbox_inches='tight')

    # fig.savefig(log_dir/"plots"/"main_cbm.png", bbox_inches='tight')
    # fig.savefig(log_dir/"plots"/"main_cbm.jpg", bbox_inches='tight')
    # fig.savefig(log_dir/"plots"/"main_cbm.svg", bbox_inches='tight')
    # fig.savefig(log_dir/"plots"/"main_cbm.pdf", bbox_inches='tight')
    # # plt.close(fig)
    return fig

In [None]:
# 4910*0.8

In [None]:
# 769*0.8

In [None]:
# pd.DataFrame(record_all_list_manual).groupby(["task", "method"]).mean()

In [None]:
dataset_name="clinical_fd_clean_nodup_nooverlap"

main_method_list=["automatic_curated_monet_full", 
                  "automatic_curated_vanilla_full",
#                   "manual_annotation_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  ]

manual_annotation_concepts_name="skincon"

automatic_concept_info={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

automatic_concept_info_plotorder={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

subpanel_offset=0
show_pval=False
stack_mode="suggestion"
# stage_list=["overview", "performance", "weight", "skincon"]
stage_list=["overview", "performance", "weight", "skincon"]
task_suffix=" (Clinical)"
method_weight="automatic_curated_monet_full"

record_all_list=record_all_list_dict["clinical_fd_clean_nodup_nooverlap"]
record_all_list=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_automatic_manual"]+\
[record for record in record_all_list_dict["clinical_fd_clean_nodup_nooverlap"] if "resnet" in record["method"]]


record_all_list=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_seed2040"]
record_all_list_manual=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_manualmatch_seed2040"]

# record_all_list_manual=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_manualmatch"]

fig=plot_cbm_results(
    dataset_name=dataset_name,
    main_method_list=main_method_list,
    manual_annotation_concepts_name=manual_annotation_concepts_name,
    automatic_concept_info=automatic_concept_info,
    automatic_concept_info_plotorder=automatic_concept_info_plotorder,
    record_all_list=record_all_list,
    record_all_list_manual=record_all_list_manual,
    subpanel_offset=subpanel_offset,
    stage_list=stage_list,
    method_weight=method_weight,    
    show_pval=show_pval,
    stack_mode=stack_mode,
    task_suffix=task_suffix,
    random_seed_range=list(range(21,21+20)),
    alpha=0.001,
    temp=0.02,
    skincon_height=2,
    performance_height=3.5,
    hspace=0.3,
    swarm_size=8,
)
print()


# fig.savefig(log_dir/"plots"/"cbm_clinical_all.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_clinical_all.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_clinical_all.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_clinical_all.pdf", bbox_inches='tight')

In [None]:
dataset_name="clinical_fd_clean_nodup_nooverlap"

main_method_list=["automatic_curated_monet_full", 
                  "automatic_curated_vanilla_full",
#                   "manual_annotation_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  ]

manual_annotation_concepts_name="skincon"

automatic_concept_info={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

automatic_concept_info_plotorder={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

subpanel_offset=0
show_pval=False
stack_mode="suggestion"
# stage_list=["overview", "performance", "weight", "skincon"]
stage_list=["overview", "performance", "weight", "skincon"]
task_suffix=" (Clinical)"
method_weight="automatic_curated_monet_full"

record_all_list=record_all_list_dict["clinical_fd_clean_nodup_nooverlap"]
record_all_list=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_automatic_manual"]+\
[record for record in record_all_list_dict["clinical_fd_clean_nodup_nooverlap"] if "resnet" in record["method"]]


record_all_list=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_seed2040"]
record_all_list_manual=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_manualmatch_seed2040"]

# record_all_list_manual=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_manualmatch"]

fig=plot_cbm_results(
    dataset_name=dataset_name,
    main_method_list=main_method_list,
    manual_annotation_concepts_name=manual_annotation_concepts_name,
    automatic_concept_info=automatic_concept_info,
    automatic_concept_info_plotorder=automatic_concept_info_plotorder,
    record_all_list=record_all_list,
    record_all_list_manual=record_all_list_manual,
    subpanel_offset=subpanel_offset,
    stage_list=stage_list,
    method_weight=method_weight,    
    show_pval=show_pval,
    stack_mode=stack_mode,
    task_suffix=task_suffix,
    random_seed_range=list(range(21,21+20)),
    alpha=0.001,
    temp=0.02,
    skincon_height=2,
    performance_height=3.5,
    hspace=0.3,
    swarm_size=8,
)
print()


# fig.savefig(log_dir/"plots"/"cbm_clinical_all.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_clinical_all.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_clinical_all.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_clinical_all.pdf", bbox_inches='tight')

In [None]:
def val_to_exp(val):
    if np.isnan(val):
        return val
    mantissa, exponent=f"{val:.3e}".split('e')
    pvalue_str_new = f"{mantissa} × 10^{{{int(exponent)}}}"
    return pvalue_str_new
fig.applymap(val_to_exp)

In [None]:
fig

In [None]:
for i in fig.collections:
    print(i)

In [None]:
len(fig.collections)

In [None]:
dir(fig)

In [None]:
pd.DataFrame(record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_seed2040"]).groupby(["task", "method"]).mean()

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"].loc[
        variable_dict["clinical_fd_clean_nodup_nooverlap"]["valid_idx_melanoma"]
]["md5hash"].isnull().value_counts()

In [None]:
variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"].loc[
        variable_dict["clinical_fd_clean_nodup_nooverlap"]["valid_idx_melanoma"]&
    (~variable_dict["clinical_fd_clean_nodup_nooverlap"]["metadata_all"]["skincon_Cyst"].isnull())
]["md5hash"].isnull().value_counts()

In [None]:
pd.DataFrame(record_all_list_dict["cbm_isic_barcelona_vienna"]).groupby(["task", "method"]).mean()

In [None]:
12271-6021

In [None]:
6021, 6250

In [None]:
10011-1824

In [None]:
1824 vs 8187

In [None]:
pd.DataFrame(record_all_list_dict["cbm_isic_vienna_barcelona"]).groupby(["task", "method"]).mean()

In [None]:
record_all_list_dict.keys()

In [None]:
    record_all_df_perf_filtered_pvalue=record_all_df_perf_filtered.groupby("task")\
    .apply(lambda x: x.groupby("method")\
    .apply(lambda y: 
    scipy.stats.ttest_rel(y.set_index('random_seed')["auc"], 
    x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"], alternative="less").pvalue)).T
    import ipdb
    ipdb.set_trace()

In [None]:
record_all_df_perf_filtered.groupby("task").apply(lambda x: x.groupby("random_seed").apply(lambda x: x.sort_values("auc").iloc[-1]["method"]).value_counts()).T





In [None]:
cbm_isic_nodup_nooverlap_all_automatic_with_validation

In [None]:
pd.DataFrame(record_all_list)

In [None]:
record_all_list_dict.keys()

In [None]:
pd.DataFrame(record_all_list)["random_seed"].value_counts()

In [None]:
record_all_list_dict.keys()

In [None]:
pd.DataFrame(record_all_list).duplicated()

In [None]:
pd.DataFrame(record_all_list)

In [None]:
pd.DataFrame(record_all_list).groupby(["task", "method", "random_seed"]).apply(len)

In [None]:
plot_cbm_results??

In [None]:
dataset_name="isic_nodup_nooverlap"

main_method_list=["automatic_curated_monet_full", 
                  "automatic_curated_vanilla_full",
#                   "manual_annotation_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  ]

manual_annotation_concepts_name="skincon"

automatic_concept_info={
    "skincon": skincon_cols,
    "derm7pt": ['derm7ptconcept_pigment network',
                    'derm7ptconcept_regression structure',
                    'derm7ptconcept_pigmentation',
                    'derm7ptconcept_blue whitish veil',
                    'derm7ptconcept_vascular structures',
                    'derm7ptconcept_streaks',
                    'derm7ptconcept_dots and globules'],    
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

automatic_concept_info_plotorder={
    "skincon": skincon_cols,
    "derm7pt": ['derm7ptconcept_pigment network',
                    'derm7ptconcept_regression structure',
                    'derm7ptconcept_pigmentation',
                    'derm7ptconcept_blue whitish veil',
                    'derm7ptconcept_vascular structures',
                    'derm7ptconcept_streaks',
                    'derm7ptconcept_dots and globules'],    
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

subpanel_offset=4
show_pval=False
stack_mode="suggestion"
stage_list=["performance", "weight"]
# stage_list=["weight"]
task_suffix=" (Dermoscopic)"
method_weight="automatic_curated_monet_full"
# method_weight="automatic_derm7pt_monet_full"

record_all_list=\
[record for record in record_all_list_dict["cbm_isic_nodup_nooverlap_all"] if "resnet" in record["method"]]+\
[record for record in record_all_list_dict["cbm_isic_nodup_nooverlap_all_more"] if "resnet" in record["method"]]+\
[record for record in record_all_list_dict["cbm_isic_nodup_nooverlap_all_automatic_with_validation"] if record["random_seed"] in list(range(1,20+1))]

# ss
#+\
#




fig=plot_cbm_results(
    dataset_name=dataset_name,
    main_method_list=main_method_list,
    manual_annotation_concepts_name=manual_annotation_concepts_name,
    automatic_concept_info=automatic_concept_info,
    automatic_concept_info_plotorder=automatic_concept_info_plotorder,
    record_all_list=record_all_list,
    record_all_list_manual=None,
    subpanel_offset=subpanel_offset,
    stage_list=stage_list,
    method_weight=method_weight,
    show_pval=show_pval,
    stack_mode=stack_mode,
    task_suffix=task_suffix,
    alpha=0.001,
    temp=0.02,
    skincon_height=2,
    hspace=0.3,
    swarm_size=5,
    improve_small_range=True,
    random_seed_range=list(range(1,1+20))
)
print()

# fig.savefig(log_dir/"plots"/"cbm_dermoscopic_all_weight.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_dermoscopic_all_weight.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_dermoscopic_all_weight.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_dermoscopic_all_weight.pdf", bbox_inches='tight')

In [None]:
fig

In [None]:
0.842783
                                std                                       0.001458
                                max                                       0.844596
                                q3                                        0.843575
                                median                                    0.842643
                                q1                                        0.841852
                                min                                        0.84125
                                conf      (0.8381433142650633, 0.847422858924

In [None]:
0.842783-1.96*0.001458/np.sqrt(19)

In [None]:
 mean                                      0.842783
                                std                                       0.001458
                                max                                       0.844596
                                q3                                        0.843575
                                median                                    0.842643
                                q1                                        0.841852
                                min                                        0.84125
                                conf      (0.8381433142650633, 0.8474228589244678)

In [None]:
scipy.stats.t.interval(0.95, )

In [None]:
70515-9999

In [None]:
pd.DataFrame(record_all_list_dict["cbm_isic_nodup_nooverlap_all"]).groupby(["task", "method"]).mean()

In [None]:
43678-5900

In [None]:
dataset_name="cbm_isic_barcelona_vienna"

main_method_list=["automatic_curated_monet_full", 
#                   "manual_annotation_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  "automatic_curated_vanilla_full"]

manual_annotation_concepts_name="skincon"

automatic_concept_info={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

automatic_concept_info_plotorder={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

subpanel_offset=-5
show_pval=True
stack_mode="suggestion"
stage_list=["weight"]
task_suffix=""
method_weight="automatic_curated_monet_full"

record_all_list=record_all_list_dict[dataset_name]

fig=plot_cbm_results(
    dataset_name=dataset_name,
    main_method_list=main_method_list,
    manual_annotation_concepts_name=manual_annotation_concepts_name,
    automatic_concept_info=automatic_concept_info,
    automatic_concept_info_plotorder=automatic_concept_info_plotorder,
    record_all_list=record_all_list,
    record_all_list_manual=None,
    subpanel_offset=subpanel_offset,
    stage_list=stage_list,
    method_weight=method_weight,
    show_pval=show_pval,
    stack_mode=stack_mode,
    task_suffix=task_suffix,
    alpha=0.001,
    temp=0.02,
    skincon_height=2,
    hspace=0.3,    
    random_seed_range=list(range(1,1+20))
)
print()

# fig.savefig(log_dir/"plots"/"cbm_dermoscopic_barcelona_vienna_weight.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_dermoscopic_barcelona_vienna_weight.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_dermoscopic_barcelona_vienna_weight.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_dermoscopic_barcelona_vienna_weight.pdf", bbox_inches='tight')

In [None]:
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_barcelona_vienna_weight.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_barcelona_vienna_weight.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_barcelona_vienna_weight.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_barcelona_vienna_weight.pdf", bbox_inches='tight')

In [None]:
fig

In [None]:
record_all_list_temp_df=pd.DataFrame(record_all_list)
record_all_list_temp_df["auc_valid"]=record_all_list_temp_df["auc_valid"].map(lambda x: x.item() if isinstance(x, torch.Tensor) else x)
record_all_list_temp_df\
.fillna(-9)\
.groupby(["task", "method", "alpha"])\
.mean()

In [None]:
dataset_name="cbm_isic_vienna_barcelona"

main_method_list=["automatic_curated_monet_full", 
#                   "manual_annotation_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  "automatic_curated_vanilla_full"]

manual_annotation_concepts_name="skincon"

automatic_concept_info={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

automatic_concept_info_plotorder={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

subpanel_offset=-1
show_pval=True
stack_mode="suggestion"
stage_list=["weight"]
task_suffix=""
method_weight="automatic_curated_monet_full"

record_all_list=record_all_list_dict[dataset_name]

fig=plot_cbm_results(
    dataset_name=dataset_name,
    main_method_list=main_method_list,
    manual_annotation_concepts_name=manual_annotation_concepts_name,
    automatic_concept_info=automatic_concept_info,
    automatic_concept_info_plotorder=automatic_concept_info_plotorder,
    record_all_list=record_all_list,
    record_all_list_manual=None,
    subpanel_offset=subpanel_offset,
    stage_list=stage_list,
    method_weight=method_weight,
    show_pval=show_pval,
    stack_mode=stack_mode,
    task_suffix=task_suffix,
    alpha=0.001,
    temp=0.02,
    skincon_height=2,
    hspace=0.3,    
    random_seed_range=list(range(1,1+20))
)
print()

# fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_weight.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_weight.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_weight.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_weight.pdf", bbox_inches='tight')

In [None]:
plot_cbm_results??

In [None]:
df_temp=pd.DataFrame(record_all_list)
df_temp

In [None]:
df_temp=pd.DataFrame(record_all_list)
df_temp=df_temp[((df_temp["temp"]==0.02)&(df_temp["alpha"]==0.001))|
                ((df_temp["temp"].isnull())&(df_temp["alpha"].isnull()))
               ]

df_temp=df_temp[df_temp["method"].isin(main_method_list)]
print(df_temp)
print(df_temp.groupby(["task","method"]).mean())

In [None]:
df_temp

In [None]:
df_temp["method"]

In [None]:
main_method_list

In [None]:
fig

In [None]:
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_weight.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_weight.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_weight.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_weight.pdf", bbox_inches='tight')

In [None]:
record_all_list_temp_df=pd.DataFrame(record_all_list)
record_all_list_temp_df["auc_valid"]=record_all_list_temp_df["auc_valid"].map(lambda x: x.item() if isinstance(x, torch.Tensor) else x)
record_all_list_temp_df\
.fillna(-9)\
.groupby(["task", "method", "alpha"])\
.mean()

In [None]:
def plot_concept_num(
    target_method,
    variable_method_list,
    subpanel_offset,
    record_all_list,
    task_suffix,
    panel_title,
    alpha,
    with_validation,
    with_validation_label,
):
    
    
    plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                               Paired[12][7],
                                                                                    Paired[12][3],
                                                                                    Paired[12][5],
                                                                                    
                                                                                    Paired[12][9],
                                                                                    Paired[12][11]
                                                                                    ]])     
    
    fig = plt.figure(figsize=(3*10, 3*(2.5 + 0.35*0)))

    box1 = gridspec.GridSpec(1,1,
                             height_ratios=[2.5],
                             wspace=0.0,
                             hspace=0.35)  

    axd={}
    for idx1, stage in enumerate(["conceptcuration"]):     
        if stage=="conceptcuration":
            box2 = gridspec.GridSpecFromSubplotSpec(1, 4,
                            subplot_spec=box1[idx1], 
                            width_ratios=[0.1, 1, 0.1, 1], wspace=0., hspace=0.)        
            for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
    #             elif investigation_type=="statistics":
                plot_key=f"{stage}_{task}"
                ax=plt.Subplot(fig, box2[idx2])
                fig.add_subplot(ax)
                axd[plot_key]=ax  

    for plot_key in axd.keys():
        if 'overview' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0) 

        if 'empty' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0)        


    for idx1, stage in enumerate(["conceptcuration"]):    
        if stage=="conceptcuration":
            for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
    #             elif investigation_type=="statistics":
                plot_key=f"{stage}_{task}"

                if task=="malignant" or task=="melanoma":
                    record_all_df_conceptcuration=pd.DataFrame(record_all_list)


                    record_all_df_conceptcuration_filtered=record_all_df_conceptcuration[
                        (record_all_df_conceptcuration["task"]==task)&
                        (record_all_df_conceptcuration["alpha"]==alpha)

                    ]



                    record_all_df_conceptcuration_filtered_target=record_all_df_conceptcuration_filtered[record_all_df_conceptcuration_filtered["method"]==target_method]


                    for variable_method_idx, variable_method in enumerate(variable_method_list):

                        record_all_df_conceptcuration_filtered_variable=record_all_df_conceptcuration_filtered[record_all_df_conceptcuration_filtered["method"]==variable_method]

        #                 b=sns.lineplot(x="num_concept", y="auc", 
        #                                color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1], linewidth=4,
        #                                data=record_all_df_conceptcuration_filtered_variable, 
        #                                ax=axd[plot_key])                 

                        #record_all_df_conceptcuration_filtered_variable_auc=record_all_df_conceptcuration_filtered_variable.copy().drop(columns=["auc_valid"])
                        #record_all_df_conceptcuration_filtered_variable_auc_valid=record_all_df_conceptcuration_filtered_variable.copy().drop(columns=["auc"])



                        

    
    
                        record_all_df_conceptcuration_filtered_variable_auc=record_all_df_conceptcuration_filtered_variable.copy()\
                        .drop(columns=["auc_valid"])
                        record_all_df_conceptcuration_filtered_variable_auc["split"]="test"    
    
                        if with_validation:
            
                            record_all_df_conceptcuration_filtered_variable_auc_valid=record_all_df_conceptcuration_filtered_variable.copy()\
                            .drop(columns=["auc"])\
                            .rename(columns={"auc_valid": "auc"})
                            record_all_df_conceptcuration_filtered_variable_auc_valid["split"]="valid"            

                            b=sns.lineplot(x="num_concept", y="auc", style="split",
                                           color=plt.rcParams["axes.prop_cycle"].by_key()["color"][variable_method_idx], linewidth=4,
                                           data=pd.concat([record_all_df_conceptcuration_filtered_variable_auc,
                                                           record_all_df_conceptcuration_filtered_variable_auc_valid
                                                          ]), 
                                           ax=axd[plot_key])   
                        else:
                            b=sns.lineplot(x="num_concept", y="auc", style="split",
                                           color=plt.rcParams["axes.prop_cycle"].by_key()["color"][variable_method_idx], linewidth=4,
                                           data=pd.concat([record_all_df_conceptcuration_filtered_variable_auc
                                                          ]), 
                                           ax=axd[plot_key])                             

                        if b.legend_ is not None:
                            b.legend_.remove()    

    #                 axd[plot_key].scatter((record_all_df_conceptcuration_filtered_target["clf"].iloc[0].coef_).shape[1], 
    #                                       record_all_df_conceptcuration_filtered_target["auc_valid"].mean(), 
    #                                       s=400, marker='o', 
    #                                       color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0])                      

    #                 axd[plot_key].scatter((record_all_df_conceptcuration_filtered_target["clf"].iloc[0].coef_).shape[1], 
    #                                       record_all_df_conceptcuration_filtered_target["auc"].mean(), 
    #                                       s=400, marker='X', 
    #                                       color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1])                    





    # #                 sns.lineplot(
    # #                 x="num_concept",
    # #                 y="auc",
    # #                 data=record_all_df_conceptcuration_filtered_variable
    # #                 )
    # #                 print(record_all_df_conceptcuration_filtered_target)                






    #                 record_all_df_perf=pd.DataFrame(record_all_list)
    #                 record_all_df_perf=record_all_df_perf[record_all_df_perf["random_seed"]<=20]
    #                 record_all_df_perf_filtered=record_all_df_perf.copy()
    #                 record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["method"].isin(main_method_list)]
    #                 record_all_df_perf_filtered=record_all_df_perf_filtered[(~record_all_df_perf_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))|\
    #                                                              ((record_all_df_perf_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))&(record_all_df_perf_filtered["alpha"]==alpha)&(record_all_df_perf_filtered["temp"]==temp))]

    #                 if task=="malignant":
    # #                     dsds
    #                     record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["task"]=="malignant"]
    #                     axd[plot_key].set_ylim(0.49, 1.01)
    #                 elif task=="melanoma":
    #                     record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["task"]=="melanoma"]
    #                     axd[plot_key].set_ylim(0.49, 1.01)
    #                 else:
    #                     raise ValueError


    #                 print(task, record_all_df_perf_filtered.groupby("method")["auc"].apply(lambda x: {"mean": x.mean(),
    #                                                                       "std": x.std(),
    #                                                                       "q3": x.quantile(q=0.75),                                                                      
    #                                                                       "median": x.median(),
    #                                                                       "q1": x.quantile(q=0.25),
    #                                                                      }))

                    if task=="malignant":
                        axd[plot_key].set_title("Malignancy"+task_suffix, fontsize=30)
                    if task=="melanoma":
                        axd[plot_key].set_title("Melanoma"+task_suffix, fontsize=30)

                    axd[plot_key].set_ylim(0.49, 1.01)
                    axd[plot_key].set_xlim(-0.5, 50)

                    for axis in ['top','bottom','left','right']:
                        axd[plot_key].spines[axis].set_linewidth(3)                

                    if idx2==1: 
                        axd[plot_key].set_ylabel('Area under the ROC curve', fontsize=25)
                    if idx2==3: 
                        axd[plot_key].set_ylabel('')                    

                    axd[plot_key].set_xlabel('Num. of concepts', fontsize=25)

                    axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
                    axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
                    axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                    axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
                    #axd[plot_key].yaxis.grid(True, which='minor', linewidth=2, alpha=0.1)

                    axd[plot_key].spines['right'].set_visible(False)
                    axd[plot_key].spines['top'].set_visible(False)                      

                    axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
                    axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   


                    axd[plot_key].xaxis.set_major_locator(MultipleLocator(10))
                    axd[plot_key].xaxis.set_minor_locator(MultipleLocator(5))                
    #                 axd[plot_key].tick_params(
    #                     axis='x',          # changes apply to the x-axis
    #                     which='both',      # both major and minor ticks are affected
    #                     bottom=False,
    #                     labelbottom=False,      # ticks along the bottom edge are off
    #                     )            

    #                 axd[plot_key].set_xlabel(None)

                if task=="empty_malignant":
#                     axd[plot_key].text(x=-0.05, y=1.02, transform=axd[plot_key].transAxes,
#                                          s=["A", "B", "C", "D", "E", "F", "G", "H", "I", ][subpanel_offset+0], fontsize=35, weight='bold') 
                    axd[plot_key].text(x=-0.05, y=1.1, transform=axd[plot_key].transAxes,
                                         s=["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M"][subpanel_offset+0], fontsize=35, weight='bold')                 
                    axd[plot_key].text(x=0.27, y=1.1, transform=axd[plot_key].transAxes,
                                         s=panel_title, fontsize=35)

                if task=="malignant":
                    pass

    #                 legend_elements=[Patch(facecolor=plt.rcParams["axes.prop_cycle"].by_key()['color'][3*method_idx], 
    #                                        edgecolor="black", linewidth=2, 
    #                                        label=method_name) for method_idx, method_name in enumerate(["Curated", "SkinCon"])]+\
                    legend_elements=[

                        Line2D([0], [0], 
                            color=plt.rcParams["axes.prop_cycle"].by_key()['color'][variable_method_idx],
                            linewidth=7,
                            linestyle='-',
                            label={"curated": "Curated",
                                  "skincon": "SkinCon",
                                   "derm7pt": "Derm7pt",
                                  }[variable_method.split('_')[1]]
                              
                              ) for variable_method_idx, variable_method in enumerate(variable_method_list)

                    ]
        
                    if with_validation:
                       legend_elements+= [Line2D([0], [0], 
                                                color='grey',
                                                linewidth=7,
                                                linestyle='--',
                                                label=with_validation_label[0]) ,
                                            Line2D([0], [0], 
                                                color='grey',
                                                linewidth=7,
                                                linestyle='-',
                                                label=with_validation_label[1]),
                                    ] 




                    if len(legend_elements)==2:
                        axd[plot_key].legend(handles=legend_elements, 
                                    ncol=1, 
                                    handlelength=2.5,
                                    handletextpad=0.4, 
                                    columnspacing=4,
                                    fontsize=23,
                                    loc='lower center', bbox_to_anchor=(1.05, -0.3))  
                        
                    else:
                        
                        
                        axd[plot_key].legend(handles=legend_elements, 
                                    ncol=2, 
                                    handlelength=2.5,
                                    handletextpad=0.4, 
                                    columnspacing=4,
                                    fontsize=23,
                                    loc='lower center', bbox_to_anchor=(1.05, -0.32))                   

                    #axd[plot_key].set_ylabel('Concepts', fontsize=30)


                if task=="empty_melanoma":
                    pass
#                     axd[plot_key].text(x=-0.0, y=1.02, transform=axd[plot_key].transAxes,
#                                          s=["A", "B", "C", "D", "E", "F", "G", "H", "I"][subpanel_offset+1], fontsize=35, weight='bold')  
#                     axd[plot_key].text(x=-0.0, y=1.1, transform=axd[plot_key].transAxes,
#                                          s=["A", "B", "C", "D", "E", "F", "G", "H", "I"][subpanel_offset+1], fontsize=35, weight='bold')  


    return fig
#     fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_num_concepts.png", bbox_inches='tight')
#     fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_num_concepts.jpg", bbox_inches='tight')
#     fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_num_concepts.svg", bbox_inches='tight')
#     fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_num_concepts.pdf", bbox_inches='tight')    

In [None]:
target_method="automatic_curated_monet_full"
variable_method_list=[
    "automatic_curated_monet_less_concept",
    "automatic_skincon_monet_less_concept"]
variable_method="automatic_curated_monet_less_concept"
subpanel_offset=0
record_all_list=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_seed2040"]
alpha=0.001

fig=plot_concept_num(
    target_method=target_method,
    variable_method_list=variable_method_list,
    subpanel_offset=subpanel_offset,
    record_all_list=record_all_list,
    task_suffix="",
    panel_title=". Clinical (Fizpatrick17k+DDI)",  
    alpha=alpha,
#     with_validation=False,  
#     with_validation_label=None,  
    with_validation=True,  
    with_validation_label=["Validation", "Test"],      
)

In [None]:
target_method="automatic_curated_monet_full"
variable_method_list=[
    "automatic_curated_monet_less_concept",
    "automatic_skincon_monet_less_concept"]
variable_method="automatic_curated_monet_less_concept"
subpanel_offset=0
record_all_list=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_seed2040"]
alpha=0.001

fig=plot_concept_num(
    target_method=target_method,
    variable_method_list=variable_method_list,
    subpanel_offset=subpanel_offset,
    record_all_list=record_all_list,
    task_suffix="",
    panel_title=". Clinical (Fizpatrick17k+DDI)",  
    alpha=alpha,
#     with_validation=False,  
#     with_validation_label=None,  
    with_validation=True,  
    with_validation_label=["Validation", "Test"],      
)
pd.DataFrame(record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_seed2040"])["method"].value_counts()

fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.pdf", bbox_inches='tight') 


# target_method="automatic_curated_monet_full"
# variable_method_list=[
#     "automatic_curated_monet_less_concept",
#     "automatic_skincon_monet_less_concept"]
# variable_method="automatic_curated_monet_less_concept"
# subpanel_offset=0
# record_all_list=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_automatic_manual"]
# alpha=0.001

# fig=plot_concept_num(
#     target_method=target_method,
#     variable_method_list=variable_method_list,
#     subpanel_offset=subpanel_offset,
#     record_all_list=record_all_list,
#     task_suffix="",
#     panel_title=". Clinical (Fizpatrick17k+DDI)",  
#     alpha=alpha,
# #     with_validation=False,  
# #     with_validation_label=None,  
#     with_validation=True,  
#     with_validation_label=["Validation", "Test"],      
# )
# pd.DataFrame(record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_automatic_manual"])["method"].value_counts()

# fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.pdf", bbox_inches='tight') 


target_method="automatic_curated_monet_full"
variable_method_list=[
    "automatic_curated_monet_less_concept",
    "automatic_skincon_monet_less_concept"]
variable_method="automatic_curated_monet_less_concept"
subpanel_offset=1
record_all_list=record_all_list_dict["cbm_isic_nodup_nooverlap_all_automatic_with_validation"]
alpha=0.001

fig=plot_concept_num(
    target_method=target_method,
    variable_method_list=variable_method_list,
    subpanel_offset=subpanel_offset,
    record_all_list=record_all_list,
    task_suffix="",
    panel_title=". Dermoscopic (ISIC)",    
    alpha=alpha,
#     with_validation=False,  
#     with_validation_label=None,
    with_validation=True,  
    with_validation_label=["Validation", "Test"],    
)
pd.DataFrame(record_all_list_dict["cbm_isic_nodup_nooverlap_all_automatic_with_validation"])["method"].value_counts()

fig.savefig(log_dir/"plots"/"cbm_num_concepts_isic.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_isic.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_isic.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_isic.pdf", bbox_inches='tight') 

target_method="automatic_curated_monet_full"
variable_method_list=[
    "automatic_curated_monet_less_concept",
    "automatic_skincon_monet_less_concept",

]
variable_method="automatic_curated_monet_less_concept"
subpanel_offset=2
record_all_list=record_all_list_dict["cbm_isic_barcelona_vienna"]
alpha=0.001

fig=plot_concept_num(
    target_method=target_method,
    variable_method_list=variable_method_list,
    subpanel_offset=subpanel_offset,
    record_all_list=record_all_list,
    task_suffix="",
    panel_title=". Dermoscopic (ISIC: Hosp. Barcelona → Med U. Vienna)",
    alpha=alpha,
    with_validation=True,
    with_validation_label=["Validation (Hosp. Barcelona)",
                           "Test (Med U. Vienna)"
                          ]
)
pd.DataFrame(record_all_list_dict["cbm_isic_barcelona_vienna"])["method"].value_counts()

fig.savefig(log_dir/"plots"/"cbm_num_concepts_barcelona_vienna.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_barcelona_vienna.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_barcelona_vienna.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_barcelona_vienna.pdf", bbox_inches='tight') 

target_method="automatic_curated_monet_full"
variable_method_list=[
    "automatic_curated_monet_less_concept",
    "automatic_skincon_monet_less_concept"]
variable_method="automatic_curated_monet_less_concept"
subpanel_offset=3
record_all_list=record_all_list_dict["cbm_isic_vienna_barcelona"]
alpha=0.001

fig=plot_concept_num(
    target_method=target_method,
    variable_method_list=variable_method_list,
    subpanel_offset=subpanel_offset,
    record_all_list=record_all_list,
    task_suffix="",
    panel_title=". Dermoscopic (ISIC: Med U. Vienna → Hosp. Barcelona)",    
    alpha=alpha,
    with_validation=True,    
    with_validation_label=["Validation (Med U. Vienna)",
                           "Test (Hosp. Barcelona)"
                          ]    
)
pd.DataFrame(record_all_list_dict["cbm_isic_vienna_barcelona"])["method"].value_counts()

fig.savefig(log_dir/"plots"/"cbm_num_concepts_vienna_barcelona.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_vienna_barcelona.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_vienna_barcelona.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_vienna_barcelona.pdf", bbox_inches='tight') 

In [None]:
pd.DataFrame(record_all_list_dict["cbm_isic_nodup_nooverlap_all_automatic_with_validation"])

In [None]:
record_all_list_dict.keys()

In [None]:
pd.DataFrame(record_all_list)\
.fillna(-9)\
.groupby(["task", "method", "alpha"])\
.apply(len)

# separate

# clinical performance & weight

In [None]:
dataset_name="clinical_fd_clean_nodup_nooverlap"

main_method_list=["automatic_curated_monet_full", 
                  "automatic_curated_vanilla_full",
#                   "manual_annotation_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  ]

manual_annotation_concepts_name="skincon"

automatic_concept_info={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

automatic_concept_info_plotorder={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

subpanel_offset=0
show_pval=False
stack_mode="suggestion"
stage_list=["overview", "performance", "weight", "skincon"]
# stage_list=["overview", "performance", "weight"]
task_suffix=" (Clinical)"
method_weight="automatic_curated_monet_full"

record_all_list=record_all_list_dict["clinical_fd_clean_nodup_nooverlap"]
record_all_list=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation"]+\
[record for record in record_all_list_dict["clinical_fd_clean_nodup_nooverlap"] if "resnet" in record["method"]]


record_all_list_manual=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_manualmatch"]

fig=plot_cbm_results(
    dataset_name=dataset_name,
    main_method_list=main_method_list,
    manual_annotation_concepts_name=manual_annotation_concepts_name,
    automatic_concept_info=automatic_concept_info,
    automatic_concept_info_plotorder=automatic_concept_info_plotorder,
    record_all_list=record_all_list,
    record_all_list_manual=record_all_list_manual,
    subpanel_offset=subpanel_offset,
    stage_list=stage_list,
    method_weight=method_weight,    
    show_pval=show_pval,
    stack_mode=stack_mode,
    task_suffix=task_suffix,
    alpha=0.001,
    temp=0.02,
)
print()


fig.savefig(log_dir/"plots"/"cbm_clinical_all.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_clinical_all.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_clinical_all.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_clinical_all.pdf", bbox_inches='tight')

In [None]:
dataset_name="derm7pt_derm_nodup"

main_method_list=["automatic_curated_monet_full", 
#                   "manual_annotation_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  "automatic_curated_vanilla_full"]

manual_annotation_concepts_name="derm7pt"

automatic_concept_info={
    "skincon": skincon_cols,
    "derm7pt": ['derm7ptconcept_pigment network',
                    'derm7ptconcept_regression structure',
                    'derm7ptconcept_pigmentation',
                    'derm7ptconcept_blue whitish veil',
                    'derm7ptconcept_vascular structures',
                    'derm7ptconcept_streaks',
                    'derm7ptconcept_dots and globules'],    
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

automatic_concept_info_plotorder={
    "skincon": skincon_cols,
    "derm7pt": ['derm7ptconcept_pigment network',
                    'derm7ptconcept_regression structure',
                    'derm7ptconcept_pigmentation',
                    'derm7ptconcept_blue whitish veil',
                    'derm7ptconcept_vascular structures',
                    'derm7ptconcept_streaks',
                    'derm7ptconcept_dots and globules'],    
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

subpanel_offset=-1
show_pval=True
stack_mode="suggestion"
stage_list=["skincon"]
task_suffix=""

record_all_list=record_all_list_dict[dataset_name]

plot_cbm_results(
    dataset_name=dataset_name,
    main_method_list=main_method_list,
    manual_annotation_concepts_name=manual_annotation_concepts_name,
    automatic_concept_info=automatic_concept_info,
    automatic_concept_info_plotorder=automatic_concept_info_plotorder,
    record_all_list=record_all_list,
    subpanel_offset=subpanel_offset,
    stage_list=stage_list,
    show_pval=show_pval,
    stack_mode=stack_mode,
    task_suffix=task_suffix,
    alpha=0.001,
    temp=0.02,
)
print()

In [None]:
variable_dict.keys()

In [None]:
record_all_list_dict.keys()

In [None]:
3928.000000+982.0

In [None]:
1129, (901.05+227.95)

In [None]:
4910-1129

In [None]:
615+

In [None]:
pd.DataFrame(record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_automatic_manual"])\
.groupby(["task", "method"]).mean()

# isic performance

In [None]:
dataset_name="isic_nodup_nooverlap"

main_method_list=["automatic_curated_monet_full", 
                  "automatic_curated_vanilla_full",
#                   "manual_annotation_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  ]

manual_annotation_concepts_name="skincon"

automatic_concept_info={
    "skincon": skincon_cols,
    "derm7pt": ['derm7ptconcept_pigment network',
                    'derm7ptconcept_regression structure',
                    'derm7ptconcept_pigmentation',
                    'derm7ptconcept_blue whitish veil',
                    'derm7ptconcept_vascular structures',
                    'derm7ptconcept_streaks',
                    'derm7ptconcept_dots and globules'],    
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

automatic_concept_info_plotorder={
    "skincon": skincon_cols,
    "derm7pt": ['derm7ptconcept_pigment network',
                    'derm7ptconcept_regression structure',
                    'derm7ptconcept_pigmentation',
                    'derm7ptconcept_blue whitish veil',
                    'derm7ptconcept_vascular structures',
                    'derm7ptconcept_streaks',
                    'derm7ptconcept_dots and globules'],    
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

subpanel_offset=0
show_pval=True
stack_mode="suggestion"
stage_list=["performance", "weight"]
task_suffix=" (Dermoscopic)"
method_weight="automatic_curated_monet_full"
# method_weight="automatic_derm7pt_monet_full"

record_all_list=record_all_list_dict["cbm_isic_nodup_nooverlap_all"]

fig=plot_cbm_results(
    dataset_name=dataset_name,
    main_method_list=main_method_list,
    manual_annotation_concepts_name=manual_annotation_concepts_name,
    automatic_concept_info=automatic_concept_info,
    automatic_concept_info_plotorder=automatic_concept_info_plotorder,
    record_all_list=record_all_list,
    record_all_list_manual=None,
    subpanel_offset=subpanel_offset,
    stage_list=stage_list,
    method_weight=method_weight,
    show_pval=show_pval,
    stack_mode=stack_mode,
    task_suffix=task_suffix,
    alpha=0.001,
    temp=0.02,
)
print()

fig.savefig(log_dir/"plots"/"cbm_dermoscopic_all_weight.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_all_weight.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_all_weight.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_all_weight.pdf", bbox_inches='tight')

In [None]:
record_all_list_dict.keys()

In [None]:
shorten_concept_name?

In [None]:
automatic_concept_info_plotorder

# ISIC inter institution

In [None]:
dataset_name="cbm_isic_barcelona_vienna"

main_method_list=["automatic_curated_monet_full", 
#                   "manual_annotation_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  "automatic_curated_vanilla_full"]

manual_annotation_concepts_name="skincon"

automatic_concept_info={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

automatic_concept_info_plotorder={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

subpanel_offset=-5
show_pval=True
stack_mode="suggestion"
stage_list=["weight"]
task_suffix=" (Dermoscopic)"
method_weight="automatic_curated_monet_full"

record_all_list=record_all_list_dict[dataset_name]

fig=plot_cbm_results(
    dataset_name=dataset_name,
    main_method_list=main_method_list,
    manual_annotation_concepts_name=manual_annotation_concepts_name,
    automatic_concept_info=automatic_concept_info,
    automatic_concept_info_plotorder=automatic_concept_info_plotorder,
    record_all_list=record_all_list,
    record_all_list_manual=None,
    subpanel_offset=subpanel_offset,
    stage_list=stage_list,
    method_weight=method_weight,
    show_pval=show_pval,
    stack_mode=stack_mode,
    task_suffix=task_suffix,
    alpha=0.001,
    temp=0.02,
)
print()

fig.savefig(log_dir/"plots"/"cbm_dermoscopic_barcelona_vienna_weight.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_barcelona_vienna_weight.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_barcelona_vienna_weight.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_barcelona_vienna_weight.pdf", bbox_inches='tight')

In [None]:
dataset_name="cbm_isic_vienna_barcelona"

main_method_list=["automatic_curated_monet_full", 
#                   "manual_annotation_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  "automatic_curated_vanilla_full"]

manual_annotation_concepts_name="skincon"

automatic_concept_info={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

automatic_concept_info_plotorder={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

subpanel_offset=-5
show_pval=True
stack_mode="suggestion"
stage_list=["weight"]
task_suffix=" (Dermoscopic)"
method_weight="automatic_curated_monet_full"

record_all_list=record_all_list_dict[dataset_name]

fig=plot_cbm_results(
    dataset_name=dataset_name,
    main_method_list=main_method_list,
    manual_annotation_concepts_name=manual_annotation_concepts_name,
    automatic_concept_info=automatic_concept_info,
    automatic_concept_info_plotorder=automatic_concept_info_plotorder,
    record_all_list=record_all_list,
    record_all_list_manual=None,
    subpanel_offset=subpanel_offset,
    stage_list=stage_list,
    method_weight=method_weight,
    show_pval=show_pval,
    stack_mode=stack_mode,
    task_suffix=task_suffix,
    alpha=0.001,
    temp=0.02,
)
print()

fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_weight.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_weight.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_weight.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_dermoscopic_vienna_barcelona_weight.pdf", bbox_inches='tight')

In [None]:
record_all_list_temp_df=pd.DataFrame(record_all_list)
record_all_list_temp_df["auc_valid"]=record_all_list_temp_df["auc_valid"].map(lambda x: x.item() if isinstance(x, torch.Tensor) else x)
record_all_list_temp_df.fillna(-9).groupby(["task", "method", "alpha"]).mean()

In [None]:
record_all_list_temp_df[(record_all_list_temp_df["method"]=="automatic_curated_vanilla_full")&\
(record_all_list_temp_df["alpha"]==0.001)]

In [None]:
automatic_curated_vanilla_full

In [None]:
record_all_list_temp_df.fillna(-9).groupby(["task", "method", "alpha"])[["auc_valid", "auc"]].mean()

In [None]:
record_all_list_temp_df[(record_all_list_temp_df["method"]=="automatic_curated_monet_full")
].groupby(["task", "method", "alpha"]).apply(lambda x: (x[["auc", "auc_valid"]].mean()))

In [None]:
auc_valid

In [None]:
record_all_list_temp_df[(record_all_list_temp_df["method"].str.contains("resnet"))
]

In [None]:
record_all_list_temp_df=pd.DataFrame(record_all_list)
record_all_list_temp_df["auc_valid"]=record_all_list_temp_df["auc_valid"].map(lambda x: x.item() if isinstance(x, torch.Tensor) else x)
record_all_list_temp_df.fillna(-9).groupby(["task", "method", "alpha"]).mean()

In [None]:
record_all_list_temp_df=pd.DataFrame(record_all_list)

In [None]:
record_all_list_temp_df[record_all_list_temp_df["task"]=="malignant"]

In [None]:
#2
dataset_name="clinical_fd_clean_nodup_nooverlap"

main_method_list=["automatic_curated_monet_full", 
                  "manual_annotation_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  "automatic_curated_vanilla_full"]

manual_annotation_concepts_name="skincon"

automatic_concept_info={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

automatic_concept_info_plotorder={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

subpanel_offset=0
show_pval=True
stack_mode="suggestion"

record_all_list=record_all_list_dict[dataset_name]

##########################################################################################

# dataset_name="derm7pt_derm_nodup"

# main_method_list=["automatic_curated_monet_full", 
#                   "manual_annotation_full", 
#                   "resnet", 
#                   "resnet_freeze_backbone", 
#                   "automatic_curated_vanilla_full"]

# manual_annotation_concepts_name="derm7pt"

# automatic_concept_info={
#     "derm7pt": ['derm7ptconcept_pigment network',
#                 'derm7ptconcept_regression structure',
#                 'derm7ptconcept_pigmentation',
#                 'derm7ptconcept_blue whitish veil',
#                 'derm7ptconcept_vascular structures',
#                 'derm7ptconcept_streaks',
#                 'derm7ptconcept_dots and globules'],
#     "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
#                   'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
# }

# automatic_concept_info_plotorder={
#     "derm7pt": ['derm7ptconcept_pigment network',
#                 'derm7ptconcept_regression structure',
#                 'derm7ptconcept_pigmentation',
#                 'derm7ptconcept_blue whitish veil',
#                 'derm7ptconcept_vascular structures',
#                 'derm7ptconcept_streaks',
#                 'derm7ptconcept_dots and globules'],
#     "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
#                   'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
# }

# subpanel_offset=-1
# show_pval=True

# record_all_list=record_all_list_dict[dataset_name]

##########################################################################################

# dataset_name="isic_nodup_nooverlap"

# main_method_list=["automatic_curated_monet_full", 
#                   "resnet", 
#                   "resnet_freeze_backbone", 
#                   "automatic_curated_vanilla_full"]

# manual_annotation_concepts_name="derm7pt"

# automatic_concept_info={
#     "derm7pt": ['derm7ptconcept_pigment network',
#                 'derm7ptconcept_regression structure',
#                 'derm7ptconcept_pigmentation',
#                 'derm7ptconcept_blue whitish veil',
#                 'derm7ptconcept_vascular structures',
#                 'derm7ptconcept_streaks',
#                 'derm7ptconcept_dots and globules'],
#     "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
#                   'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
# }

# automatic_concept_info_plotorder={
#     "derm7pt": ['derm7ptconcept_pigment network',
#                 'derm7ptconcept_regression structure',
#                 'derm7ptconcept_pigmentation',
#                 'derm7ptconcept_blue whitish veil',
#                 'derm7ptconcept_vascular structures',
#                 'derm7ptconcept_streaks',
#                 'derm7ptconcept_dots and globules'],
#     "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
#                   'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
# }

# subpanel_offset=-1
# show_pval=False
# # stage_list["overview", "performance", "weight", "skincon"]
# stage_list=["performance", "weight"]

# record_all_list=record_all_list_dict["cbm_isic_barcelona_vienna"]

##########################################################################################
stack_mode="suggestion"
alpha=0.001
temp=0.02



dataset_name="clinical_fd_clean_nodup_nooverlap"

main_method_list=["automatic_curated_monet_full", 
                  "manual_annotation_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  "automatic_curated_vanilla_full"]

manual_annotation_concepts_name="skincon"

automatic_concept_info={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

automatic_concept_info_plotorder={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

subpanel_offset=0
show_pval=True
stack_mode="suggestion"

record_all_list=record_all_list_dict[dataset_name]

In [None]:
record_all_df_perf=pd.DataFrame(record_all_list)
record_all_df_perf=record_all_df_perf[record_all_df_perf["random_seed"]<=20]

#                 record_all_df_perf=record_all_df_perf[(record_all_df_perf["method"].str.contains("automatic")&(record_all_df_perf["alpha"]==alpha))|\
#                                                      (~record_all_df_perf["method"].str.contains("automatic"))
#                                                      ]
record_all_df_perf_filtered=record_all_df_perf.copy()
record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["method"].isin(main_method_list)]
#                 print(record_all_df_perf_filtered["method"].value_counts())

#                 anual_annotation_full            200
#                 automatic_curated_monet_full      200
#                 automatic_curated_vanilla_full    200
#                 resnet                             40
#                 resnet_freeze_backbone             40            
# record_all_df_perf_filtered=record_all_df_perf_filtered[(  ~(record_all_df_perf_filtered["method"].str.contains("automatic") | record_all_df_perf_filtered["method"].str.contains("manual"))  )|\
#                                          (   (record_all_df_perf_filtered["method"].str.contains("automatic") | record_all_df_perf_filtered["method"].str.contains("manual"))   &(record_all_df_perf_filtered["alpha"]==alpha)&(record_all_df_perf_filtered["temp"]==temp))]
#dsds

In [None]:
record_all_df_perf_filtered["method"].value_counts()

In [None]:
record_all_df_perf_filtered[
    ((record_all_df_perf_filtered["alpha"].isnull())|(record_all_df_perf_filtered["alpha"]==alpha))]["method"].value_counts()

In [None]:
record_all_df_perf_filtered[(  ~(record_all_df_perf_filtered["method"].str.contains("automatic") | record_all_df_perf_filtered["method"].str.contains("manual"))  )|\
(   (record_all_df_perf_filtered["method"].str.contains("automatic") | record_all_df_perf_filtered["method"].str.contains("manual"))   & (record_all_df_perf_filtered["alpha"]==alpha)&(record_all_df_perf_filtered["temp"]==temp))]["method"].value_counts()




In [None]:
record_all_df_perf["temp"].value_counts()

In [None]:
record_all_df

In [None]:
pd.DataFrame(record_all_list)["method"].value_counts()

In [None]:
pd.DataFrame(record_all_list)["method"].value_counts()

In [None]:
automatic_curated_monet_less_concept

In [None]:
record_all_list_dict.keys()

In [None]:
target_method="automatic_curated_monet_full"
variable_method_list=[
    "automatic_curated_monet_less_concept",
    "automatic_skincon_monet_less_concept"]
variable_method="automatic_curated_monet_less_concept"
subpanel_offset=0
record_all_list=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation_seed2040"]
alpha=0.001

fig=plot_concept_num(
    target_method=target_method,
    variable_method_list=variable_method_list,
    subpanel_offset=subpanel_offset,
    record_all_list=record_all_list,
    task_suffix="",
    panel_title=". Clinical (Fizpatrick17k+DDI)",  
    alpha=alpha,
#     with_validation=False,  
#     with_validation_label=None,  
    with_validation=True,  
    with_validation_label=["Validation", "Test"],      
)
pd.DataFrame(record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation"])["method"].value_counts()

# fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.pdf", bbox_inches='tight') 


target_method="automatic_curated_monet_full"
variable_method_list=[
    "automatic_curated_monet_less_concept",
    "automatic_skincon_monet_less_concept"]
variable_method="automatic_curated_monet_less_concept"
subpanel_offset=0
record_all_list=record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation"]
alpha=0.001

fig=plot_concept_num(
    target_method=target_method,
    variable_method_list=variable_method_list,
    subpanel_offset=subpanel_offset,
    record_all_list=record_all_list,
    task_suffix="",
    panel_title=". Clinical (Fizpatrick17k+DDI)",  
    alpha=alpha,
#     with_validation=False,  
#     with_validation_label=None,  
    with_validation=True,  
    with_validation_label=["Validation", "Test"],      
)
pd.DataFrame(record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation"])["method"].value_counts()

# fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"cbm_num_concepts_clinical.pdf", bbox_inches='tight') 


target_method="automatic_curated_monet_full"
variable_method_list=[
    "automatic_curated_monet_less_concept",
    "automatic_skincon_monet_less_concept"]
variable_method="automatic_curated_monet_less_concept"
subpanel_offset=1
record_all_list=record_all_list_dict["cbm_isic_nodup_nooverlap_all_automatic_with_validation"]
alpha=0.001

fig=plot_concept_num(
    target_method=target_method,
    variable_method_list=variable_method_list,
    subpanel_offset=subpanel_offset,
    record_all_list=record_all_list,
    task_suffix="",
    panel_title=". Dermoscopic (ISIC)",    
    alpha=alpha,
#     with_validation=False,  
#     with_validation_label=None,
    with_validation=True,  
    with_validation_label=["Validation", "Test"],    
)
pd.DataFrame(record_all_list_dict["cbm_isic_nodup_nooverlap_all_automatic_with_validation"])["method"].value_counts()

fig.savefig(log_dir/"plots"/"cbm_num_concepts_isic.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_isic.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_isic.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_isic.pdf", bbox_inches='tight') 

target_method="automatic_curated_monet_full"
variable_method_list=[
    "automatic_curated_monet_less_concept",
    "automatic_skincon_monet_less_concept",

]
variable_method="automatic_curated_monet_less_concept"
subpanel_offset=2
record_all_list=record_all_list_dict["cbm_isic_barcelona_vienna"]
alpha=0.001

fig=plot_concept_num(
    target_method=target_method,
    variable_method_list=variable_method_list,
    subpanel_offset=subpanel_offset,
    record_all_list=record_all_list,
    task_suffix="",
    panel_title=". Dermoscopic (ISIC: Hosp. Barcelona → Med U. Vienna)",
    alpha=alpha,
    with_validation=True,
    with_validation_label=["Validation (Hosp. Barcelona)",
                           "Test (Med U. Vienna)"
                          ]
)
pd.DataFrame(record_all_list_dict["cbm_isic_barcelona_vienna"])["method"].value_counts()

fig.savefig(log_dir/"plots"/"cbm_num_concepts_barcelona_vienna.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_barcelona_vienna.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_barcelona_vienna.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_barcelona_vienna.pdf", bbox_inches='tight') 

target_method="automatic_curated_monet_full"
variable_method_list=[
    "automatic_curated_monet_less_concept",
    "automatic_skincon_monet_less_concept"]
variable_method="automatic_curated_monet_less_concept"
subpanel_offset=3
record_all_list=record_all_list_dict["cbm_isic_vienna_barcelona"]
alpha=0.001

fig=plot_concept_num(
    target_method=target_method,
    variable_method_list=variable_method_list,
    subpanel_offset=subpanel_offset,
    record_all_list=record_all_list,
    task_suffix="",
    panel_title=". Dermoscopic (ISIC: Med U. Vienna → Hosp. Barcelona)",    
    alpha=alpha,
    with_validation=True,    
    with_validation_label=["Validation (Med U. Vienna)",
                           "Test (Hosp. Barcelona)"
                          ]    
)
pd.DataFrame(record_all_list_dict["cbm_isic_vienna_barcelona"])["method"].value_counts()

fig.savefig(log_dir/"plots"/"cbm_num_concepts_vienna_barcelona.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_vienna_barcelona.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_vienna_barcelona.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"cbm_num_concepts_vienna_barcelona.pdf", bbox_inches='tight') 

In [None]:
import sklearn.metrics

In [None]:
sklearn.metrics.roc_auc_score(
y_true=pd.DataFrame(record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation"]).iloc[0]["y_test"],
y_score=pd.DataFrame(record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation"]).iloc[0]["y_test_pred"]
)

In [None]:
pd.DataFrame(record_all_list_dict["clinical_fd_clean_nodup_nooverlap_with_validation"])["method"].value_counts()

In [None]:
record_all_list=record_all_list_dict["cbm_isic_barcelona_vienna"]

##########################################################################################
stack_mode="suggestion"
alpha=0.001
temp=0.02

In [None]:
for task in ["melanoma", "malignant"]:
    record_all_df_conceptcuration_filtered=record_all_df_conceptcuration[
        
        (record_all_df_conceptcuration["task"]==task)&
        (record_all_df_conceptcuration["alpha"]==alpha)
    
    ]
    

    
    record_all_df_conceptcuration_filtered_target=record_all_df_conceptcuration_filtered[record_all_df_conceptcuration_filtered["method"]==target_method]
    record_all_df_conceptcuration_filtered_variable=record_all_df_conceptcuration_filtered[record_all_df_conceptcuration_filtered["method"]==variable_method]
    
    record_all_df_conceptcuration_filtered_target
    
    
    
    

In [None]:
record_all_df_conceptcuration_filtered_target

In [None]:
sns.lineplot(
x="num_concept",
y="auc",
data=record_all_df_conceptcuration_filtered_variable
)
print(record_all_df_conceptcuration_filtered_target)

In [None]:
record_all_df_conceptcuration_filtered[
    record_all_df_conceptcuration_filtered["method"]=="automatic_derm7pt_monet_full"]

In [None]:
record_all_df_conceptcuration_filtered["method"].value_counts()

In [None]:
setting_name

In [None]:
record_all_df_concept_select

In [None]:
record_all_list_dict.keys()

In [None]:
pd.DataFrame(record_all_list).groupby(["task", "method"]).mean()

In [None]:
print(1)

In [None]:
#1 deprecated (original version)

dataset_name="clinical_fd_clean_nodup_nooverlap"

main_method_list=["automatic_curated_monet_full", 
                  "manual_annotation_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  "automatic_curated_vanilla_full"]
method_list=["manual_annotaion", "automatic", "resnet", "resnet_freeze_backbone"]

manual_annotation_concepts_name="skincon"

automatic_concept_info={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

automatic_concept_info_plotorder={
    "skincon": skincon_cols,
    "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Erosion', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
                  'cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
}

# dataset_name="derm7pt_derm_nodup"
# main_method_list=["automatic_curated_monet_full", 
#                   "manual_annotation_full", 
#                   "resnet", 
#                   "resnet_freeze_backbone", 
#                   "automatic_curated_vanilla_full"]


# method_list=["automatic"]
# manual_annotation_concepts=['derm7ptconcept_pigment network',
#                     'derm7ptconcept_regression structure',
#                     'derm7ptconcept_pigmentation',
#                     'derm7ptconcept_blue whitish veil',
#                     'derm7ptconcept_vascular structures',
#                     'derm7ptconcept_streaks',
#                     'derm7ptconcept_dots and globules']

# automatic_concept_info={
#     "derm7pt": ['derm7ptconcept_pigment network',
#                     'derm7ptconcept_regression structure',
#                     'derm7ptconcept_pigmentation',
#                     'derm7ptconcept_blue whitish veil',
#                     'derm7ptconcept_vascular structures',
#                     'derm7ptconcept_streaks',
#                     'derm7ptconcept_dots and globules'],
#     "curated": ['cbm_Asymmetry', 'cbm_Irregular', 'cbm_Black', 'cbm_Blue', 'cbm_White', 'cbm_Brown', 
#                   'cbm_Erosion','cbm_Multiple Colors', 'cbm_Tiny', 'cbm_Regular']
# }

alpha=0.001
temp=0.02

record_all_list=record_all_list_dict[dataset_name]

plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                Paired[12][3],
                                                                                Paired[12][5],
                                                                                Paired[12][7],
                                                                                Paired[12][9],
                                                                                Paired[12][11]
                                                                                ]])


fig = plt.figure(figsize=(3*10, 3*(3 + 2.5 + 4 + 3 + 0.35*3)))

box1 = gridspec.GridSpec(4,1,
                         height_ratios=[3, 2.5, 4, 3],
                         wspace=0.0,
                         hspace=0.35)


# temp array([  nan, 0.02 , 0.01 , 0.005])
# alpha array([0.001 , 0.0001,    nan])


axd={}
for idx1, stage in enumerate(["overview", "skincon", "performance", "weight"]):
    if stage=="overview":
        plot_key=stage
        ax=plt.Subplot(fig, box1[idx1])
        fig.add_subplot(ax) 
        axd[plot_key]=ax
    elif stage=="empty":
        plot_key=stage
        ax=plt.Subplot(fig, box1[idx1])
        fig.add_subplot(ax) 
        axd[plot_key]=ax        
    elif stage=="performance":
        box2 = gridspec.GridSpecFromSubplotSpec(1, 4,
                        subplot_spec=box1[idx1], 
                        width_ratios=[0.1, 1, 0.1, 1], wspace=0., hspace=0.)        
        for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
#             elif investigation_type=="statistics":
            plot_key=f"{stage}_{task}"
            ax=plt.Subplot(fig, box2[idx2])
            fig.add_subplot(ax)
            axd[plot_key]=ax  
            
    elif stage=="weight":
        box2 = gridspec.GridSpecFromSubplotSpec(1, 4,
                        subplot_spec=box1[idx1], 
                        width_ratios=[0.15, 1, 0.15, 1], wspace=0.1, hspace=0.)        
        for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
#             elif investigation_type=="statistics":
            plot_key=f"{stage}_{task}"
            ax=plt.Subplot(fig, box2[idx2])
            fig.add_subplot(ax)
            axd[plot_key]=ax  
            
    elif stage=="skincon":
        box2 = gridspec.GridSpecFromSubplotSpec(1, 8,
                        subplot_spec=box1[idx1], width_ratios=[0.15, 1, 0.15, 1,     0.27, 1, 0.15, 1], wspace=0.0, hspace=0.)        
        
        
        for idx2, variable in enumerate(["empty", "malignant_num_concept", "empty1", "malignant_num_sample", 
                                         "empty2", "melanoma_num_concept", "empty3", "melanoma_num_sample"]):
#             elif investigation_type=="statistics":
            plot_key=f"{stage}_{variable}"
            ax=plt.Subplot(fig, box2[idx2])
            fig.add_subplot(ax)
            axd[plot_key]=ax              
            
        
for plot_key in axd.keys():
    if 'overview' in plot_key:
        axd[plot_key].set_xticks([])
        axd[plot_key].set_yticks([])
        for axis in ['top','bottom','left','right']:
            axd[plot_key].spines[axis].set_linewidth(0) 
            
    if 'empty' in plot_key:
        axd[plot_key].set_xticks([])
        axd[plot_key].set_yticks([])
        for axis in ['top','bottom','left','right']:
            axd[plot_key].spines[axis].set_linewidth(0)           
        
for idx1, stage in enumerate(["overview", "performance", "weight", "skincon"]):
    if stage=="overview":
        plot_key=stage
        
        axd[plot_key].text(x=-0.0, y=1.0, transform=axd[plot_key].transAxes,
                                 s="A", fontsize=35, weight='bold')           

    elif stage=="performance":   
        for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
#             elif investigation_type=="statistics":
            plot_key=f"{stage}_{task}"
            
            if task=="malignant" or task=="melanoma":
                record_all_df_perf=pd.DataFrame(record_all_list)
                record_all_df_perf=record_all_df_perf[record_all_df_perf["random_seed"]<20]
                record_all_df_perf_filtered=record_all_df_perf.copy()
                record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["method"].isin(main_method_list)]
                record_all_df_perf_filtered=record_all_df_perf_filtered[(~record_all_df_perf_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))|\
                                                             ((record_all_df_perf_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))&(record_all_df_perf_filtered["alpha"]==alpha)&(record_all_df_perf_filtered["temp"]==temp))]
                #dsds
#                 record_all_df_new=pd.DataFrame(record_all_list_new)
#                 record_all_df_filtered_new=record_all_df_new[record_all_df_new["is_clean"]=="clean_only"]
#                 record_all_df_filtered_new=record_all_df_filtered_new[record_all_df_filtered_new["method"]=="automatic_skincon_monet_full"]
#                 record_all_df_filtered=pd.concat([record_all_df_filtered, record_all_df_filtered_new], axis=0)
#                 dsdsdsds
                if task=="malignant":
#                     dsds
                    record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["task"]=="malignant"]
                    axd[plot_key].set_ylim(0.49, 1.01)
                elif task=="melanoma":
                    record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["task"]=="melanoma"]
                    axd[plot_key].set_ylim(0.49, 1.01)
                else:
                    raise ValueError
#                 sdsdsd
#                 sdsd
                b=sns.boxplot(x="method", y="auc", 
                              order=main_method_list,
                              width=0.5,
                              linewidth=3,
                              saturation=1.3,
                              boxprops=dict(alpha=.9),
                              data=record_all_df_perf_filtered, 
                            ax=axd[plot_key])
                
                
                sns.swarmplot(x="method", y="auc", 
                              order=main_method_list,
                              color='black', 
                              alpha=0.8,
                              size=9,
                              data=record_all_df_perf_filtered, ax=axd[plot_key])
 

                record_all_df_perf_filtered_pvalue=record_all_df_perf_filtered.groupby("task")\
                .apply(lambda x: x.groupby("method")\
                .apply(lambda y: 
                scipy.stats.ttest_rel(y.set_index('random_seed')["auc"], 
                x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"], alternative="less").pvalue)).T

    
#                 count=0
#                 for method_name in main_method_list[::-1]:
# #                     continue
#                     if method_name=="automatic_curated_monet_full":
#                         continue
                        
#                     pvalue_x1=main_method_list.index("automatic_curated_monet_full")
#                     pvalue_x2=main_method_list.index(method_name)
                    
#                     pvalue_y=record_all_df_perf_filtered[record_all_df_perf_filtered["method"]=="automatic_curated_monet_full"]["auc"].max()
#                     pvalue_y_=record_all_df_perf_filtered[record_all_df_perf_filtered["method"]==method_name]["auc"].max()
                    
                    
#                     print(method_name, record_all_df_perf_filtered_pvalue.loc[method_name][task])
#                     if record_all_df_perf_filtered_pvalue.loc[method_name][task]<0.001:
#                         pvalue_str="***"
#                     elif record_all_df_perf_filtered_pvalue.loc[method_name][task]<0.01:
#                         pvalue_str="**"
#                     elif record_all_df_perf_filtered_pvalue.loc[method_name][task]<0.05:
#                         pvalue_str="*"                        
#                     else:
#                         pvalue_str="ns"  
                        
#                     axd[plot_key].text((pvalue_x2), 
#                                        pvalue_y_+0.0035,
#                              s=pvalue_str, 
#                              fontsize=25,
#                              ha='center', 
#                              va='bottom', 
#                              color="k")
#                     count+=1   
#                     continue
                    
                    
#                     axd[plot_key].plot([pvalue_x1, 
#                                         pvalue_x1, 
#                                         pvalue_x2, 
#                                         pvalue_x2], 
#                                        [pvalue_y+0.013+0.023*(count),
#                                         pvalue_y+0.013+0.023*(count)+0.005, 
#                                         pvalue_y+0.013+0.023*(count)+0.005, 
#                                         pvalue_y+0.013+0.023*(count)], 
#                                        lw=3, c='black')
                      
                    

                    
                print(task, record_all_df_perf_filtered.groupby("method")["auc"].apply(lambda x: {"mean": x.mean(),
                                                                      "std": x.std(),
                                                                      "q3": x.quantile(q=0.75),                                                                      
                                                                      "median": x.median(),
                                                                      "q1": x.quantile(q=0.25),
                                                                     }))

                if task=="malignant":
                    axd[plot_key].set_title("Malignant"+task_suffix, fontsize=30)
                if task=="melanoma":
                    axd[plot_key].set_title("Melanoma"+task_suffix, fontsize=30)

                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(3)                

                if idx2==1: 
                    axd[plot_key].set_ylabel('Area under the ROC curve', fontsize=25)
                if idx2==3: 
                    axd[plot_key].set_ylabel('')                    

                axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
                axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
                axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
                #axd[plot_key].yaxis.grid(True, which='minor', linewidth=2, alpha=0.1)

                axd[plot_key].spines['right'].set_visible(False)
                axd[plot_key].spines['top'].set_visible(False)                   
                axd[plot_key].spines['bottom'].set_visible(False)    

                axd[plot_key].tick_params(axis='x', which='major', labelsize=25)
                axd[plot_key].tick_params(axis='y', which='major', labelsize=25)   

                axd[plot_key].tick_params(
                    axis='x',          # changes apply to the x-axis
                    which='both',      # both major and minor ticks are affected
                    bottom=False,
                    labelbottom=False,      # ticks along the bottom edge are off
                    )            

                axd[plot_key].set_xlabel(None)
            
            if task=="empty_malignant":
                axd[plot_key].text(x=-0.05, y=1.02, transform=axd[plot_key].transAxes,
                                     s="F", fontsize=35, weight='bold') 
                
            if task=="malignant":
                
                legend_elements=[Patch(facecolor=plt.rcParams["axes.prop_cycle"].by_key()['color'][method_idx], 
                                       edgecolor="black", linewidth=2, 
                                       label=shorten_method_name(method_name)) for method_idx, method_name in enumerate(main_method_list)]
                

                axd[plot_key].legend(handles=legend_elements, 
                            ncol=5, 
                            handlelength=2.5,
                            handletextpad=0.4, 
                            columnspacing=1.3,
                            fontsize=23,
                            loc='lower center', bbox_to_anchor=(1., -0.1))                   
                
                #axd[plot_key].set_ylabel('Concepts', fontsize=30)
                
                
            if task=="empty_melanoma":
                axd[plot_key].text(x=-0.0, y=1.02, transform=axd[plot_key].transAxes,
                                     s="G", fontsize=35, weight='bold')  
                
                
                
                
        
    elif stage=="weight":  
        for idx2, task in enumerate(["empty_malignant", "malignant", "empty_melanoma", "melanoma"]):
            plot_key=f"{stage}_{task}"
            if task=="malignant" or task=="melanoma":
                record_all_df_weight=pd.DataFrame(record_all_list)
                record_all_df_weight=record_all_df_weight[record_all_df_weight["random_seed"]<20]
                record_all_df_weight_filtered=record_all_df_weight.copy()#[record_all_df_weight["is_clean"]=="clean_only"]
                record_all_df_weight_filtered=record_all_df_weight_filtered[(record_all_df_weight_filtered["method"]=="automatic_curated_monet_full")&(record_all_df_weight_filtered["alpha"]==alpha)&(record_all_df_weight_filtered["temp"]==temp)]
                
#                 record_all_df_filtered=record_all_df_filtered[record_all_df_filtered["method"].isin(main_method_list)]
#                 record_all_df_filtered=record_all_df_filtered[(~record_all_df_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))|\
#                                                              ((record_all_df_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))&(record_all_df_filtered["alpha"]==alpha)&(record_all_df_filtered["temp"]==temp))]                
                
                #print(record_all_df_filtered["alpha"])

                if task=="malignant":
                    record_all_df_weight_filtered=record_all_df_weight_filtered[record_all_df_weight_filtered["task"]=="malignant"]
                elif task=="melanoma":
                    record_all_df_weight_filtered=record_all_df_weight_filtered[record_all_df_weight_filtered["task"]=="melanoma"]
                else:
                    raise ValueError

                    
                print(record_all_df_weight_filtered)
                    
                coef_dict_list=[]
                for clf_idx, clf in enumerate(record_all_df_weight_filtered["clf"]):
                    for concept_name, coef in zip(automatic_concept_info["curated"], clf.coef_[0]):
                        coef_dict_list.append({"concept_name": concept_name,
                                               "coef": coef,
                                               "clf_idx": clf_idx
                                              })

#                 print(coef_dict_list)
#                 print(sdsd)
                #pd.DataFrame(coef_dict_list).groupby("concept_name")["coef"].to_csv(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}_{task}.csv")
#                 print(task, pd.DataFrame(coef_dict_list).groupby("concept_name")["coef"].mean())
                
                #sdsd
                coef_dict_list_df_barplot=pd.DataFrame(coef_dict_list)
                coef_dict_list_df_barplot["concept_name"]=coef_dict_list_df_barplot["concept_name"].map(shorten_concept_name)
            
                weight_bar=sns.barplot(y="concept_name", x="coef", 
                            color=np.array(Paired[12][7])/256,
                            edgecolor='black',
                            linewidth=2,
                            width=0.5,    
                            order= list(map(shorten_concept_name,automatic_concept_info_plotorder["curated"])),
                            errwidth=5,
                                       
                            data=coef_dict_list_df_barplot, ax=axd[plot_key])
                
#                 for container in weight_bar.containers:
#                     axd[plot_key].bar_label(container)                
                    
#                 for p in weight_bar.patches:
#                     _x = p.get_x() + p.get_width() / 2
#                     _y = p.get_y() + p.get_height()
#                     value = '{:.2f}'.format(p.get_height())
#                     weight_bar.text(_x, _y, value, ha="center")  
                
                #rint(axd[plot_key].get_yticks())
        
#                 bar_labels=[]
#                 for concept_name in concept_list_curated:
#                     temp=pd.DataFrame(coef_dict_list)
#                     temp=temp[temp["concept_name"]==concept_name]
                    
#                     if temp["coef"].mean()>3:
#                         bar_labels.append()
#                     else:
#                         bar_labels.append("")
#                 print(temp, bar_labels)
#                 axd[plot_key].bar_label(weight_bar.containers[0], labels=bar_labels, fontsize=20)
                #print(pd.DataFrame(coef_dict_list))
                #dsd
                for p, concept_name in zip(weight_bar.patches, automatic_concept_info_plotorder["curated"]):
                    _x = p.get_x() + p.get_width() / 2
                    _y = p.get_y() + p.get_height()
                    
                    coef_dict_list_df=pd.DataFrame(coef_dict_list)
#                     dsdsd
                    coef_dict_list_df=coef_dict_list_df[coef_dict_list_df["concept_name"]==concept_name]
                    
                    if coef_dict_list_df["coef"].mean()>3:
                        value=f"{coef_dict_list_df['coef'].mean():.2f} (±{1.96*coef_dict_list_df['coef'].std()/ np.sqrt(len(coef_dict_list_df)) :.2f})"
                        axd[plot_key].text(2.8, _y+0.6, value, ha="center", fontsize=20, zorder=100)
                    
                    #value = '{:.2f}'.format(p.get_height())
                    
                
#                 for c in weight_bar.containers:
#                     c_mean=c.datavalues.mean()
#                     c_mean=np.round(c_mean,2)
#                     ci=1.96*c.datavalues.std()/np.sqrt(len(c.datavalues))
#                     ci=np.round(ci,2)
                    #axd[plot_key].bar_label(c, labels=[f"{c_mean:.2f} (±{ci})"], fontsize=20)

    #             sns.boxplot(x="method", y="auc", 
    #                         data=record_all_df_filtered, 
    #                         width=0.5,
    #                         linewidth=3,
    #                         ax=axd[plot_key])
    #             sns.swarmplot(x="method", y="auc", 
    #                           color='black', 
    #                           alpha=0.8,
    #                           size=10,
    #                           data=record_all_df_filtered, ax=axd[plot_key])


#             if task=="empty_malignancy":
#                 axd[plot_key].text(x=-0.0, y=1.0, transform=axd[plot_key].transAxes,
#                                                          s="D", fontsize=35, weight='bold')
#             if task=="empty_melanoma":
#                 axd[plot_key].text(x=-0.0, y=1.0, transform=axd[plot_key].transAxes,
#                                                          s="E", fontsize=35, weight='bold')


            if task=="empty_malignant":
                axd[plot_key].text(x=-0., y=1.05, transform=axd[plot_key].transAxes,
                                     s="H", fontsize=35, weight='bold')              
                
                #axd[plot_key].set_ylabel('Concepts', fontsize=30)
                
                
            if task=="empty_melanoma":
                axd[plot_key].text(x=-0.1, y=1.05, transform=axd[plot_key].transAxes,
                                     s="I", fontsize=35, weight='bold')    

            if task=="malignant":
                axd[plot_key].set_title("Malignant"+task_suffix, fontsize=30, pad=20)
            if task=="melanoma":
                axd[plot_key].set_title("Melanoma"+task_suffix, fontsize=30, pad=20)

            if task=="malignant" or task=="melanoma":
                if task=="malignant":
                    axd[plot_key].set_xlim(-3,3)
                elif task=="melanoma":
                    axd[plot_key].set_xlim(-3,3)
                else:
                    raise ValueError            
            
                axd[plot_key].axvline(x=0, ymin=0, ymax=1, color='black', alpha=0.7, linewidth=5, zorder=-5)

                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(3)                

                axd[plot_key].set_ylabel('Area under the curve', fontsize=25)

    #             axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
    #             axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
                axd[plot_key].xaxis.set_major_locator(MultipleLocator(0.5))
                axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.1))
                axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        
                axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                #axd[plot_key].yaxis.grid(True, which='minor', linewidth=2, alpha=0.1)

                axd[plot_key].spines['left'].set_visible(False)
                axd[plot_key].spines['right'].set_visible(False)
                axd[plot_key].spines['top'].set_visible(False)                   
                #axd[plot_key].spines['bottom'].set_visible(False)    

                axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
                axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

                axd[plot_key].tick_params(
                    axis='x',          # changes apply to the x-axis
                    which='both',      # both major and minor ticks are affected
                    top=False,
                    labeltop=False,      # ticks along the bottom edge are off                    
                    bottom=True,
                    labelbottom=True,      # ticks along the bottom edge are off
                    )            

                axd[plot_key].set_ylabel(None)
                axd[plot_key].set_xlabel("Coefficients of linear model", fontsize=25, labelpad=5)
                #axd[plot_key].xaxis.set_label_position('top') 
    elif stage=="skincon":       
        for idx2, variable in enumerate(["malignant_num_concept", "malignant_num_sample", 
                                         "melanoma_num_concept", "melanoma_num_sample", ]):
            plot_key=f"{stage}_{variable}"    
            
            record_all_df_skincon=pd.DataFrame(record_all_list)
            record_all_df_skincon=record_all_df_skincon[record_all_df_skincon["random_seed"]<20]
            record_all_df_skincon_filtered=record_all_df_skincon.copy()#[record_all_df_skincon["is_clean"]=="clean_only"]
            
#             import ipdb
#             ipdb.set_trace()            
            record_all_df_skincon_filtered=record_all_df_skincon_filtered[
                record_all_df_skincon_filtered["method"].isin([f"automatic_{manual_annotation_concepts_name}_monet_full", "manual_annotation_less_sample", "manual_annotation_less_concept"])]
            

            if variable.startswith("malignant"):
                record_all_df_skincon_filtered=record_all_df_skincon_filtered[record_all_df_skincon_filtered["task"]=="malignant"]
            elif variable.startswith("melanoma"):
                record_all_df_skincon_filtered=record_all_df_skincon_filtered[record_all_df_skincon_filtered["task"]=="melanoma"]
            else:
                raise ValueError
            

            record_all_df_skincon_filtered_ref=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]==f"automatic_{manual_annotation_concepts_name}_monet_full"]
            if variable.endswith("num_sample"):
                record_all_df_skincon_filtered_obs=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]=="manual_annotation_less_sample"]
                if variable.startswith("melanoma"):
                    pass
#                     sdsd
                record_all_df_skincon_filtered_obs_sample_prop=record_all_df_skincon_filtered_obs.groupby("num_sample_train")["num_sample_train"].mean()
                #record_all_df_skincon_filtered_obs["num_sample_train_pseudo"]=record_all_df_skincon_filtered_obs.apply(lambda x: record_all_df_skincon_filtered_obs_sample_prop[x["num_sample_train_select"]], axis=1)
                #jjj
                b=sns.lineplot(x="num_sample_train", y="auc", 
                               color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1], linewidth=4,
                               data=record_all_df_skincon_filtered_obs, ax=axd[plot_key])
                if b.legend_ is not None:
                    b.legend_.remove()
                #axd[plot_key].axhline(y=ref_value, xmin=0, xmax=100000, color='red', alpha=0.7, linewidth=3, zorder=-5)
                
                
                axd[plot_key].scatter(0, record_all_df_skincon_filtered_ref["auc"].mean(), s=300, marker='X', color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0])
                print(stage, variable, record_all_df_skincon_filtered_ref["auc"].mean())
#                 record_all_df_skincon_filtered_ref_matched=pd.DataFrame(list(itertools.product(record_all_df_skincon_filtered_obs_sample_prop.values, 
#                                                     record_all_df_skincon_filtered_ref["auc"].values)),
#                             columns=["num_sample_train_pseudo", "auc"])                
#                 b=sns.lineplot(x="num_sample_train_pseudo", y="auc", 
#                                color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], linewidth=4, linestyle="--",
#                                data=record_all_df_skincon_filtered_ref_matched, ax=axd[plot_key])
#                 if b.legend_ is not None:
#                     b.legend_.remove()
                #axd[plot_key].set_xlim(record_all_df_skincon_filtered_obs["num_sample_train_pseudo"].min(), record_all_df_skincon_filtered_obs["num_sample_train_pseudo"].max())
                
            elif variable.endswith("num_concept"):
                #sdsdd
                record_all_df_skincon_filtered_obs=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]=="manual_annotation_less_concept"]
                b=sns.lineplot(x="num_concept", y="auc", 
                               color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1], linewidth=4,
                               data=record_all_df_skincon_filtered_obs, ax=axd[plot_key])
                if b.legend_ is not None:
                    b.legend_.remove()
                    
                    
                axd[plot_key].scatter(48, record_all_df_skincon_filtered_ref["auc"].mean(), s=300, marker='X', color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0])
                print(stage, variable, record_all_df_skincon_filtered_ref["auc"].mean())
                #sds
                #axd[plot_key].axhline(y=ref_value, xmin=0, xmax=100000, color='red', alpha=0.7, linewidth=3, zorder=-5)                    
                
#                 record_all_df_skincon_filtered_ref_matched=pd.DataFrame(list(itertools.product(record_all_df_skincon_filtered_obs["num_concept"].unique(), 
#                                                     record_all_df_skincon_filtered_ref["auc"].values)),
#                             columns=["num_sample_train_pseudo", "auc"])                
#                 b=sns.lineplot(x="num_sample_train_pseudo", y="auc", 
#                                color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], linewidth=4, linestyle="--",
#                                data=record_all_df_skincon_filtered_ref_matched, ax=axd[plot_key])                
#                 if b.legend_ is not None:
#                     b.legend_.remove()
                
                
                
            
        
#             dsds  
#             ref_value=record_all_df_skincon_filtered_ref[record_all_df_skincon_filtered_ref["task"]=="melanoma"]["auc"].mean()
            
#             #

#             #sns.lineplot(x="sample_prop", y="auc", hue="task", style="method", data=record_all_df_skincon_filtered, ax=axd[plot_key])
#             dsdsd


#             
#             axd[plot_key].set_xlim(0, record_all_df_skincon_filtered["num_sample_train"].max())
#                 #axd[plot_key].set_xlim(1-0.2, 11.5)

#             record_all_df_skincon_filtered     


            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)    
            axd[plot_key].tick_params(axis='y', which='major', labelsize=20)
            
            if variable.endswith("num_concept"):
                axd[plot_key].xaxis.set_major_locator(MultipleLocator(10))
                axd[plot_key].xaxis.set_minor_locator(MultipleLocator(5))
                axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
                axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
                
                
                axd[plot_key].set_xlabel("Num. of concepts", fontsize=25)
                axd[plot_key].set_xlim(-0.1, 49)
                
                
#                 axd[plot_key].tick_params(
#                     axis='y',          # changes apply to the x-axis
#                     which='both',      # both major and minor ticks are affected
#                     labelleft=False)                
                
            elif variable=="num_reference":
                axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
                axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
                
                
                       
            elif variable.endswith("num_sample"):
                #axd[plot_key].set_xticks([0.05, 0.1 , 0.2 , 0.4 , 0.6 , 0.8 , 1.])
                #.set_xticks([2,4,6,8,10])
                if variable.startswith("malignant"):
                    axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
                    axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))                
                elif variable.startswith("malignant"):
                    axd[plot_key].xaxis.set_major_locator(MultipleLocator(100))
                    axd[plot_key].xaxis.set_minor_locator(MultipleLocator(50))                                    
                
                axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
                axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
                
                axd[plot_key].set_xlabel("Num. of expert-labeled samples", fontsize=25)#, labelpad=-10)
            
            if idx2==0:
                axd[plot_key].set_ylabel("Area under the ROC curve", fontsize=25)
            else:
                axd[plot_key].set_ylabel(None)    
            
                
                #axd[plot_key].tick_params(axis='x', which='major', left=False, labelleft=False)
                
            if variable.startswith("malignant"):
                #axd[plot_key].set_ylim(0.61, 0.98)
                #axd[plot_key].set_ylim(0.531, 0.881)
                axd[plot_key].set_ylim(0.49, 1.01)
                axd[plot_key].set_title("malignancy"+task_suffix, fontsize=25, pad=20)
            elif variable.startswith("melanoma"):                
                axd[plot_key].set_title("Melanoma"+task_suffix, fontsize=25, pad=20)
                axd[plot_key].set_ylim(0.49, 1.01)
                
            if variable.endswith("num_sample"):
                if variable.startswith("malignant"):
                    axd[plot_key].set_xlim(left=-50)
                elif variable.startswith("melanoma"):    
                    axd[plot_key].set_xlim(left=-10)
                
            if variable.endswith("num_concept"):
                axd[plot_key].set_xlim(left=-1)
                axd[plot_key].set_xlim(right=50) 
                
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(3)                     
            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            axd[plot_key].text(x=-0.15, y=1.05, transform=axd[plot_key].transAxes,
                                     s=["B", "C", "D", "E"][idx2], fontsize=35, weight='bold')  
            
            
            
            
            if idx2==1:
                legend_elements=[Line2D([], [], color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], label="MONET+CBM (SkinCon)", linestyle='None', marker='X', markersize=20),
                                 Line2D([0], [0], color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1], linewidth=10, label="Manual Label (SkinCon)")]
                axd[plot_key].legend(handles=legend_elements, 
                            ncol=2, 
                            handlelength=3,
                            handletextpad=0.6, 
                            columnspacing=1.5,
                            fontsize=23,
                            loc='lower center', 
                            bbox_to_anchor=(1, -0.33)).set_zorder(100)              
            
            #record_all_df_filtered=record_all_df_filtered[record_all_df_filtered["method"]=="automatic_monet_full"]    

# fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.pdf", bbox_inches='tight')

# fig.savefig(log_dir/"plots"/"main_cbm.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm.pdf", bbox_inches='tight')
# # plt.close(fig)

In [None]:
x1 = np.array([65, 75, 86, 69, 60, 81,  88, 53, 75, 73])
x2  = np.array([77, 98, 92, 77, 65, 77, 100, 73, 93, 75])

In [None]:
np.mean(x1-x2)/np.std(x1-x2, ddof=1)

In [None]:
record_all_df_perf_filtered_pvalue=record_all_df_perf_filtered.groupby("task")\
.apply(lambda x: x.groupby("method")\
.apply(lambda y: 
       
    
    (
        scipy.stats.ttest_rel(y.set_index('random_seed')["auc"], 
        x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"], alternative="less").pvalue,
       
        (np.mean(y.set_index('random_seed')["auc"]-x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"]))\
        /(np.std(y.set_index('random_seed')["auc"]-x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"], ddof=1)),
    
        (np.mean(y.set_index('random_seed')["auc"]-x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"]))
    )
      
      )
      
).T


In [None]:
record_all_df_perf_filtered_pvalue.loc["automatic_curated_vanilla_full"]["melanoma"]

In [None]:
fig.savefig(log_dir/"plots"/"main_cbm.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"main_cbm.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"main_cbm.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"main_cbm.pdf", bbox_inches='tight')

In [None]:
# fig.savefig(log_dir/"plots"/"main_cbm.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm.pdf", bbox_inches='tight')

In [None]:
pd.DataFrame(coef_dict_list)["concept_name"].unique()

In [None]:
fig.savefig(log_dir/"plots"/"main_cbm.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"main_cbm.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"main_cbm.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"main_cbm.pdf", bbox_inches='tight')

In [None]:
import scipy

In [None]:
record_all_df_perf_filtered

In [None]:
record_all_df_perf_filtered

In [None]:
# record_all_df_perf_filtered.groupby(["task","random_seed"]).apply(lambda x: 
# ((x[x["method"]=="resnet"]["auc"].iloc[0]-x[x["method"]=="automatic_curated_monet_full"]["auc"].iloc[0])>0)
# )

record_all_df_perf=pd.DataFrame(record_all_list)
record_all_df_perf=record_all_df_perf[record_all_df_perf["random_seed"]<20]
record_all_df_perf_filtered=record_all_df_perf[record_all_df_perf["is_clean"]=="clean_only"]
record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["method"].isin(main_method_list)]
record_all_df_perf_filtered=record_all_df_perf_filtered[(~record_all_df_perf_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))|\
                                             ((record_all_df_perf_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))&(record_all_df_perf_filtered["alpha"]==alpha)&(record_all_df_perf_filtered["temp"]==temp))]

record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["task"]=="malignancy"]
# record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["task"]=="melanoma"]

record_all_df_perf_filtered.groupby("task")\
.apply(lambda x: x.groupby("method")\
.apply(lambda y: 
((y.set_index('random_seed')["auc"] - x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"])>0)\
.sum())).T

In [None]:
# record_all_df_perf_filtered.groupby(["task","random_seed"]).apply(lambda x: 
# ((x[x["method"]=="resnet"]["auc"].iloc[0]-x[x["method"]=="automatic_curated_monet_full"]["auc"].iloc[0])>0)
# )

record_all_df_perf_filtered.groupby("task")\
.apply(lambda x: x.groupby("method")\
.apply(lambda y: 
1-scipy.stats.ttest_rel(y.set_index('random_seed')["auc"], 
x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"], alternative="greater").pvalue)).T

In [None]:
# record_all_df_perf_filtered.groupby(["task","random_seed"]).apply(lambda x: 
# ((x[x["method"]=="resnet"]["auc"].iloc[0]-x[x["method"]=="automatic_curated_monet_full"]["auc"].iloc[0])>0)
# )

record_all_df_perf_filtered.groupby("task")\
.apply(lambda x: x.groupby("method")\
.apply(lambda y: 
scipy.stats.ttest_rel(y.set_index('random_seed')["auc"], 
x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"], alternative="less").pvalue)).T

In [None]:
from scipy.stats import wilcoxon

In [None]:
record_all_df_skincon_filtered["method"].unique()

In [None]:
record_all_df_skincon_filtered_obs.plot.scatter(x="num_sample", y="auc")

In [None]:
record_all_df_skincon_filtered_obs.plot.scatter(x="sample_prop", y="auc")

In [None]:
record_all_df_skincon_filtered_obs.plot.scatter(x="num_sample_train", y="auc")

In [None]:
record_all_df_skincon_filtered_ref

In [None]:
#                 record_all_df_skincon_filtered_ref_matched=pd.DataFrame(list(itertools.product(record_all_df_skincon_filtered_obs_sample_prop.values, 
#                                                     record_all_df_skincon_filtered_ref["auc"].values)),
#                             columns=["num_sample_train_pseudo", "auc"])                
#                 b=sns.lineplot(x="num_sample_train_pseudo", y="auc", 
#                                color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], linewidth=4, linestyle="--",
#                                data=record_all_df_skincon_filtered_ref_matched, ax=axd[plot_key])
#                 if b.legend_ is not None:
#                     b.legend_.remove()

In [None]:
record_all_df_perf_filtered

In [None]:
record_all_df_perf_filtered.groupby(["method"]).apply(lambda x: len(x))

In [None]:
main_method_list

In [None]:
import itertools

In [None]:
record_all_df_perf_filtered.groupby("method")["auc"].apply(lambda x: {"mean": x.mean(),
                                                                      "std": x.std(),
                                                                      "q3": x.quantile(q=0.75),                                                                      
                                                                      "median": x.median(),
                                                                      "q1": x.quantile(q=0.25),
                                                                     })

In [None]:
pd.DataFrame(coef_dict_list).groupby("concept_name")["coef"].to_csv(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}_{task}.csv")




In [None]:
record_all_df_weight_filtered.groupby("method")["auc"].mean()

In [None]:
plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                Paired[12][3],
                                                                                Paired[12][5],
                                                                                Paired[12][7],
                                                                                Paired[12][9],
                                                                                Paired[12][11]
                                                                                ]])

main_method_list=["automatic_monet_full", 
                  "skincon_manual", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  "automatic_vanilla_full"]


fig = plt.figure(figsize=(3*10, 3*(2.5) + 0.25*3))

box1 = gridspec.GridSpec(1, 1,
                         height_ratios=[2.5],
                         wspace=0.0,
                         hspace=0.25)

axd={}
for idx1, stage in enumerate(["ablation"]):
            
    if stage=="ablation":
        box2 = gridspec.GridSpecFromSubplotSpec(1, 2,
                        subplot_spec=box1[idx1], width_ratios=[1, 1], wspace=0.2, hspace=0.)        
        for idx2, variable in enumerate(["num_concept", "num_samples"]):
#             elif investigation_type=="statistics":
            plot_key=f"{stage}_{variable}"
            ax=plt.Subplot(fig, box2[idx2])
            fig.add_subplot(ax)
            axd[plot_key]=ax    
            
for idx1, stage in enumerate(["ablation"]):            
    if stage=="ablation":
  
        for idx2, variable in enumerate(["num_concept",  "num_samples"]):
#             elif investigation_type=="statistics":
            plot_key=f"{stage}_{variable}"    
    
            record_all_df=pd.DataFrame(record_all_list)
            record_all_df_filtered=record_all_df[record_all_df["is_clean"]=="clean_only"]
#             sdsd
            if variable=="num_concept":
                record_all_df_filtered=record_all_df_filtered[record_all_df_filtered["method"]=="automatic_curated_monet_less_concept"]
                b=sns.lineplot(x="num_concept", y="auc", hue="task", data=record_all_df_filtered, ax=axd[plot_key])
                b.legend_.remove()
                
                axd[plot_key].set_xlabel("Num. of concepts", fontsize=30)
                axd[plot_key].set_xlim(1-0.2, 11.5)
            elif variable=="num_reference":
                record_all_df_filtered=record_all_df_filtered[record_all_df_filtered["method"]=="automatic_curated_monet_less_reference"]
                b=sns.lineplot(x="num_ref_concepts", y="auc", hue="task", data=record_all_df_filtered, ax=axd[plot_key])
                b.legend_.remove()
                
                axd[plot_key].set_xlabel("Num. of reference concepts", fontsize=30)
                axd[plot_key].set_xlim(-0.2, 5.1)
                
            elif variable=="num_samples":
                record_all_df_filtered=record_all_df_filtered[record_all_df_filtered["method"]=="automatic_curated_monet_less_sample"]                
                b=sns.lineplot(x="sample_prop", y="auc", hue="task", data=record_all_df_filtered, ax=axd[plot_key])
                b.legend_.remove()
                
                axd[plot_key].set_xlabel("Proportion of training data", fontsize=30)
            else:
                raise ValueError
                
            axd[plot_key].set_ylim(0.61, 0.94)
            
                
            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)    
            axd[plot_key].tick_params(axis='y', which='major', labelsize=25)
            
            if variable=="num_concept":
                axd[plot_key].xaxis.set_major_locator(MultipleLocator(2))
                axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].tick_params(axis='x', which='major', labelsize=25)                
                
            elif variable=="num_reference":
                axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
                axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].tick_params(axis='x', which='major', labelsize=25)
                       
            elif variable=="num_samples":
                #axd[plot_key].set_xticks([0.05, 0.1 , 0.2 , 0.4 , 0.6 , 0.8 , 1.])
                #.set_xticks([2,4,6,8,10])
                axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].tick_params(axis='x', which='major', labelsize=30)
            
            if idx2==0:
                axd[plot_key].set_ylabel("Area under the ROC curve", fontsize=30)
            else:
                #axd[plot_key].set_ylabel(None)    
                axd[plot_key].set_ylabel("Area under the ROC curve", fontsize=30)
            
                
                #axd[plot_key].tick_params(axis='x', which='major', left=False, labelleft=False)
                
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(3)                     
            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            axd[plot_key].text(x=-0.15, y=1.0, transform=axd[plot_key].transAxes,
                                     s=["A", "B"][idx2], fontsize=35, weight='bold')               
            
            if idx2==1:
                legend_elements=[Line2D([0], [0], color=np.array(Paired[12][1])/256, linewidth=10, label="Malignancy"),
                                 Line2D([0], [0], color=np.array(Paired[12][3])/256, linewidth=10, label="Melanoma")]
                axd[plot_key].legend(handles=legend_elements, 
                            ncol=2, 
                            handlelength=3,
                            handletextpad=0.6, 
                            columnspacing=1.5,
                            fontsize=30,
                            loc='lower center', bbox_to_anchor=(-0.15, -0.3))              
            
# fig.savefig(log_dir/"plots"/"main_cbm_ablation.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm_ablation.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm_ablation.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm_ablation.pdf", bbox_inches='tight')
#plt.close(fig)               

In [None]:
record_all_df_filtered["method"].unique()

In [None]:
import matplotlib.pyplot as plt

trianglex = [ 1, 10, 7, 1 ] 
triangley = [ 2, 8, 4, 2 ]    
triangle2x = [ 13, 25, 21, 13]
triangle2y = [ 5,  7 , 14, 5 ]

plt.figure('Triangles')
for i in range(3):
    plt.plot(trianglex, triangley, 'o-')
plt.fill(trianglex, triangley)


plt.show()

In [None]:
box2

In [None]:
        for p in ax.patches:
            _x = p.get_x() + p.get_width() / 2
            _y = p.get_y() + p.get_height()
            value = '{:.2f}'.format(p.get_height())
            ax.text(_x, _y, value, ha="center")

In [None]:
for rect in weight_bar:
    height = rect.get_height()
    print(height)
    #plt.text(rect.get_x() + rect.get_width() / 2.0, height, f'{height:.0f}', ha='center', va='bottom')


In [None]:
weight_bar

In [None]:
weight_bar

In [None]:
plt.rcParams["axes.prop_cycle"].by_key()["color"][0]

In [None]:
record_all_df_skincon_filtered_ref=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]=="automatic_skincon_monet_full"]
if variable.endswith("num_sample"):
    record_all_df_skincon_filtered_obs=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]=="skincon_manual_less_sample"]
    record_all_df_skincon_filtered_obs_sample_prop=record_all_df_skincon_filtered_obs.groupby("sample_prop")["num_sample_train"].mean()
    record_all_df_skincon_filtered_obs["num_sample_train_pseudo"]=record_all_df_skincon_filtered_obs.apply(lambda x: record_all_df_skincon_filtered_obs_sample_prop[x["sample_prop"]], axis=1)
    b=sns.lineplot(x="num_sample_train_pseudo", y="auc", style="method", data=record_all_df_skincon_filtered_obs, ax=axd[plot_key])


In [None]:
import itertools

In [None]:
record_all_df_skincon_filtered_obs_sample_prop

In [None]:
record_all_df_skincon_filtered_ref["auc"]

In [None]:
record_all_df_skincon_filtered_obs_sample_prop.index