In [None]:
import os

import hydra
import omegaconf
import pyrootutils

root = pyrootutils.setup_root(os.path.abspath("model_auditing.ipynb"), pythonpath=True)
import os

os.chdir(root)

In [None]:
import sys

sys.path.append(str(root / "src"))

In [None]:
import os
import time
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib import gridspec
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn.metrics
import torch
import tqdm
from IPython import display
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

import clip

In [None]:
# !cp /System/Library/Fonts/Supplemental ~/.local/share/fonts/
# rm -fr ~/.cache/matplotlib
from matplotlib import font_manager
from matplotlib.lines import Line2D
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.edgecolor']='1.0'
plt.rcParams['legend.framealpha']=0

# https://github.com/dsc/colorbrewer-python/blob/master/colorbrewer.py

Set1 = {
    3: [[228,26,28], [55,126,184], [77,175,74]],
    4: [[228,26,28], [55,126,184], [77,175,74], [152,78,163]],
    5: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0]],
    6: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51]],
    7: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40]],
    8: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191]],
    9: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191], [153,153,153]],
}

Paired = {
    3: [(166,206,227), [31,120,180], [178,223,138]],
    4: [[166,206,227], [31,120,180], [178,223,138], [51,160,44]],
    5: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153]],
    6: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28]],
    7: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111]],
    8: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0]],
    9: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214]],
    10: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154]],
    11: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153]],
    12: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153], [177,89,40]]
}

color_qual_7=['#F53345',
            '#87D303',
            '#04CBCC',
            '#8650CD',
            (160/256, 95/256, 0),
            '#F5A637',              
            '#DBD783',            
             ]

pd.set_option('display.max_rows', 500)

In [None]:
import scipy.special
import tqdm.contrib.concurrent

In [None]:
from MONET.datamodules.multiplex_datamodule import MultiplexDatamodule
from MONET.utils.loader import custom_collate_per_key, dataloader_apply_func
from MONET.utils.metrics import skincon_calcualte_auc_all
from MONET.utils.static import (
    concept_to_prompt,
    fitzpatrick17k_disease_label,
    fitzpatrick17k_ninelabel,
    fitzpatrick17k_threelabel,
    skincon_cols,
)
from MONET.utils.text_processing import generate_prompt_token_from_concept

In [None]:
def wandb_to_exppath(wandb, log_path="/gscratch/cse/chanwkim/MONET_log/train/runs"):
    log_path = Path(log_path)
    for experiment in os.listdir(log_path):
        if os.path.exists(log_path / experiment / "wandb"):
            filenames = os.listdir(log_path / experiment / "wandb")
            filename = [filename for filename in filenames if filename.startswith("run")][0][-8:]
            if filename == wandb:
                return log_path / experiment
    raise RuntimeError("not found")


exppath = wandb_to_exppath(
    wandb="baqqmm5v", log_path="/projects/leelab2/chanwkim/dermatology_datasets/logs/train/runs"
)
print([exppath / "checkpoints" / ckpt for ckpt in os.listdir(exppath / "checkpoints/")])

In [None]:
variable_dict={}

In [None]:
def setup_dataloader(dataset_name):
    if dataset_name=="clinical_fd_clean_nodup":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "clinical_fd_clean_nodup=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()      
         
        
    elif dataset_name=="isic":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "isic=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()    
        
    return {"dataloader": dataloader}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup""isic"]:  
    variable_dict.setdefault(dataset_name, {})
    variable_dict[dataset_name].update(setup_dataloader(dataset_name))

In [None]:
!gpustat

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:6"

In [None]:
cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

In [None]:
model = hydra.utils.instantiate(cfg_model)
model.to(model_device)
model.eval()

In [None]:
model_path_dir = {
    "zt0n2xd0": "logs/train/runs/2023-01-17_20-58-15/checkpoints/last.ckpt",
}
if model_name != "ViT-L/14":
    model_path = model_path_dir[model_name]
    loaded = torch.load(model_path, map_location=model_device)
    model.load_state_dict(loaded["state_dict"])
    print('loaded')

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:5"

cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

model_vanilla = hydra.utils.instantiate(cfg_model)
model_vanilla.to(model_device)
model_vanilla.eval()

In [None]:
log_dir = Path("logs")

In [None]:
def batch_func(batch):
    with torch.no_grad():
        batch["image"] = batch["image"].to(model_device)
        image_features = model.model_step_with_image(batch)["image_features"]
        image_features_vanilla = model_vanilla.model_step_with_image(batch)["image_features"]
    # print(batch["metadata"])
    return {
        "image_features": image_features.detach().cpu(),
        "image_features_vanilla": image_features_vanilla.detach().cpu(),
        "metadata": batch["metadata"],
    }

def setup_features(dataset_name, dataloader):
    if dataset_name=="isic":
        image_features=torch.load(log_dir/"image_features"/"isic.pt")
        metadata_all=dataloader.dataset.metadata_all

        return {"image_features":image_features, 
#                 "image_features_vanilla":image_features_vanilla,
                "metadata_all": metadata_all}
    
    else:
        loader_applied = dataloader_apply_func(
            dataloader=dataloader,
            func=batch_func,
            collate_fn=custom_collate_per_key,
        )
        image_features = loader_applied["image_features"].cpu()
        image_features_vanilla = loader_applied["image_features_vanilla"].cpu()
        metadata_all = loader_applied["metadata"]

        return {"image_features":image_features, 
                "image_features_vanilla":image_features_vanilla,
                "metadata_all": metadata_all}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup", "isic"]:
    print(dataset_name)
    variable_dict[dataset_name].update(setup_features(dataset_name, variable_dict[dataset_name]["dataloader"])) 

In [None]:
!gpustat

In [None]:
import torchvision
def get_layer_feature(model, feature_layer_name, image):
    # image = self.normalize(self.toTensor(img)).unsqueeze(0).to(self.device)
    # embedding = torch.zeros(image.shape[0], num_features, 1, 1).to(image.device)
    feature_layer = model._modules.get(feature_layer_name)

    embedding = []

    def copyData(module, input, output):
        embedding.append(output.data)

    h = feature_layer.register_forward_hook(copyData)
    out = model(image.to(image.device))
    h.remove()
    embedding = embedding[0]
    assert embedding.shape[0] == image.shape[0], f"{embedding.shape[0]} != {image.shape[0]}"
    assert embedding.shape[2] == 1, f"{embedding.shape[2]} != 1"
    assert embedding.shape[2] == 1, f"{embedding.shape[3]} != 1"
    return embedding[:, :, 0, 0]

def batch_func_efficientnet(batch):
    with torch.no_grad():
        efficientnet_feature = get_layer_feature(
            efficientnet, "avgpool", batch["image"].to(efficientnet_device)
        )

    return {
        "efficientnet_feature": efficientnet_feature.detach().cpu(),
        "metadata": batch["metadata"],
    }

def setup_features_efficientnet(dataset_name, dataloader):
    loader_applied = dataloader_apply_func(
        dataloader=dataloader,
        func=batch_func_efficientnet,
        collate_fn=custom_collate_per_key,
    )
    image_features = loader_applied["efficientnet_feature"].cpu()

    return {"efficientnet_feature":image_features}

In [None]:
efficientnet_device="cuda:7"
efficientnet = torchvision.models.efficientnet_v2_s(
    weights=torchvision.models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
).to(efficientnet_device)
efficientnet.eval()

for dataset_name in ["clinical_fd_clean_nodup", "isic"]:
    print(dataset_name)
    variable_dict[dataset_name].update(setup_features_efficientnet(dataset_name, variable_dict[dataset_name]["dataloader"])) 

In [None]:
variable_dict["clinical_fd_clean_nodup"].keys()

In [None]:
diagnosis_malignant_mapping=\
{'AIMP':'indeterminate',
'acrochordon':'benign',
'actinic keratosis':'benign', # 
'angiofibroma or fibrous papule':'benign', 
'angiokeratoma':'benign',
'angioma':'benign',
'atypical melanocytic proliferation':'indeterminate',
'atypical spitz tumor':'indeterminate', #
'basal cell carcinoma':'malignant', #
'cafe-au-lait macule':'benign',
'clear cell acanthoma':'benign',
'dermatofibroma':'benign', #
'lentigo NOS':'benign',
'lentigo simplex':'benign',
'lichenoid keratosis':'benign',
'melanoma':'malignant',
'melanoma metastasis':'malignant',
'neurofibroma':'benign',
'nevus':'benign',
'other':'indeterminate',
'pigmented benign keratosis':'benign', #??
'scar':'benign',
'seborrheic keratosis':'benign',
'solar lentigo':'benign',
'squamous cell carcinoma':'malignant',
'vascular lesion':'unknown', # 
'verruca':'benign'
}

def map_diagnosis_malignant(diagnosis, benign_malignant):
#     if diagnosis=="basal cell carcinoma":
#         print(diagnosis_malignant_mapping[diagnosis], benign_malignant)    
    if isinstance(benign_malignant, str):
        return benign_malignant
    elif diagnosis in diagnosis_malignant_mapping.keys():
        return diagnosis_malignant_mapping[diagnosis]
    elif np.isnan(diagnosis):
        return "indeterminate"
    else:
        raise RuntimeError

In [None]:
def set_config(dataset_name, metadata_all):
    if "clinical_fd_clean" in dataset_name:
        y_pos=(((metadata_all["source"]=="fitz")&(metadata_all["three_partition_label"]=="malignant"))|
              ((metadata_all["source"]=="ddi")&(metadata_all["malignant"] == True))).values
        
        valid_idx=(metadata_all["skincon_Do not consider this image"]!=1).values
        
        concept_list=skincon_cols
        
        
    elif dataset_name=="isic":  
        metadata_all["benign_malignant_full"]=\
        metadata_all.apply(lambda x: map_diagnosis_malignant(x["diagnosis"], x["benign_malignant"]), axis=1)
        #metadata_all["benign_malignant_full"].value_counts()
        #metadata_all.groupby("diagnosis").apply(lambda x: x["benign_malignant_full"].value_counts())
        metadata_all["benign_malignant_bool"]=metadata_all["benign_malignant_full"].str.contains("malignant")
        
        y_pos=metadata_all["benign_malignant_bool"].values
        
        valid_idx = (metadata_all["benign_malignant_full"].str.contains("malignant")|metadata_all["benign_malignant_full"].str.contains("benign")).values
        
        concept_list=skincon_cols+\
                            ["purple pen", 
                             "nail", 
                             "pinkish", 
                             "red", 
                             "hair", 
                             "orange sticker", 
                             "dermoscope border",
                             "gel",
                             "malignant",
                             "melanoma"]      
        
        concept_list+=concept_list+\
                        [f"disease_{disease_name}" for disease_name in ['seborrheic keratosis', 'nevus', 'squamous cell carcinoma',
                        'melanoma', 'lichenoid keratosis', 'lentigo',
                        'actinic keratosis', 'basal cell carcinoma', 'dermatofibroma',
                        'atypical melanocytic proliferation', 'verruca',
                        'clear cell acanthoma', 'angiofibroma or fibrous papule', 'scar',
                        'angioma', 'atypical spitz tumor', 'solar lentigo', 'AIMP',
                        'neurofibroma', 'lentigo simplex', 'acrochordon', 
                        'angiokeratoma', 'vascular lesion', 'cafe-au-lait macule',
                        'pigmented benign keratosis']]
        
    return {"valid_idx": valid_idx,
            "y_pos": y_pos,
            "metadata_all": metadata_all,
            "concept_list": concept_list}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup", "isic"]:
    variable_dict[dataset_name].update(
        set_config(dataset_name, variable_dict[dataset_name]["metadata_all"])
    )

In [None]:
def normalize_embedding(dataset_name, image_features):
    #prompt_ref_tokenized = clip.tokenize(prompt_ref, truncate=True)
    #output = model.model_step_with_text({"text": prompt_ref_tokenized.to(model_device)})
    #prompt_ref_embedding=output["text_features"].detach().cpu()
    #prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=1, keepdim=True)      
    
    image_features_norm = image_features / image_features.norm(dim=1, keepdim=True)
    
    return {"image_features_norm": image_features_norm}    

In [None]:
for dataset_name in ["clinical_fd_clean_nodup", "isic"]:
    if dataset_name=="clinical_fd_clean_nodup":
        variable_dict[dataset_name].update(
            {"image_features_vanilla_norm":normalize_embedding(dataset_name, 
                            variable_dict[dataset_name]["image_features_vanilla"])["image_features_norm"]}
        )           
    variable_dict[dataset_name].update(
        {"image_features_norm":normalize_embedding(dataset_name, 
                        variable_dict[dataset_name]["image_features"])["image_features_norm"]}
    )  
#     variable_dict[dataset_name].update(
#         {"image_features_vanilla_norm":normalize_embedding(dataset_name, 
#                         variable_dict[dataset_name]["image_features_vanilla"])["image_features_norm"]}
#     )    

In [None]:
def get_concept_embedding(dataset_name, concept_list, model):
    prompt_info={}
    
    for concept_name in concept_list:
        if dataset_name=="clinical_fd_clean_nodup":
            prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
            prompt_engineered_list = []
            for k, v in prompt_dict.items():
                if k != "original":
                    prompt_engineered_list += v    
            concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                      for prompt in prompt_engineered_list]))
            prompt_template_list=["This is skin image of {}", "This is dermatology image of {}", "This is image of {}"]
            #prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]            
            prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
            
            prompt_ref = [["This is skin image"], ["This is dermatology image"], ["This is image"]]        
        
        elif dataset_name=="isic":
            if concept_name.startswith("skincon_"):
                prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
                prompt_engineered_list = []
                for k, v in prompt_dict.items():
                    if k != "original":
                        prompt_engineered_list += v

                concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                                          for prompt in prompt_engineered_list]))
                prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                prompt_ref = ["This is dermatoscopy", "This is dermoscopy"]
                prompt_target=[[prompt_template.format(term) for term in concept_term_list] for prompt_template in prompt_template_list]
                prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]] 
                
                prompt_target=[[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]]
                prompt_ref = [["This is dermatoscopy", "This is dermoscopy"]] 
            elif concept_name.startswith("disease_"):  
                if concept_name=="disease_AIMP":
                    disease_name=concept_name[8:]
                    prompt_target=[["This is dermatoscopy of AIMP",
                                    "This is dermatoscopy of Atypical intraepidermal melanocytic proliferation"],
                                   ["This is dermoscopy of AIMP",
                                    "This is dermoscopy of Atypical intraepidermal melanocytic proliferation"]]
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                else:
                    disease_name=concept_name[8:]
                    prompt_target=[[f"This is dermatoscopy of {disease_name}"],
                                   [f"This is dermoscopy of {disease_name}"]] 
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
            else:
                if concept_name=="gel":
                    #concept_term_list=["water drop", 'gel', "fluid"]
                    prompt_target=[["This is dermatoscopy of water drop", "This is dermatoscopy of gel", "This is dermatoscopy of dermoscopy liquid"],
                                   ["This is dermoscopy of water drop", "This is dermoscopy of gel", "This is dermoscopy of dermoscopy liquid"],
                                  ]
                    prompt_target=[["This is dermatoscopy of gel"],
                                   ["This is dermoscopy of gel"],
                                  ]                    
                    
                    prompt_ref = [["This is dermatoscopy"], 
                                  ["This is dermoscopy"]]
                elif concept_name=="dermoscope border":
                    concept_term_list=["dermoscope"]
                    prompt_target=["This is hole"]
                    prompt_target=["This is scope hole", "This is circle", "This is dermoscope"]
                    #prompt_target=[["This is dermatoscopy of dermoscope", "This is dermatoscopy of dermoscopy"]]
                    prompt_target=[["This is dermatoscopy of dermoscopy"]]
                    prompt_ref = [["This is dermatoscopy"]]
                    
                else:
                    concept_term_list=[concept_name]
                    prompt_template_list=["This is dermatoscopy of {}", "This is dermoscopy of {}"]
                    prompt_target=[prompt_template.format(term) for prompt_template in prompt_template_list for term in concept_term_list]
                    
                    prompt_ref = [["This is dermatoscopy"], ["This is dermoscopy"]]
                
        
        #print(prompt_target, prompt_ref)
        # target embedding
        prompt_target_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_target]
        with torch.no_grad():
            prompt_target_embedding = torch.stack([model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_target_tokenized])
        prompt_target_embedding_norm=prompt_target_embedding/prompt_target_embedding.norm(dim=2, keepdim=True)          

        # reference embedding
        prompt_ref_tokenized=[clip.tokenize(prompt_list, truncate=True) for prompt_list in prompt_ref]
        with torch.no_grad():
            prompt_ref_embedding = torch.stack([model.model_step_with_text({"text": prompt_tokenized.to(model_device)})[
                    "text_features"].detach().cpu() for prompt_tokenized in prompt_ref_tokenized])
        prompt_ref_embedding_norm=prompt_ref_embedding/prompt_ref_embedding.norm(dim=2, keepdim=True)                

        prompt_info[concept_name]={"prompt_ref_embedding_norm":prompt_ref_embedding_norm,
                                   "prompt_target_embedding_norm":prompt_target_embedding_norm,
                                  }
        print(dataset_name, concept_name, prompt_target, prompt_ref)
        print(prompt_target_embedding_norm)
        print(prompt_ref_embedding_norm)
        del prompt_ref
        del prompt_target
    
    return {"prompt_info": prompt_info}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup", "isic"]:
    if dataset_name=="clinical_fd_clean_nodup":
        variable_dict[dataset_name].update(
            {"prompt_info_vanilla":get_concept_embedding(dataset_name, 
                          concept_list=variable_dict[dataset_name]["concept_list"],
                                 model=model_vanilla)["prompt_info"]})            
    variable_dict[dataset_name].update(
        {"prompt_info":get_concept_embedding(dataset_name, 
                      concept_list=variable_dict[dataset_name]["concept_list"],
                             model=model)["prompt_info"]})  

In [None]:
def calculate_similaity_score(image_features_norm, 
                              prompt_target_embedding_norm,
                              prompt_ref_embedding_norm,
                              temp=1,
                              normalize=True):

    target_similarity=prompt_target_embedding_norm.float()@image_features_norm.T.float()
    ref_similarity=prompt_ref_embedding_norm.float()@image_features_norm.T.float()
    
    
    target_similarity_mean=target_similarity.mean(dim=[1])
    ref_similarity_mean=ref_similarity.mean(axis=1)
         
    if normalize:
        similarity_score=scipy.special.softmax([target_similarity_mean.numpy()/temp, 
                            ref_similarity_mean.numpy()/temp], axis=0)[0,:].mean(axis=0)   
    else:
        similarity_score=target_similarity_mean.mean(axis=0)

    
    return similarity_score

In [None]:
from sklearn.model_selection import train_test_split
from MONET.datamodules.components.base_dataset import BaseDataset

def get_training_data(dataloader, metadata_all, valid_idx, y_pos, subset_idx_train, subset_idx_test, n_px=None):
    metadata_all_new = dataloader.dataset.metadata_all.copy()
    metadata_all_new["label"]=y_pos.astype(int)
    metadata_all_new_train=metadata_all_new[valid_idx&subset_idx_train]
    # metadata_all_new=metadata_all_new.iloc[list(true_set.union(false_set))]

    train_idx, val_idx = train_test_split(metadata_all_new_train.index, test_size=0.2, random_state=42)
    
    print("train:", len(metadata_all_new_train))

    if n_px is None:
        n_px=dataloader.dataset.n_px
    
    data_train = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new_train.loc[train_idx],
        integrity_level="weak",
        return_label=["label"],
    )

    data_val = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new_train.loc[val_idx],
        integrity_level="weak",
        return_label=["label"],
    )
    
    data_test = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new[valid_idx&subset_idx_test],
        integrity_level="weak",
        return_label=["label"],
    )    

    from MONET.utils.loader import custom_collate

    train_dataloader = torch.utils.data.DataLoader(
        dataset=data_train,
        batch_size=64,
        num_workers=4,
#         pin_memory=True,
#         persistent_workers=False,
        shuffle=True,
        collate_fn=custom_collate,
    )
    val_dataloader = torch.utils.data.DataLoader(
        dataset=data_val,
        batch_size=64,
        num_workers=4,
#         pin_memory=True,
#         persistent_workers=False,
        shuffle=False,
        collate_fn=custom_collate,
    )   
    
    test_dataloader = torch.utils.data.DataLoader(
        dataset=data_val,
        batch_size=64,
        num_workers=4,
#         pin_memory=True,
#         persistent_workers=False,
        shuffle=False,
        collate_fn=custom_collate,
    )       
    
    return train_dataloader, val_dataloader, test_dataloader

In [None]:
import torchvision
from torch import nn
from torch.nn import functional as F
from torchmetrics import AUROC

class Classifier(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.backbone = torchvision.models.resnet50(weights="ResNet50_Weights.IMAGENET1K_V1")

        for param in self.backbone.parameters():
            param.requires_grad = True
            # pass

        head_in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.head = nn.Linear(head_in_features, output_dim)

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x
    
class Inception(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.inception = torchvision.models.inception_v3(weights="Inception_V3_Weights.IMAGENET1K_V1")
        self.fc = nn.Linear(2048, output_dim)

    def forward(self, x):
        x = self.input_to_representation(x)
        x = self.fc(x)
        # N x 200
        return x

    def input_to_representation(self, x):
        # N x 3 x 299 x 299
        x = self.inception.Conv2d_1a_3x3(x)
        # N x 32 x 149 x 149
        x = self.inception.Conv2d_2a_3x3(x)
        # N x 32 x 147 x 147
        x = self.inception.Conv2d_2b_3x3(x)
        # N x 64 x 147 x 147
        x = self.inception.maxpool1(x)
        # N x 64 x 73 x 73
        x = self.inception.Conv2d_3b_1x1(x)
        # N x 80 x 73 x 73
        x = self.inception.Conv2d_4a_3x3(x)
        # N x 192 x 71 x 71
        x = self.inception.maxpool2(x)
        # N x 192 x 35 x 35
        x = self.inception.Mixed_5b(x)
        # N x 256 x 35 x 35
        x = self.inception.Mixed_5c(x)
        # N x 288 x 35 x 35
        x = self.inception.Mixed_5d(x)
        # N x 288 x 35 x 35
        x = self.inception.Mixed_6a(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6b(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6c(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6d(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6e(x)
        # N x 768 x 17 x 17
        aux: Optional[torch.Tensor] = None
        if self.inception.AuxLogits is not None:
            if self.inception.training:
                aux = self.inception.AuxLogits(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_7a(x)
        # N x 1280 x 8 x 8
        x = self.inception.Mixed_7b(x)
        # N x 2048 x 8 x 8
        x = self.inception.Mixed_7c(x)
        # N x 2048 x 8 x 8
        # Adaptive average pooling
        x = self.inception.avgpool(x)
        # N x 2048 x 1 x 1
        x = self.inception.dropout(x)
        # N x 2048 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 2048
        return x    


class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
    
def find_thres_best_f1(y_test, y_test_pred):
    precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_test, y_test_pred)
    numerator = 2 * recall * precision
    denom = recall + precision
    f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom!=0))
    max_f1 = np.max(f1_scores)
    max_f1_thresh = thresholds[np.argmax(f1_scores)]    
    return max_f1_thresh 

def train_classifier(train_dataloader, val_dataloader, test_dataloader, classifier_type="resnet", verbose=True):
    if classifier_type=="resnet":
        classifier = Classifier(output_dim=1)
    elif classifier_type=="inception":
        classifier = Inception(output_dim=1)
    classifier_device = "cuda:5"
    classifier.to(classifier_device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=2, verbose=True)
    early_stopper = EarlyStopper(patience=5, min_delta=0)

    train_auroc = AUROC(task="binary")
    val_auroc = AUROC(task="binary")
    for epoch in range(20):
        train_loss = 0
        train_correct = 0
        classifier.train()
        if verbose:
            pbar=tqdm.tqdm(train_dataloader)
        else:
            pbar=train_dataloader        
        for batch in pbar:
            image, label = batch["image"].to(classifier_device), batch["label"].to(classifier_device)
            logits = classifier(image)
            weight = torch.ones(label.shape[0], device=label.device)
            loss = F.binary_cross_entropy_with_logits(
                input=logits[:, 0], target=(label == 1).float()
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * image.size(0)
            train_auroc.update(logits, (label == 1))

        val_loss = 0
        classifier.eval()
        label_list=[]
        logits_list=[]
        with torch.no_grad():   
            if verbose:
                pbar=tqdm.tqdm(val_dataloader)
            else:
                pbar=val_dataloader             
            for batch in pbar:
                image, label = batch["image"].to(classifier_device), batch["label"].to(
                    classifier_device
                )
                logits = classifier(image)
                loss = F.binary_cross_entropy_with_logits(
                    input=logits[:, 0], target=(label == 1).float()
                )
                val_loss += loss.item() * image.size(0)
                val_auroc.update(logits, (label == 1))
                logits_list.append(logits.detach().cpu().numpy())
                label_list.append(label.detach().cpu().numpy())                
        if verbose:
            print(
                f"Train loss: {train_loss/len(train_dataloader.dataset):.3f} AUROC: {train_auroc.compute():.3f} Val loss: {val_loss/len(val_dataloader.dataset):.3f} AUROC: {val_auroc.compute():.3f}"
            )
        scheduler.step(val_loss)
        if early_stopper.early_stop(val_loss):
            print("break")
            break
        train_auroc.reset()
        val_auroc.reset() 
        max_f1_thresh=find_thres_best_f1(y_test=np.hstack(label_list), y_test_pred=np.concatenate(logits_list)[:,0])
        print(max_f1_thresh)
        
        
    

        
    test_auroc = AUROC(task="binary")    
    test_loss = 0
    classifier.eval()
    
    logits_list=[]
    label_list=[]
    metadata_list=[]
    
    with torch.no_grad():
        if verbose:
            pbar=tqdm.tqdm(test_dataloader)
        else:
            pbar=test_dataloader          
        for batch in tqdm.tqdm(test_dataloader):
            image, label = batch["image"].to(classifier_device), batch["label"].to(
                classifier_device
            )
            logits = classifier(image)
            loss = F.binary_cross_entropy_with_logits(
                input=logits[:, 0], target=(label == 1).float()
            )
            test_loss += loss.item() * image.size(0)
            test_auroc.update(logits, (label == 1))
            logits_list.append(logits.detach().cpu().numpy())
            label_list.append(label.detach().cpu().numpy())
            metadata_list.append(batch["metadata"])
            
    if verbose:
        print(
            f"Test loss: {test_loss/len(test_dataloader.dataset):.3f} AUROC: {test_auroc.compute():.3f}"
        )   
    return test_auroc.compute(), classifier, logits_list, label_list, metadata_list, max_f1_thresh

In [None]:
def fdrcorrection(pvals, alpha=0.05, method='indep', is_sorted=False):
    '''
    pvalue correction for false discovery rate.

    This covers Benjamini/Hochberg for independent or positively correlated and
    Benjamini/Yekutieli for general or negatively correlated tests.

    Parameters
    ----------
    pvals : array_like, 1d
        Set of p-values of the individual tests.
    alpha : float, optional
        Family-wise error rate. Defaults to ``0.05``.
    method : {'i', 'indep', 'p', 'poscorr', 'n', 'negcorr'}, optional
        Which method to use for FDR correction.
        ``{'i', 'indep', 'p', 'poscorr'}`` all refer to ``fdr_bh``
        (Benjamini/Hochberg for independent or positively
        correlated tests). ``{'n', 'negcorr'}`` both refer to ``fdr_by``
        (Benjamini/Yekutieli for general or negatively correlated tests).
        Defaults to ``'indep'``.
    is_sorted : bool, optional
        If False (default), the p_values will be sorted, but the corrected
        pvalues are in the original order. If True, then it assumed that the
        pvalues are already sorted in ascending order.

    Returns
    -------
    rejected : ndarray, bool
        True if a hypothesis is rejected, False if not
    pvalue-corrected : ndarray
        pvalues adjusted for multiple hypothesis testing to limit FDR

    Notes
    -----
    If there is prior information on the fraction of true hypothesis, then alpha
    should be set to ``alpha * m/m_0`` where m is the number of tests,
    given by the p-values, and m_0 is an estimate of the true hypothesis.
    (see Benjamini, Krieger and Yekuteli)

    The two-step method of Benjamini, Krieger and Yekutiel that estimates the number
    of false hypotheses will be available (soon).

    Both methods exposed via this function (Benjamini/Hochberg, Benjamini/Yekutieli)
    are also available in the function ``multipletests``, as ``method="fdr_bh"`` and
    ``method="fdr_by"``, respectively.

    See also
    --------
    multipletests

    '''
    

    def _ecdf(x):
        '''no frills empirical cdf used in fdrcorrection
        '''
        nobs = len(x)
        return np.arange(1,nobs+1)/float(nobs)    

    pvals = np.asarray(pvals)
    assert pvals.ndim == 1, "pvals must be 1-dimensional, that is of shape (n,)"

    if not is_sorted:
        pvals_sortind = np.argsort(pvals)
        pvals_sorted = np.take(pvals, pvals_sortind)
    else:
        pvals_sorted = pvals  # alias

    if method in ['i', 'indep', 'p', 'poscorr']:
        ecdffactor = _ecdf(pvals_sorted)
    elif method in ['n', 'negcorr']:
        cm = np.sum(1./np.arange(1, len(pvals_sorted)+1))   #corrected this
        ecdffactor = _ecdf(pvals_sorted) / cm
##    elif method in ['n', 'negcorr']:
##        cm = np.sum(np.arange(len(pvals)))
##        ecdffactor = ecdf(pvals_sorted)/cm
    else:
        raise ValueError('only indep and negcorr implemented')
    reject = pvals_sorted <= ecdffactor*alpha
    if reject.any():
        rejectmax = max(np.nonzero(reject)[0])
        reject[:rejectmax] = True

    pvals_corrected_raw = pvals_sorted / ecdffactor
    pvals_corrected = np.minimum.accumulate(pvals_corrected_raw[::-1])[::-1]
    del pvals_corrected_raw
    pvals_corrected[pvals_corrected>1] = 1
    if not is_sorted:
        pvals_corrected_ = np.empty_like(pvals_corrected)
        pvals_corrected_[pvals_sortind] = pvals_corrected
        del pvals_corrected
        reject_ = np.empty_like(reject)
        reject_[pvals_sortind] = reject
        return reject_, pvals_corrected_
    else:
        return reject, pvals_corrected

In [None]:
from scipy.special import xlogy
def log_loss(
    y_true, y_pred, *, eps="auto", normalize=True, sample_weight=None, labels=None
):
    
    r"""Log loss, aka logistic loss or cross-entropy loss.

    This is the loss function used in (multinomial) logistic regression
    and extensions of it such as neural networks, defined as the negative
    log-likelihood of a logistic model that returns ``y_pred`` probabilities
    for its training data ``y_true``.
    The log loss is only defined for two or more labels.
    For a single sample with true label :math:`y \in \{0,1\}` and
    a probability estimate :math:`p = \operatorname{Pr}(y = 1)`, the log
    loss is:

    .. math::
        L_{\log}(y, p) = -(y \log (p) + (1 - y) \log (1 - p))

    Read more in the :ref:`User Guide <log_loss>`.

    Parameters
    ----------
    y_true : array-like or label indicator matrix
        Ground truth (correct) labels for n_samples samples.

    y_pred : array-like of float, shape = (n_samples, n_classes) or (n_samples,)
        Predicted probabilities, as returned by a classifier's
        predict_proba method. If ``y_pred.shape = (n_samples,)``
        the probabilities provided are assumed to be that of the
        positive class. The labels in ``y_pred`` are assumed to be
        ordered alphabetically, as done by
        :class:`preprocessing.LabelBinarizer`.

    eps : float or "auto", default="auto"
        Log loss is undefined for p=0 or p=1, so probabilities are
        clipped to `max(eps, min(1 - eps, p))`. The default will depend on the
        data type of `y_pred` and is set to `np.finfo(y_pred.dtype).eps`.

        .. versionadded:: 1.2

        .. versionchanged:: 1.2
           The default value changed from `1e-15` to `"auto"` that is
           equivalent to `np.finfo(y_pred.dtype).eps`.

    normalize : bool, default=True
        If true, return the mean loss per sample.
        Otherwise, return the sum of the per-sample losses.

    sample_weight : array-like of shape (n_samples,), default=None
        Sample weights.

    labels : array-like, default=None
        If not provided, labels will be inferred from y_true. If ``labels``
        is ``None`` and ``y_pred`` has shape (n_samples,) the labels are
        assumed to be binary and are inferred from ``y_true``.

        .. versionadded:: 0.18

    Returns
    -------
    loss : float
        Log loss, aka logistic loss or cross-entropy loss.

    Notes
    -----
    The logarithm used is the natural logarithm (base-e).

    References
    ----------
    C.M. Bishop (2006). Pattern Recognition and Machine Learning. Springer,
    p. 209.

    Examples
    --------
    >>> from sklearn.metrics import log_loss
    >>> log_loss(["spam", "ham", "ham", "spam"],
    ...          [[.1, .9], [.9, .1], [.8, .2], [.35, .65]])
    0.21616...
    """
    
    def _weighted_sum(sample_score, sample_weight, normalize=False):
        if normalize:
            return np.average(sample_score, weights=sample_weight)
        elif sample_weight is not None:
            return np.dot(sample_score, sample_weight)
        else:
            return sample_score.sum()    
    y_pred = sklearn.utils.check_array(
        y_pred, ensure_2d=False, dtype=[np.float64, np.float32, np.float16]
    )
    eps = np.finfo(y_pred.dtype).eps if eps == "auto" else eps

    sklearn.utils.check_consistent_length(y_pred, y_true, sample_weight)
    lb = sklearn.preprocessing.LabelBinarizer()
    if labels is not None:
        lb.fit(labels)
    else:
        lb.fit(y_true)

    if len(lb.classes_) == 1:
        if labels is None:
            raise ValueError(
                "y_true contains only one label ({0}). Please "
                "provide the true labels explicitly through the "
                "labels argument.".format(lb.classes_[0])
            )
        else:
            raise ValueError(
                "The labels array needs to contain at least two "
                "labels for log_loss, "
                "got {0}.".format(lb.classes_)
            )

    transformed_labels = lb.transform(y_true)

    if transformed_labels.shape[1] == 1:
        transformed_labels = np.append(
            1 - transformed_labels, transformed_labels, axis=1
        )

    # Clipping
    y_pred = np.clip(y_pred, eps, 1 - eps)

    # If y_pred is of single dimension, assume y_true to be binary
    # and then check.
    if y_pred.ndim == 1:
        y_pred = y_pred[:, np.newaxis]
    if y_pred.shape[1] == 1:
        y_pred = np.append(1 - y_pred, y_pred, axis=1)

    # Check if dimensions are consistent.
    transformed_labels = sklearn.utils.check_array(transformed_labels)
    if len(lb.classes_) != y_pred.shape[1]:
        if labels is None:
            raise ValueError(
                "y_true and y_pred contain different number of "
                "classes {0}, {1}. Please provide the true "
                "labels explicitly through the labels argument. "
                "Classes found in "
                "y_true: {2}".format(
                    transformed_labels.shape[1], y_pred.shape[1], lb.classes_
                )
            )
        else:
            raise ValueError(
                "The number of classes in labels is different "
                "from that in y_pred. Classes found in "
                "labels: {0}".format(lb.classes_)
            )

    # Renormalize
#     print(y_pred)
    y_pred_sum = y_pred.sum(axis=1)
    y_pred = y_pred / y_pred_sum[:, np.newaxis]
#     print(y_pred)
#     print(-xlogy(transformed_labels, y_pred))
    loss = -xlogy(transformed_labels, y_pred).sum(axis=1)
#     print(-xlogy(transformed_labels, y_pred))
#     print(loss)
    return loss
#     return _weighted_sum(loss, sample_weight, normalize)

In [None]:
def similarity_matrix(embeddings,
                      prompt_info,
                     idx):

    concept_similarity_all=[]
    for concept_name in prompt_info.keys():
        concept_similarity=calculate_similaity_score(
            image_features_norm=embeddings,
            prompt_target_embedding_norm=prompt_info[concept_name]["prompt_target_embedding_norm"],
            prompt_ref_embedding_norm=prompt_info[concept_name]["prompt_ref_embedding_norm"],
            temp=1/np.exp(4.5944),
            normalize=True)
        concept_similarity_all.append(pd.Series(concept_similarity, 
                                                index=idx,
                                               name=concept_name
                                               )
                                     )
                                      
#                                       {"concept_name":concept_name,
#                                       "concept_similarity":,
#                                       })
    concept_similarity_all=pd.concat(concept_similarity_all, axis=1)
    
    return concept_similarity_all

In [None]:
for dataset_name in ["clinical_fd_clean_nodup", "isic"]:
    if dataset_name=="clinical_fd_clean_nodup":
        variable_dict[dataset_name].update(
            {"similarity_matrix_vanilla": similarity_matrix(embeddings=variable_dict[dataset_name]["image_features_vanilla_norm"],
                             prompt_info=variable_dict[dataset_name]["prompt_info_vanilla"],
                            idx=variable_dict[dataset_name]["metadata_all"].index
                     )})  
    variable_dict[dataset_name].update(
        {"similarity_matrix": similarity_matrix(embeddings=variable_dict[dataset_name]["image_features_norm"],
                         prompt_info=variable_dict[dataset_name]["prompt_info"],
                        idx=variable_dict[dataset_name]["metadata_all"].index
                 )})  

In [None]:
def check_concept_name(dataset_name, concept_name):
    if dataset_name=="isic":
        if concept_name.startswith("disease"):
            return False
        elif concept_name in ["melanoma", "malignant"]:
            return False
        else:
            return True
    else:
        raise NotImplemented(dataset_name)

In [None]:
#check
from sklearn.model_selection import train_test_split
from MONET.datamodules.components.base_dataset import BaseDataset

def get_training_data_idx(dataloader, valid_idx, y_pos, subset_idx_train, subset_idx_test, n_px=None):
    metadata_all_new = dataloader.dataset.metadata_all.copy()
    metadata_all_new["label"]=y_pos.astype(int)
    
    metadata_all_new_=metadata_all_new[valid_idx]
    
#     print(subset_idx_train)
    
    train_idx, val_idx = train_test_split(np.unique(subset_idx_train), test_size=0.2, random_state=42)
    
    metadata_all_new_train=metadata_all_new_.loc[[i for i in subset_idx_train if i in train_idx]]
    metadata_all_new_val=metadata_all_new_.loc[[i for i in subset_idx_train if i in val_idx]]
    metadata_all_new_test=metadata_all_new_.loc[subset_idx_test]
    
    
    
    print("train:", len(metadata_all_new_train))
    print("val:", len(metadata_all_new_val))
    print("test:", len(metadata_all_new_test))

    if n_px is None:
        n_px=dataloader.dataset.n_px
    
    data_train = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new_train,
        integrity_level="weak",
        return_label=["label"],
    )

    data_val = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new_val,
        integrity_level="weak",
        return_label=["label"],
    )
    
    data_test = BaseDataset(
        image_path_or_binary_dict=dataloader.dataset.image_path_dict,
        n_px=n_px,
        norm_mean=dataloader.dataset.transforms_aftertensor.transforms[1].mean,
        norm_std=dataloader.dataset.transforms_aftertensor.transforms[1].std,
        augment=False,
        metadata_all=metadata_all_new_test,
        integrity_level="weak",
        return_label=["label"],
    )    

    from MONET.utils.loader import custom_collate

    train_dataloader = torch.utils.data.DataLoader(
        dataset=data_train,
        batch_size=64,
        num_workers=4,
#         pin_memory=True,
#         persistent_workers=False,
        shuffle=True,
        collate_fn=custom_collate,
    )
    val_dataloader = torch.utils.data.DataLoader(
        dataset=data_val,
        batch_size=64,
        num_workers=4,
#         pin_memory=True,
#         persistent_workers=False,
        shuffle=False,
        collate_fn=custom_collate,
    )   
    
    test_dataloader = torch.utils.data.DataLoader(
        dataset=data_test,
        batch_size=64,
        num_workers=4,
#         pin_memory=True,
#         persistent_workers=False,
        shuffle=False,
        collate_fn=custom_collate,
    )       
    
    return train_dataloader, val_dataloader, test_dataloader

# forced direction

In [None]:
dataset_name="clinical_fd_clean_nodup"

simulation_data_list=[]

for concept_name in variable_dict[dataset_name]["concept_list"]:
    skincon_idx=(variable_dict[dataset_name]["metadata_all"]["skincon_Do not consider this image"]==0).values
    if variable_dict[dataset_name]["metadata_all"][concept_name][skincon_idx].values.astype(bool).sum()<30:
        print(concept_name, "!!!!!!!!!!!!!!!!!!!!!!!! SKIPPED !!!!!!!!!!!!!!!!!!!!!!!!")
        continue  
    
#     if concept_name!="skincon_Brown(Hyperpigmentation)":
#         continue
    
    
    num_train_pos=500
    num_train_neg=500
    
    num_test_pos=500
    num_test_neg=500
#     for proportion in [1.0, 0.8, 0.2, 0.0]:
    for proportion in [1]:
        num_train_pos_with = int(num_train_pos*proportion)
        num_train_pos_without = num_train_pos-num_train_pos_with
        
        num_train_neg_with = int(num_train_neg*(1-proportion))
        num_train_neg_without = num_train_neg-num_train_neg_with
        
        
        num_test_pos_with = int(num_test_pos*(1-proportion))
        num_test_pos_without = num_test_pos-num_test_pos_with
        
        num_test_neg_with = int(num_test_neg*(proportion))
        num_test_neg_without = num_test_neg-num_test_neg_with
        
        
        for random_seed in range(20):
            subset_idx_train_, subset_idx_test_ = train_test_split(np.arange(len(skincon_idx))[skincon_idx], 
                                                                   test_size=0.4, 
                                                                   random_state=random_seed)



            metadata_all_train_=variable_dict[dataset_name]["metadata_all"].iloc[subset_idx_train_]
            y_pos_train_=variable_dict[dataset_name]["y_pos"][subset_idx_train_]
            metadata_all_train_=metadata_all_train_.copy()
            metadata_all_train_["y_pos"]=y_pos_train_
            
            if len(metadata_all_train_[(metadata_all_train_[concept_name]==1)&(metadata_all_train_["y_pos"]==True)])<30 or\
            len(metadata_all_train_[(metadata_all_train_[concept_name]==0)&(metadata_all_train_["y_pos"]==True)])<30 or\
            len(metadata_all_train_[(metadata_all_train_[concept_name]==1)&(metadata_all_train_["y_pos"]==False)])<30 or\
            len(metadata_all_train_[(metadata_all_train_[concept_name]==0)&(metadata_all_train_["y_pos"]==False)])<30:
                continue
            
            
            train_idx_pos_with=metadata_all_train_[(metadata_all_train_[concept_name]==1)&(metadata_all_train_["y_pos"]==True)].sample(n=num_train_pos_with, replace=True, random_state=random_seed).index
            train_idx_pos_without=metadata_all_train_[(metadata_all_train_[concept_name]==0)&(metadata_all_train_["y_pos"]==True)].sample(n=num_train_pos_without, replace=True, random_state=random_seed).index
            
            train_idx_neg_with=metadata_all_train_[(metadata_all_train_[concept_name]==1)&(metadata_all_train_["y_pos"]==False)].sample(n=num_train_neg_with, replace=True, random_state=random_seed).index
            train_idx_neg_without=metadata_all_train_[(metadata_all_train_[concept_name]==0)&(metadata_all_train_["y_pos"]==False)].sample(n=num_train_neg_without, replace=True, random_state=random_seed).index
            
            train_idx=train_idx_pos_with.tolist()+train_idx_pos_without.tolist()+train_idx_neg_with.tolist()+train_idx_neg_without.tolist()
            # train_idx=metadata_all_train_[(metadata_all_train_[concept_name]==1] ###
            #train_idx=metadata_all_train_.index.tolist()
            #print(metadata_all_train_.loc[train_idx][[concept_name,"y_pos"]])
            metadata_all_train=metadata_all_train_.loc[train_idx]
            
            
            metadata_all_test_=variable_dict[dataset_name]["metadata_all"].iloc[subset_idx_test_]
            y_pos_test_=variable_dict[dataset_name]["y_pos"][subset_idx_test_]
            metadata_all_test_=metadata_all_test_.copy()
            metadata_all_test_["y_pos"]=y_pos_test_
            
            if len(metadata_all_test_[(metadata_all_test_[concept_name]==1)&(metadata_all_test_["y_pos"]==True)])<30 or\
            len(metadata_all_test_[(metadata_all_test_[concept_name]==0)&(metadata_all_test_["y_pos"]==True)])<30 or\
            len(metadata_all_test_[(metadata_all_test_[concept_name]==1)&(metadata_all_test_["y_pos"]==False)])<30 or\
            len(metadata_all_test_[(metadata_all_test_[concept_name]==0)&(metadata_all_test_["y_pos"]==False)])<30:        
                continue
            
            
            test_idx_pos_with=metadata_all_test_[(metadata_all_test_[concept_name]==1)&(metadata_all_test_["y_pos"]==True)].sample(n=num_test_pos_with, replace=True, random_state=random_seed).index
            test_idx_pos_without=metadata_all_test_[(metadata_all_test_[concept_name]==0)&(metadata_all_test_["y_pos"]==True)].sample(n=num_test_pos_without, replace=True, random_state=random_seed).index
            
            test_idx_neg_with=metadata_all_test_[(metadata_all_test_[concept_name]==1)&(metadata_all_test_["y_pos"]==False)].sample(n=num_test_neg_with, replace=True, random_state=random_seed).index
            test_idx_neg_without=metadata_all_test_[(metadata_all_test_[concept_name]==0)&(metadata_all_test_["y_pos"]==False)].sample(n=num_test_neg_without, replace=True, random_state=random_seed).index
            
            test_idx=test_idx_pos_with.tolist()+test_idx_pos_without.tolist()+test_idx_neg_with.tolist()+test_idx_neg_without.tolist()
            
            
            metadata_all_test=metadata_all_test_.loc[test_idx]
            print(len(train_idx), len(test_idx))
            
            train_dataloader, val_dataloader, test_dataloader=\
            get_training_data_idx(dataloader=variable_dict[dataset_name]["dataloader"], 
                              valid_idx=skincon_idx, 
                              y_pos=variable_dict[dataset_name]["y_pos"], 
#                               y_pos=variable_dict[dataset_name]["dataloader"].dataset.metadata_all[concept_name].fillna(0), 
                              subset_idx_train=metadata_all_train.index, 
                              #subset_idx_test=test_idx, 
                              subset_idx_test=metadata_all_test.index,
                              n_px=None)            
    
#             print(len(train_dataloader))
#             print(len(val_dataloader))
#             print(len(test_dataloader))
#             subset_idx_train=
            
            auc, x, logits_test, label_test, metadata_test, max_f1_thres =train_classifier(train_dataloader=train_dataloader, 
                                   val_dataloader=val_dataloader,
                                   test_dataloader=test_dataloader, verbose=True)  
        
            metadata_test=pd.concat(metadata_test)  
            label_test=pd.Series(np.hstack(label_test), index=metadata_test.index)
            logit_test=pd.Series(np.concatenate(logits_test)[:,0], index=metadata_test.index)
            
            simulation_data_list.append({"concept_name": concept_name,
                                         "random_seed": random_seed,
                                         "proportion": proportion,
                                         "label_test": label_test,
                                         "logit_test": logit_test,
                                         "metadata_all_train": metadata_all_train,
                                         "metadata_all_test": metadata_all_test,
                                         "metadata_test": metadata_test,
                                         "max_f1_thres": max_f1_thres,
                                        })
#             label_list=np.hstack(label_list)
#             logits_list=np.concatenate(logits_list)[:,0]
#             metadata_list=pd.concat(metadata_list)         

#         record_dict_list.append({"concept_name": concept_name})
        print(concept_name)

In [None]:
# torch.save(simulation_data_list, "logs/experiment_results/model_audit_benchmark_0525.pt")

In [None]:
# simulation_data_list=torch.load("logs/experiment_results/model_audit_benchmark_0525.pt")

In [None]:
simulation_data_list=torch.load("logs/experiment_results/model_audit_benchmark_0526.pt", map_location="cpu")

In [None]:
# torch.save(variable_dict, "logs/experiment_results/model_audit_data_0526.pt")

In [None]:
variable_dict=torch.load("logs/experiment_results/model_audit_data_0526.pt", map_location="cpu")

In [None]:
pd.DataFrame(simulation_data_list)

In [None]:
# logit_test

# log_loss(y_pred=logit_test,
#         y_true=label_test).mean()

In [None]:
pd.set_option('display.max_columns', 500)

In [None]:
variable_dict[dataset_name].keys()

current method+CLIP
DOMINO+MONET

In [None]:
# similarity_info_copy=cluster_concept_test(similarity_info=concept_similarity_all,
#                                    ground_truth=variable_dict["clinical_fd_clean_nodup"]["metadata_all"][skincon_cols],
#                                    clustering_features=concept_similarity_all,
#                                    labels=label_test, logits=logit_test,
#                                    threshold=0,
#                                    score_threshold=0.8, accuracy_diff=0.1)

In [None]:
from scipy.stats import fisher_exact
def fisher_test_df(data, columns, y_pos):
    res_df=[]
    for column in columns:
#         print(column)
#         print([[((data[column]==1)&(y_pos.loc[data.index]==True)).sum(), ((data[column]==0)&(y_pos.loc[data.index]==True)).sum()],
#             [((data[column]==1)&(y_pos.loc[data.index]==False)).sum(), ((data[column]==0)&(y_pos.loc[data.index]==False)).sum()]])
#         print(data.shape)
#         print(((data[column]==1).shape,(y_pos.loc[data.index]==True).shape))

        y_1_c_1=((y_pos==True)&(data[column]==1)).sum()
        y_1_c_0=((y_pos==True)&(data[column]==0)).sum()
        y_0_c_1=((y_pos==False)&(data[column]==1)).sum()
        y_0_c_0=((y_pos==False)&(data[column]==0)).sum()
    

        res=fisher_exact(
            [[y_1_c_1, y_1_c_0],
            [y_0_c_1, y_0_c_0]])
    
#         rl_top=((data[column]==1)&(y_pos==True)).sum()/(((data[column]==1)).sum())
#         rl_bottom=((data[column]==0)&(y_pos==True)).sum()/(((data[column]==0)).sum())
        
#         rl_top=((data[column]==1)&(y_pos==True)).sum()/(((y_pos==True)).sum())
#         rl_bottom=((data[column]==1)&(y_pos==False)).sum()/(((y_pos==False)).sum())        
#         print(res)
        direction=(y_1_c_1-y_1_c_0)*(y_0_c_1-y_0_c_0)
        
        res_df.append({"name": column,
                       "y=1,c=1":y_1_c_1,
                       "y=1,c=0":y_1_c_0,
                       "y=0,c=1":y_0_c_1,
                       "y=0,c=0":y_0_c_0,
                       "direction":direction,
                       "direction1":(y_1_c_1-y_1_c_0),
                       "direction2":(y_0_c_1-y_0_c_0),                       
#                        "rl": rl_top/ rl_bottom,
                      "pvalue": res.pvalue,
                       "statistic": res.statistic,
                      })
        
#         print()
    
    res_df=pd.DataFrame(res_df).sort_values('statistic').set_index('name')
    fdr_corrected=fdrcorrection(res_df["pvalue"])
    res_df["FDR_rejected"]=fdr_corrected[0]
    res_df["FDR_pvalue_adjusted"]=fdr_corrected[1]    
    res_df["bof_pvalue_adjusted"]=res_df["pvalue"]*len(res_df)
    res_df["bof_pvalue_adjusted"]=res_df["bof_pvalue_adjusted"].map(lambda x:1 if x>1 else x)
    #res_df["bof_rejected"]=res_df["bof_pvalue_adjusted"]<0.05
    res_df["bof_rejected"]=res_df["bof_pvalue_adjusted"]<0.01
    return res_df 

In [None]:
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import random
from scipy.stats import mode

def cluster_concept_test_real(similarity_info, clustering_features, fixed_answer,
                         labels, logits, threshold,
                         metric_diff=0.5,
                         n_clusters=40, random_state=42):
    
    torch.manual_seed(random_state)
    random.seed(random_state)
    np.random.seed(random_state)    
    
    record_list=[]
    
    per_label=False
    
    if per_label:
        labels_unique=np.unique(labels)
    else:
        labels_unique=[None]
        
    for label in labels_unique:
        if label is not None:
            focus_idx=labels[labels==label].index

            similarity_info_focus=similarity_info.loc[focus_idx].copy()
            clustering_features_focus=clustering_features.loc[focus_idx].copy()
            labels_focus=labels[labels==label].copy()
            logits_focus=logits[labels==label].copy()
        else:
            focus_idx=labels[labels.astype(int)>-9].index

            similarity_info_focus=similarity_info.loc[focus_idx].copy()
            clustering_features_focus=clustering_features.loc[focus_idx].copy()
            labels_focus=labels.copy()
            logits_focus=logits.copy()            
            
            
        assert (similarity_info_focus.index==clustering_features_focus.index).all()
        assert (similarity_info_focus.index==labels_focus.index).all()
        assert (similarity_info_focus.index==logits_focus.index).all()

        if clustering_features_focus.shape[1]<50:
            pca = PCA(n_components=10)
        else:
            pca = PCA(n_components=50)

        clustering_features_focus_pc=pca.fit_transform(clustering_features_focus)
    
        kmeans = KMeans(n_clusters=n_clusters//len(labels_unique), random_state=random_state, n_init="auto").fit(clustering_features_focus_pc)
        kmeans_dist=sklearn.metrics.pairwise_distances(kmeans.cluster_centers_)
    
        similarity_info_focus_copy=similarity_info_focus.copy()
        similarity_info_focus_copy["kmeans_label"]=kmeans.labels_
        similarity_info_focus_copy["kmeans_dist"]=((clustering_features_focus_pc-kmeans.cluster_centers_[kmeans.labels_])**2).sum(axis=1)
        similarity_info_focus_copy["accuracy"]=(labels_focus==(logits_focus>threshold))
        similarity_info_focus_copy["loss"]=-log_loss(y_true=labels_focus, y_pred=logits_focus.map(lambda x: 1/(1+np.exp(-x))), labels=[0,1])
        similarity_info_focus_copy["label"]=labels_focus
        similarity_info_focus_copy["logit"]=logits_focus
        
        similarity_info_focus_copy_group=similarity_info_focus_copy.groupby("kmeans_label")[similarity_info.columns.tolist()].apply(lambda x: pd.Series([x[i].values for i in x.columns], index=x.columns))
        similarity_info_focus_copy_group["count"]=similarity_info_focus_copy.groupby("kmeans_label").apply(len)
        similarity_info_focus_copy_group["accuracy"]=similarity_info_focus_copy.groupby("kmeans_label")["accuracy"].mean()
        similarity_info_focus_copy_group["loss"]=similarity_info_focus_copy.groupby("kmeans_label")["loss"].mean()
        similarity_info_focus_copy_group["label_frequent"]=similarity_info_focus_copy.groupby("kmeans_label")["label"].apply(lambda x: mode(x, keepdims=False).mode)

        metric_use="accuracy"

        for count, (idx, row) in enumerate(similarity_info_focus_copy_group.sort_values(metric_use, ascending=True).iterrows()):
            if row[metric_use]>=similarity_info_focus_copy[metric_use].mean():
                continue

            sorted_idx=pd.Series(kmeans_dist[idx], index=sorted(np.unique(kmeans.labels_))).sort_values(ascending=True).index
            sorted_idx=[i for i in sorted_idx if similarity_info_focus_copy_group.loc[i][metric_use]>(similarity_info_focus_copy[metric_use].mean()+metric_diff)]
            
            similarity_info_focus_copy_group_diff_plus=similarity_info_focus_copy_group.copy().loc[[sorted_idx[0]]]
#             print(similarity_info_focus_copy_group_diff_plus)
            similarity_info_focus_copy_group_diff_plus[similarity_info.columns.tolist()]=similarity_info_focus_copy_group_diff_plus[similarity_info.columns.tolist()].apply(lambda x: pd.Series([(np.mean(row[i])-np.mean(x.loc[i])) for i in x.index], index=x.index), axis=1)
    
            x=pd.concat([
                similarity_info_focus_copy_group_diff_plus[similarity_info.columns.tolist()].loc[sorted_idx[0]].rename('diff_magnitude'),
                row[similarity_info.columns.tolist()].map(lambda x: np.mean(x)).rename("mean_value")
            ],
                axis=1)
#             print(x.sort_values("diff_magnitude", ascending=False))    
                        
            similarity_info_focus_copy_group_diff_minus=similarity_info_focus_copy_group.copy().loc[[sorted_idx[0]]]
            similarity_info_focus_copy_group_diff_minus[similarity_info.columns.tolist()]=similarity_info_focus_copy_group_diff_minus[similarity_info.columns.tolist()].apply(lambda x: pd.Series([-(np.mean(row[i])-np.mean(x.loc[i])) for i in x.index], index=x.index), axis=1)            
            
            
            
#             print('-------')
#             print(x.sort_values('diff_magnitude', ascending=True).index==similarity_info_focus_copy_group_diff_minus[similarity_info.columns.tolist()].loc[sorted_idx[0]].sort_values(ascending=False).index.tolist())
#             print(x.sort_values('diff_magnitude', ascending=True))
#             print(idx, sorted_idx)
            
            
            
            record_list.append(
                { 
#                  "on_the_spot_plus_pred": similarity_info_focus_copy_group_diff_plus[similarity_info.columns.tolist()].loc[sorted_idx[0]].sort_values(ascending=False).index.tolist(),
#                  "on_the_spot_minus_pred": similarity_info_focus_copy_group_diff_minus[similarity_info.columns.tolist()].loc[sorted_idx[0]].sort_values(ascending=False).index.tolist(),
                 "on_the_spot_plus_pred": x[(x["mean_value"]>0.5)&(x["diff_magnitude"]>0)].sort_values("diff_magnitude", ascending=False).index.tolist(),
                 "on_the_spot_minus_pred": x[(x["mean_value"]>0.5)&(x["diff_magnitude"]>0)].sort_values("diff_magnitude", ascending=True).index.tolist(),                    
                 "statistics": x,
                 "labels": similarity_info_focus_copy[(similarity_info_focus_copy["kmeans_label"]==idx)][["kmeans_dist", metric_use]],
                 "labels_ref": similarity_info_focus_copy[(similarity_info_focus_copy["kmeans_label"]==sorted_idx[0])][["kmeans_dist", metric_use]]                 
                })
#             print(record_list[-1]["statistics"].sort_values("diff_magnitude", ascending=False).index==record_list[-1]["on_the_spot_minus_pred"]).all()
    
    return record_list

In [None]:
def evaluate_model_audit(simulation_data_list):
    
    def get_ground_truth_on_the_spot(ground_truth, target_idx, reference_idx):
        ground_truth_copy=ground_truth.copy()
        
        ground_truth_copy_target=ground_truth_copy.loc[target_idx]
        ground_truth_copy_reference=ground_truth_copy.loc[reference_idx]
    
    
        ground_truth_copy_target=(ground_truth_copy_target.mean(axis=0)>0.5).astype(int)
        ground_truth_copy_reference=(ground_truth_copy_reference.mean(axis=0)>0.5).astype(int)
        ground_truth_copy_target_diff=ground_truth_copy_target-ground_truth_copy_reference
        return {"more_present":ground_truth_copy_target_diff[ground_truth_copy_target_diff>0].index,
                "less_present":ground_truth_copy_target_diff[ground_truth_copy_target_diff<0].index}
                
                
    record_dict_all=[]
    for simulation_count, simulation_data in enumerate(tqdm.tqdm(simulation_data_list)):
#         if simulation_count not in [0,20,40,60,80]:
#             continue
        
        concept_name=simulation_data["concept_name"]
        label_test=simulation_data["label_test"]
        logit_test=simulation_data["logit_test"]
        metadata_train=simulation_data["metadata_all_train"]        
        metadata_test=simulation_data["metadata_all_test"]                
#         metadata_test=simulation_data["metadata_test"]
        max_f1_thres=simulation_data["max_f1_thres"]
        
        
        fisher_pvals_train=fisher_test_df(data=metadata_train,
                   columns=variable_dict["clinical_fd_clean_nodup"]["concept_list"],
                   y_pos=metadata_train["y_pos"])
        

        fisher_pvals_test=fisher_test_df(data=metadata_test,
                       columns=variable_dict["clinical_fd_clean_nodup"]["concept_list"],
                       y_pos=metadata_test["y_pos"])        
        
#         print(fisher_pvals_train)
#         print(fisher_pvals_test)
        
        
        test_less_represented=fisher_pvals_train[(fisher_pvals_train["direction"]<0)&(fisher_pvals_train["statistic"]>1)].index\
        .intersection(fisher_pvals_test[(fisher_pvals_test["direction"]<0)&(fisher_pvals_test["statistic"]<1)].index)
        
        test_more_represented=fisher_pvals_train[(fisher_pvals_train["direction"]<0)&(fisher_pvals_train["statistic"]<1)].index\
        .intersection(fisher_pvals_test[(fisher_pvals_test["direction"]<0)&(fisher_pvals_test["statistic"]>1)].index)
        
        test_more_represented_=fisher_pvals_train[(fisher_pvals_train["direction1"]>0)&(fisher_pvals_train["direction2"]<0)].index\
        .intersection(fisher_pvals_test[(fisher_pvals_test["direction1"]<0)&(fisher_pvals_test["direction2"]>0)].index)
        
        test_less_represented_=fisher_pvals_train[(fisher_pvals_train["direction1"]<0)&(fisher_pvals_train["direction2"]>0)].index\
        .intersection(fisher_pvals_test[(fisher_pvals_test["direction1"]>0)&(fisher_pvals_test["direction2"]<0)].index)        
        
        print(concept_name, "Train-/Test+", test_more_represented.tolist(), "Train+/Test-", test_less_represented.tolist())
        print("Test",concept_name, "Train-/Test+", test_more_represented_.tolist(), "Train+/Test-", test_less_represented_.tolist())
                
        fixed_answer=test_more_represented.tolist()+test_less_represented.tolist()

        test_result_MONET=cluster_concept_test_real(similarity_info=variable_dict["clinical_fd_clean_nodup"]["similarity_matrix"][skincon_cols],
                                                    clustering_features=pd.DataFrame(variable_dict["clinical_fd_clean_nodup"]["efficientnet_feature"].numpy(),
                                                                     index=variable_dict["clinical_fd_clean_nodup"]["metadata_all"].index,
                                                                    ),
                                                    fixed_answer=fixed_answer,
                                                    labels=label_test, logits=logit_test,
                                                    threshold=max_f1_thres,
        #                                            score_threshold=0.8, 
                                                    metric_diff=0,
                                                    n_clusters=40)
        
        test_result_vanilla=cluster_concept_test_real(similarity_info=variable_dict["clinical_fd_clean_nodup"]["similarity_matrix_vanilla"][skincon_cols],
                                                    clustering_features=pd.DataFrame(variable_dict["clinical_fd_clean_nodup"]["efficientnet_feature"].numpy(),
                                                                     index=variable_dict["clinical_fd_clean_nodup"]["metadata_all"].index,
                                                                    ),
                                                    fixed_answer=fixed_answer,
                                                    labels=label_test, logits=logit_test,
                                                    threshold=max_f1_thres,
        #                                            score_threshold=0.8, 
                                                    metric_diff=0.1,
                                                    n_clusters=40)
        
        
        for model in ["MONET", "CLIP"]:
            if model=="MONET":
                test_result_list=test_result_MONET
            elif model=="CLIP":
                test_result_list=test_result_vanilla
            else:
                raise
        
            for test_result in test_result_list:
                ground_truth_on_the_spot=get_ground_truth_on_the_spot(ground_truth=variable_dict["clinical_fd_clean_nodup"]["metadata_all"][skincon_cols], 
                                             target_idx=test_result["labels"].index, 
                                             reference_idx=test_result["labels_ref"].index)
        
#                 print(model, len(test_result_list))
                for i in range(1,5+1):
                    if i>len(test_result["on_the_spot_plus_pred"]):
                        continue
                    record_dict_all.append({
                        "model": model,
                        "method": "on_the_spot_plus",
                        "rank_n": i,
                        "metric": len(set(ground_truth_on_the_spot["more_present"]).intersection(test_result["on_the_spot_plus_pred"][:i]))!=0,
                        "answer_length": len(set(ground_truth_on_the_spot["more_present"])),
                        "random_performance": 1-(math.perm(48-len(set(ground_truth_on_the_spot["more_present"])), i) / math.perm(48, i)),
                    })

                for i in range(1,5+1):
                    if i>len(test_result["on_the_spot_minus_pred"]):
                        continue                 
                    record_dict_all.append({
                        "model": model,
                        "method": "on_the_spot_minus",
                        "rank_n": i,
                        "metric": len(set(ground_truth_on_the_spot["less_present"]).intersection(test_result["on_the_spot_minus_pred"][:i]))!=0,
                        "answer_length": len(set(ground_truth_on_the_spot["less_present"]))
                    })  
                    
#                 for i in range(1,5+1):
#                     record_dict_all.append({
#                         "model": model,
#                         "method": "on_the_spot_both",
#                         "rank_n": i,
#                         "metric": len(set(ground_truth_on_the_spot["more_present"]).intersection(test_result["on_the_spot_plus_pred"][:i]))!=0 and len(set(ground_truth_on_the_spot["less_present"]).intersection(test_result["on_the_spot_minus_pred"][:i]))!=0,
#                     })                      
                    
            for i in range(1,5+1):      
                record_dict_all.append({
                    "model": model,
                    "method": "fixed_answer_plus",
                    "answer_length": len(set(fixed_answer)),
                    "count": len(test_result_list),
                    "rank_n": i,
                    "metric": len(set(fixed_answer).intersection([p for test_result in test_result_list for p in test_result["on_the_spot_plus_pred"][:i]]))!=0,
                    "random_performance": 1-(math.comb(48-len(set(fixed_answer)), i) / math.comb(48, i))**len(test_result_list),
                })    
                
            for i in range(1,5+1):
                record_dict_all.append({
                    "model": model,
                    "method": "fixed_answer_minus",
                    "answer_length": len(set(fixed_answer)),
                    "count": len(test_result_list),
                    "rank_n": i,
                    "metric": len(set(fixed_answer).intersection([p for test_result in test_result_list for p in test_result["on_the_spot_minus_pred"][:i]]))!=0,
                })    
                                  
        
        
    return record_dict_all

In [None]:
evaluate_model_audit(simulation_data_list=simulation_data_list[:])

In [None]:
pd

In [None]:
eval_result_df[]

In [None]:
1-(math.comb(48-2, 1) / math.comb(48, 1))**20

In [None]:
1-(math.comb(48-2, 2) / math.comb(48, 2))**20

In [None]:
1-(math.comb(48-2, 4) / math.comb(48, 4))**20

In [None]:
eval_result_df[eval_result_df["method"]=="fixed_answer_plus"].groupby(["method", "rank_n"]).mean()

In [None]:
1-(math.perm(48-len(set(ground_truth_on_the_spot["more_present"])), i) / math.perm(48, i))

In [None]:
eval_result_df[(eval_result_df["method"]=="fixed_answer_plus")].groupby(["rank_n", "method"]).mean()

In [None]:
import math

In [None]:
eval_result_df.plot(x="random_performance", y="random_performance_")

In [None]:
# eval_result_df[eval_result_df["method"]=="fixed_answer_minus"]["answer_length"].hist()
# eval_result_df[eval_result_df["method"]=="on_the_spot_minus"]["answer_length"].hist()
# eval_result=evaluate_model_audit(simulation_data_list=
# [
#     simulation_data_list[0],
#     simulation_data_list[20],
#     simulation_data_list[40],
#     simulation_data_list[60],
#     simulation_data_list[80],    

# ])
eval_result=evaluate_model_audit(simulation_data_list=simulation_data_list)
eval_result_df=pd.DataFrame(eval_result)
eval_result_df["metric_random_ratio"]=eval_result_df["metric"].astype(int)/eval_result_df["random_performance"]

In [None]:
simulation_data_list[0]["concept_name"],\
simulation_data_list[20]["concept_name"],\
simulation_data_list[40]["concept_name"],\
simulation_data_list[60]["concept_name"],\
simulation_data_list[80]["concept_name"]

In [None]:
# variable_dict["clinical_fd_clean_nodup"]["metadata_all"][
# (variable_dict["clinical_fd_clean_nodup"]["metadata_all"]["skincon_Crust"]==0)
# &(variable_dict["clinical_fd_clean_nodup"]["y_pos"]==False)
# ].iloc[5:]
# variable_dict["clinical_fd_clean_nodup"]["dataloader"].dataset.getitem(
# variable_dict["clinical_fd_clean_nodup"]["dataloader"].dataset.metadata_all.index.tolist().index(
# "7d2f3fa05f4f362299c1ed148e7fc719.jpg")
# )["image"]

# one ground truth

In [None]:
plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                Paired[12][3],
                                                                                Paired[12][5],
                                                                                Paired[12][7],
                                                                                Paired[12][9],
                                                                                Paired[12][11]
                                                                                ]])
# fig = plt.figure(constrained_layout=True, figsize=(15, 6))
fig = plt.figure(figsize=(18, 5))
subfigs = fig.subfigures(1, 1)

axes = subfigs.subplots(1,2, gridspec_kw={"wspace":0.3})

axd={'fixed': axes[0], "on_the_spot": axes[1] }

plot_key="fixed"


# sns.barplot(x="rank_n", y="metric", hue="model", hue_order=["MONET", "CLIP"], 
#             data=eval_result_df[(eval_result_df["method"]=="fixed_answer_plus")&(eval_result_df["rank_n"]<=3)],
#            ax=axd[plot_key])

eval_result_df_mean_fixed=eval_result_df[
    (eval_result_df["method"]=="fixed_answer_plus")
    &(eval_result_df["answer_length"]!=0)
    &(eval_result_df["rank_n"]<=3)].groupby(["rank_n", "model"]).mean()
print(eval_result_df_mean_fixed)
sns.barplot(
    x="rank_n", y="metric", hue="model", hue_order=["MONET", "CLIP"],
    data=(eval_result_df_mean_fixed).reset_index(),
    width=0.7,
    ax=axd[plot_key]
)


# plt.title("\
# ground-truth is defined based on distribution of train/test set\
# \n(i.e., similar to the `red` confounder in the ISIC)")
for axis in ['top','bottom','left','right']:
    axd[plot_key].spines[axis].set_linewidth(1.5)
axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)

axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))
axd[plot_key].yaxis.grid(True, which='major', linewidth=0.4, alpha=0.4)
# axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.2, alpha=0.4)

axd[plot_key].set_ylim(0,1)

axd[plot_key].tick_params(axis='both', which='major', labelsize=16)
axd[plot_key].tick_params(axis='both', which='minor', labelsize=16)

axd[plot_key].set_xlabel("Top-N", fontsize=18)
# axd[plot_key].set_ylabel("Freq. of recovering spurious corr.\n(across all underperforming clusters)", fontsize=16)
axd[plot_key].set_ylabel("Freq. of recovering spurious corr.", fontsize=18)

for patch in axd[plot_key].patches :
    patch.set_linewidth(1)
    patch.set_edgecolor("black")


    
# leg=axd[plot_key].legend(fontsize = 16, facecolor='white', framealpha=0.5)
# axd[plot_key].get_legend().remove()
leg=axd[plot_key].legend(fontsize = 16, facecolor='white', framealpha=0.5, ncols=2,
                         loc='upper center', bbox_to_anchor=(0.487, -0.13, 0, 0), 
                        
                        )
leg.set_title("", prop={"size":16})

axd[plot_key].text(x=-0.2, 
                   y=1.09, 
#                    y=1.03, 
                   transform=axd[plot_key].transAxes,
                     s=" B.", fontsize=23, weight='bold')
# leg.set_title("Model", prop={"size":16})


# axd[plot_key].set_title("Do the top N rec spurious correlation", fontsize=16)
axd[plot_key].set_title("Do the top-N concept explanations recover spurious correlations?\n(across all low-performing clusters)", 
                        fontsize=18)

plot_key="on_the_spot"





# sns.barplot(x="rank_n", y="metric_random_ratio", hue="model", 
# data=eval_result_df[
#     (eval_result_df["method"]=="on_the_spot_plus")
#     &(eval_result_df["answer_length"]!=0)
#     &(eval_result_df["rank_n"]<=3)], ax=axd[plot_key])


for axis in ['top','bottom','left','right']:
    axd[plot_key].spines[axis].set_visible(False)
    
axd[plot_key].tick_params(left = False, right = False , labelleft = False ,
            labelbottom = False, bottom = False)

fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.pdf", bbox_inches='tight')

# two ground truth

In [None]:
plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                Paired[12][3],
                                                                                Paired[12][5],
                                                                                Paired[12][7],
                                                                                Paired[12][9],
                                                                                Paired[12][11]
                                                                                ]])
# fig = plt.figure(constrained_layout=True, figsize=(15, 6))
fig = plt.figure(figsize=(18, 5))
subfigs = fig.subfigures(1, 1)

axes = subfigs.subplots(1,2, gridspec_kw={"wspace":0.3})

axd={'fixed': axes[0], "on_the_spot": axes[1] }

plot_key="fixed"


# sns.barplot(x="rank_n", y="metric", hue="model", hue_order=["MONET", "CLIP"], 
#             data=eval_result_df[(eval_result_df["method"]=="fixed_answer_plus")&(eval_result_df["rank_n"]<=3)],
#            ax=axd[plot_key])

eval_result_df_mean_fixed=eval_result_df[
    (eval_result_df["method"]=="fixed_answer_plus")
    &(eval_result_df["answer_length"]!=0)
    &(eval_result_df["rank_n"]<=3)].groupby(["rank_n", "model"]).mean()
print(eval_result_df_mean_fixed)
sns.barplot(
    x="rank_n", y="metric", hue="model", hue_order=["MONET", "CLIP"],
    data=(eval_result_df_mean_fixed).reset_index(),
    width=0.7,
    ax=axd[plot_key]
)


# plt.title("\
# ground-truth is defined based on distribution of train/test set\
# \n(i.e., similar to the `red` confounder in the ISIC)")
for axis in ['top','bottom','left','right']:
    axd[plot_key].spines[axis].set_linewidth(1.5)
axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)

axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))
axd[plot_key].yaxis.grid(True, which='major', linewidth=0.4, alpha=0.4)
# axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.2, alpha=0.4)

axd[plot_key].set_ylim(0,1)

axd[plot_key].tick_params(axis='both', which='major', labelsize=16)
axd[plot_key].tick_params(axis='both', which='minor', labelsize=16)

axd[plot_key].set_xlabel("Top-N", fontsize=18)
# axd[plot_key].set_ylabel("Freq. of recovering spurious corr.\n(across all underperforming clusters)", fontsize=16)
axd[plot_key].set_ylabel("Freq. of recovering spurious corr.", fontsize=18)

for patch in axd[plot_key].patches :
    patch.set_linewidth(1)
    patch.set_edgecolor("black")


    
# leg=axd[plot_key].legend(fontsize = 16, facecolor='white', framealpha=0.5)
axd[plot_key].get_legend().remove()
axd[plot_key].text(x=-0.2, 
                   y=1.09, 
#                    y=1.03, 
                   transform=axd[plot_key].transAxes,
                     s=" B.", fontsize=23, weight='bold')
# leg.set_title("Model", prop={"size":16})


# axd[plot_key].set_title("Do the top N rec spurious correlation", fontsize=16)
axd[plot_key].set_title("Do the top-N concept explanations recover spurious correlations?\n(across all low-performing clusters)", 
                        fontsize=18)

plot_key="on_the_spot"





# sns.barplot(x="rank_n", y="metric_random_ratio", hue="model", 
# data=eval_result_df[
#     (eval_result_df["method"]=="on_the_spot_plus")
#     &(eval_result_df["answer_length"]!=0)
#     &(eval_result_df["rank_n"]<=3)], ax=axd[plot_key])

eval_result_df_mean_spot=eval_result_df[
    (eval_result_df["method"]=="on_the_spot_plus")
    &(eval_result_df["answer_length"]!=0)
    &(eval_result_df["rank_n"]<=3)].groupby(["rank_n", "model"]).mean()
print(eval_result_df_mean_spot)
sns.barplot(
    x="rank_n", y=0, hue="model", hue_order=["MONET", "CLIP"],
    data=(eval_result_df_mean_spot["metric"]/eval_result_df_mean_spot["random_performance"]).reset_index(),
    width=0.7,
    ax=axd[plot_key]
)

for axis in ['top','bottom','left','right']:
    axd[plot_key].spines[axis].set_linewidth(1.5)
axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)

axd[plot_key].yaxis.set_major_locator(MultipleLocator(1))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.5))
axd[plot_key].yaxis.grid(True, which='major', linewidth=0.4, alpha=0.4)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.2, alpha=0.4)

axd[plot_key].tick_params(axis='both', which='major', labelsize=16)
axd[plot_key].tick_params(axis='both', which='minor', labelsize=16)

axd[plot_key].set_xlabel("Top-N", fontsize=18)
# axd[plot_key].set_ylabel("Prob. of listing ground-truth concept\ncompared to random (per cluster)", fontsize=16)
axd[plot_key].set_ylabel("Ratio of freq. of including ground truth\nto that in random ordering", fontsize=18)



for patch in axd[plot_key].patches :
    patch.set_linewidth(1)
    patch.set_edgecolor("black")

leg=axd[plot_key].legend(fontsize = 16, facecolor='white', framealpha=0.5, ncols=2,
                         loc='upper center', bbox_to_anchor=(-0.2, -0.1, 0, 0)
                        
                        )
leg.set_title("", prop={"size":16})
axd[plot_key].text(x=-0.19, 
                   y=1.09, 
#                    y=1.03, 
                   transform=axd[plot_key].transAxes,
                     s="C.", fontsize=23, weight='bold')
# plt.tight_figure()
# axd[plot_key].set_title("Ground-truth", fontsize=16)
axd[plot_key].set_title("Do the top-N concept explanations include ground truth\ndefined per low-performing cluster?", 
                        fontsize=18)

fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"model_audit_main_benchmark.pdf", bbox_inches='tight')

In [None]:
eval_result_df_mean_spot

In [None]:
top_n=3
num_true=2
1-(math.comb(48-num_true, top_n) / math.comb(48, top_n))**20

In [None]:
variable_dict_isic=torch.load("logs/experiment_results/data_audit_new_0429.pt", map_location='cpu')

In [None]:
variable_dict_isic["isic"].keys()

In [None]:
hospital_1="ViDIR Group, Department of Dermatology, Medical University of Vienna"
hospital_2="Hospital Clínic de Barcelona"

In [None]:
max_f1_thres_isic={}

classifier_val_idx=variable_dict_isic[dataset_name][f"classifier_dataloader_{hospital_1}"][1].dataset.metadata_all.index
y_test=variable_dict_isic[dataset_name]["classifier_dataloader_all"].dataset.metadata_all["label"].loc[classifier_val_idx]
y_test_predicted_probas=variable_dict_isic[dataset_name][f"classifier_model_{hospital_1}_eval"]["logits"].loc[classifier_val_idx]
# y_test_predicted_probas=y_test_predicted_probas.map(lambda x: 1/(1 + np.exp(-x)))

precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_test, y_test_predicted_probas)
numerator = 2 * recall * precision
denom = recall + precision
f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom!=0))
max_f1 = np.max(f1_scores)
max_f1_thresh = thresholds[np.argmax(f1_scores)]
max_f1_thres_isic[hospital_1]=max_f1_thresh
print(max_f1_thresh)

classifier_val_idx=variable_dict_isic[dataset_name][f"classifier_dataloader_{hospital_2}"][1].dataset.metadata_all.index
y_test=variable_dict_isic[dataset_name]["classifier_dataloader_all"].dataset.metadata_all["label"].loc[classifier_val_idx]
y_test_predicted_probas=variable_dict_isic[dataset_name][f"classifier_model_{hospital_2}_eval"]["logits"].loc[classifier_val_idx]
# y_test_predicted_probas=y_test_predicted_probas.map(lambda x: 1/(1 + np.exp(-x)))

precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_test, y_test_predicted_probas)
numerator = 2 * recall * precision
denom = recall + precision
f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom!=0))
max_f1 = np.max(f1_scores)
max_f1_thresh = thresholds[np.argmax(f1_scores)]
max_f1_thres_isic[hospital_2]=max_f1_thresh

print(max_f1_thresh)

In [None]:
def get_subset_index(dataset_name, metadata_all, attribution):
    if dataset_name=="isic":
        if attribution=="all":
            #pd.Series([True], index=metadata_all)
            subset_idx=np.array([True]*len(metadata_all))
        else:
            collection_65=(metadata_all["collection_65"]==1).values
            
            if attribution=="barcelona_all":
                subset_idx=((metadata_all["attribution"]=="Department of Dermatology, Hospital Clínic de Barcelona")|(metadata_all["attribution"]=="Hospital Clínic de Barcelona")).values
            elif attribution=="mskcc_all":
                subset_idx=((metadata_all["attribution"]=="MSKCC")|(metadata_all["attribution"]=="Memorial Sloan Kettering Cancer Center")).values
            else:
                subset_idx=(metadata_all["attribution"]==attribution).values          
                
            subset_idx=subset_idx&collection_65
        #for attribution in [None]+["barcelona", "vienna", "barcelona_all"]:
        #for attribution in ["all"]+list(metadata_all["attribution"].unique())+["barcelona_all", "mskcc_all"]:
  
                    
    return subset_idx

In [None]:
def check_image(dataset_name, idx):
    if dataset_name=="clinical_fd_clean":
        if idx in [1866, 3016, 2737, 2905, 982]:
            return False
        else:
            return True    
    elif dataset_name=="fitzpatrick17k_clean_threelabel":
        if idx in [1866, 3016, 3227, 2949, 394, 2861, 1515, 2109, 1385]:
            return False
        else:
            return True
    else:
        return True

In [None]:
def check_concept_name(dataset_name, concept_name):
    if dataset_name=="isic":
        if concept_name.startswith("disease"):
            return False
        elif concept_name in ["melanoma", "malignant"]:
            return False
        else:
            return True
    else:
        raise NotImplemented(dataset_name)

def shorten_concept_name(concept_name, strict=True):
    if concept_name.startswith("disease_"):
        short_name=concept_name.replace("disease_", "")
    elif concept_name=="skincon_Erythema":
        short_name="Erythema"
    elif concept_name=="skincon_Bulla":
        short_name="Bulla"
    elif concept_name=="skincon_Lichenification":
        short_name="Lichenification"
    elif concept_name=="skincon_Pustule":
        short_name="Pustule"
    elif concept_name=="skincon_Ulcer":
        short_name="Ulcer"
    elif concept_name=="skincon_Warty/Papillomatous":
        short_name="Warty"
    elif concept_name=="skincon_White(Hypopigmentation)":
        short_name="Hypopigmentation"
    elif concept_name=="skincon_Brown(Hyperpigmentation)":
        short_name="Hyperpigmentation"
    elif concept_name=="skincon_Exophytic/Fungating":
        short_name="Fungating"          
    elif concept_name=="purple pen":
        short_name="Purple pen"
    elif concept_name=="nail":
        short_name="Nail"  
    elif concept_name=="orange sticker":
        short_name="Orange sticker"          
    elif concept_name=="hair":
        short_name="Hair"          
    elif concept_name=="gel":
        short_name="Gel"
    elif concept_name=="red":
        short_name="Red"     
    elif concept_name=="dermoscope border":
        short_name="Dermoscopic border"
    elif concept_name=="pinkish":
        short_name="Pinkish"
    else:
        if concept_name.startswith("skincon_"):
            short_name=concept_name[8:]
        else:
            if strict:
                raise NotImplementedError(concept_name)
            else:
                short_name=concept_name
            
    return short_name

def shorten_hospital_name(hospital_name):
    if hospital_name=="Hospital Clínic de Barcelona":
        short_name="Hospital Clínic de Barcelona"
    elif hospital_name=="ViDIR Group, Department of Dermatology, Medical University of Vienna":
        short_name="Medical University of Vienna"    
    return short_name

In [None]:
from scipy.stats import mode as sci_mode
def plot_slice_figure(data_dict,
                      prompt_info, 
                      row_per_slice=30, 
                      example_per_row=30, 
                      normalize=True, 
                      show_small_box=True, 
                      print_alphabet=True,
                      print_legend_color=True,
                      print_legend_color_idx=2,
                      print_legend_number=True,
                      fontsize=32,
                      slice_title_fontsize=32,
                      skip_section=0,
                      figure_title=None):
    

    
    
    red_color=np.array([212,17,89]) #np.array((222,40,40))
    green_color=np.array([26,133,255]) #np.array((40,200,40))
    [31, 120, 180], [51, 160, 44]
    # two_color=[np.array([90, 0, 220]), np.array([51, 160, 44])]
    #two_color=[np.array([31, 120, 180]), np.array([51, 160, 44])]
    two_color=[np.array([60, 50, 180]), np.array([51, 160, 44])]

    
    total_slices=(len([j for exp_name in data_dict.keys() for j in data_dict[exp_name]["sample_list_list"]]))
    
    fig = plt.figure(figsize=(3*(example_per_row), 
                              3*(row_per_slice)*total_slices+\
                              0.4*(len(data_dict)-1)+\
                              0.3*((total_slices-1)-(len(data_dict)-1))
                             )
                    )

    box1 = gridspec.GridSpec(len(data_dict), 1,
                             wspace=0.0,
                             hspace=0.4)
    
    axd={}
    for idx1, exp_name in enumerate(data_dict.keys()):
        box2 = gridspec.GridSpecFromSubplotSpec(len(data_dict[exp_name]["sample_list_list"]), 1,
                        subplot_spec=box1[idx1], wspace=0.0, hspace=0.3)

        for idx2, (slice_assignment) in enumerate(data_dict[exp_name]["sample_list_list"]):
            box3 = gridspec.GridSpecFromSubplotSpec(row_per_slice, example_per_row,
                                                                    subplot_spec=box2[idx2], wspace=0, hspace=0.05)            
#             if example_per_slice//10==1:
#                 box3 = gridspec.GridSpecFromSubplotSpec(row_per_slice, example_per_row,
#                                                         subplot_spec=box2[idx2], wspace=0, hspace=0.05)
#             else:
#                 box3 = gridspec.GridSpecFromSubplotSpec(row_per_slice, example_per_row,
#                                                         subplot_spec=box2[idx2], wspace=0.05, hspace=0.15)
            for rank_num in range(row_per_slice*example_per_row):
                ax=plt.Subplot(fig, box3[rank_num])
                fig.add_subplot(ax)

                plot_key=(idx1, idx2, rank_num)
                axd[plot_key]=ax   
                
    #dsdsd           
    for idx1, exp_name in enumerate(data_dict.keys()):
        
        targets=data_dict[exp_name]["targets"]
        preds=data_dict[exp_name]["preds"]
        main_title=data_dict[exp_name]["main_title"]   
                
        
        for idx2, (sample_list) in enumerate(data_dict[exp_name]["sample_list_list"]):
            image_idx_list=variable_dict[dataset_name]["metadata_all"].index.get_indexer(sample_list)
        
            count=0
            rank_num=0
            while rank_num<min(row_per_slice*example_per_row, len(image_idx_list)):
                if check_image(dataset_name, image_idx_list[count]):
                    pass
                else:
                    count+=1
                    continue
                    
                plot_key=(idx1, idx2, rank_num)
                
                item=variable_dict[dataset_name]["dataloader"].dataset.getitem(image_idx_list[count])
                image=item["image"]
                axd[plot_key].imshow(image.resize((300, 300)))
                
                if show_small_box:

                    
                    if item["metadata"]["benign_malignant_bool"]==True:
                        axd[plot_key].scatter(x=[0.905], y=[0.905], s=650, 
                                       linewidths=1.5,
                                       edgecolor=np.array((0,0,0, 120))/256,
                                       #edgecolor=np.array((255,255,0, 120))/256,
                                       color=red_color/256,
                                       marker="s",
                                       transform=axd[plot_key].transAxes)     

                    elif item["metadata"]["benign_malignant_bool"]==False:
                        axd[plot_key].scatter(x=[0.905], y=[0.905], s=650, 
                                   linewidths=1.5,
                                   #edgecolor=np.array((0,0,0, 120))/256,
                                   edgecolor=np.array((0,0,0, 120))/256,
                                   color=green_color/256,
                                   marker="s",
                                   transform=axd[plot_key].transAxes)                 

                    x1=0.82
                    x2=0.99
                    if preds.loc[item["metadata"].name]==1:
                        axd[plot_key].fill([x1, x2, x2, x1], [x1, x2, x1, x1], 
                                           color=red_color/256,
                                          transform=axd[plot_key].transAxes
                                          )    
                    else:
                        axd[plot_key].fill([x1, x2, x2, x1], [x1, x2, x1, x1], 
                                           color=green_color/256,
                                          transform=axd[plot_key].transAxes
                                          )
    #                     axd[plot_key].scatter(x=[0.99], y=[0.99], s=700, 
    #                                    linewidths=1.3,
    # #                                    edgecolor=np.array((0,0,0, 120))/256,
    #                                    color=np.array((40,200,40))/256,
    #                                         #color=np.array((100,40,40))/256,
    #                                    marker=6,
    #                                    transform=axd[plot_key].transAxes)                         
                else:
                    axd[plot_key].set_title(item["metadata"].name, fontsize=10)

                axd[plot_key].set_xticks([])
                axd[plot_key].set_yticks([])  
                      

                if rank_num==0:   
                    #shorten_concept_name(concept_name)

                    #axd[plot_key].set_ylabel(shorten_concept_name(concept_name), fontsize=30, zorder=-10)
                    #axd[plot_key].set_ylabel(str(idx1), fontsize=30, zorder=-10)
                    pass

                if rank_num==99:
                    diff_dict_df=pd.DataFrame(diff_dict)
                    #print(diff_dict_df["concept_name"])
                    diff_dict_df=diff_dict_df[diff_dict_df["concept_name"].map(lambda x: check_concept_name("isic", x))]
                    
#                     print(diff_dict_df.sort_values("diff_score", ascending=False).iloc[:]["concept_name"])
                    #title=f"{int(slice_mask.sum()):d} {(targets[slice_mask]==((prob[slice_mask]>0.5).astype(int))).mean():.2f} "
                    #title=', '.join(diff_dict_df.sort_values("diff_score", ascending=False).iloc[:5]["concept_name"].map(shorten_concept_name).tolist())

#                     concept_str=diff_dict_df.sort_values("diff_score", ascending=False).iloc[:5]["concept_name"]
                    concept_str=diff_dict_df.sort_values("concept_presence_score", ascending=False).iloc[:5]["concept_name"]
#                     concept_str=diff_dict_df.sort_values("slice_score", ascending=False).iloc[:5]["concept_name"]
                    concept_str=concept_str.map(shorten_concept_name)
                    concept_str=", ".join(concept_str.str.replace("skincon_",""))
                    concept_str=concept_str
                    
                    print(concept_str)
#                     print(diff_dict_df.sort_values("diff_score", ascending=False).iloc[:20])
                    print(diff_dict_df.sort_values("concept_presence_score", ascending=False).iloc[:20])                    
#                     print(diff_dict_df.sort_values("concept_presence_score", ascending=False).iloc[:20])                    
                    

                    
                    #title+= / Predicted Pos={(prob[slice_mask]>0.5).sum()} Neg={(prob[slice_mask]<0.5).sum()}"
                                      
                    title= concept_str
                        
                    targets.loc[sample_list].sum()
                    axd[plot_key].text(x=-0., y=1.1, transform=axd[plot_key].transAxes,
                                         s=title, fontsize=25, color="black",
                                      
#                                       bbox=dict(facecolor='white', edgecolor='red')
                                      )   
                if rank_num==9:
                    
                    label_str=f"True Malignant: {targets.loc[sample_list].sum()} Neg={(1-targets.loc[sample_list]).sum()}"
                    predicted_str=f" Pred +={(preds.loc[sample_list]==1).sum()} Neg={(preds.loc[sample_list]==0).sum()}"                    
                    
#                     title= f"Malignant: {targets[slice_mask].sum()} → {(preds[slice_mask]==1).sum()}   Benign: {(1-targets[slice_mask]).sum()} → {(prob[slice_mask]<0.5).sum()}"
                    title= f"True: {targets.loc[sample_list].sum()} / {(1-targets.loc[sample_list]).sum()} → Pred: {(preds.loc[sample_list]==1).sum()} / {(preds.loc[sample_list]==0).sum()} "
                    targets.loc[sample_list].sum()
                    axd[plot_key].text(x=1.0, y=1.1, transform=axd[plot_key].transAxes,
                                         s=title, fontsize=25, color="black",
                                       horizontalalignment="right",

#                                       bbox=dict(facecolor='white', edgecolor='red')
                                      )   
                    
                if print_legend_number and idx1==len(data_dict)-1 and idx2==len(data_dict[exp_name]["sample_list_list"])-1 and rank_num==example_per_row-1:
                    
                    title= f"True: # Malignant / # Benign → Pred: # Malignant / # Benign"
                    targets.loc[sample_list].sum()
                    axd[plot_key].text(x=1.0, y=-0.25, transform=axd[plot_key].transAxes,
                                         s=title, fontsize=23, color="black",
                                       horizontalalignment="right",

#                                       bbox=dict(facecolor='white', edgecolor='red')
                                      )                       
                    
                    pass
                   
                    
#                       axd[plot_key].text(x=-0.3, y=1.05, transform=axd[plot_key].transAxes,
#                                          s=["A", "B", "C", "D", "E"][idx1], fontsize=35, weight='bold')

                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(1)      
#                 print('idx~~~~', idx1)
                if rank_num==0 and idx1==0 and idx2==0 and figure_title is not None:
                    axd[plot_key].text(x=-0.33, 
                                         #y=1.1, 
                                         y=1.4,
                                         transform=axd[plot_key].transAxes,
                                         s=figure_title[0], 
                                         fontsize=35, weight='bold')  
                    axd[plot_key].text(x=0.01, 
                                         #y=1.1, 
                                         y=1.4, 
                                         transform=axd[plot_key].transAxes,
                                         s=figure_title[1],
                                         fontsize=fontsize)        
                    
                if rank_num==0 and idx2==0:
                    if print_alphabet and skip_section+idx1<26:
                        axd[plot_key].text(x=-0.3, 
                                             #y=1.1, 
                                             y=1.1, 
                                             transform=axd[plot_key].transAxes,
                                             s=["A.", "B.", "C.", "D.", "E.", "F.", "G.", "H.", "I.", "J.", 
                                                "K.", "L.", "M.", "N.", "O.", "P.", "Q.", "R.", "S.", "T.", 
                                                "U.", "V.", "W.", "X.", "Y.", "Z."][skip_section+idx1], 
                                             fontsize=35, weight='bold')  
                        axd[plot_key].text(x=0.05, 
                                             #y=1.1, 
                                             y=1.1, 
                                             transform=axd[plot_key].transAxes,
                                             s=main_title[0], 
                                             fontsize=fontsize)
                    else:
#                         axd[plot_key].text(x=-0.3, 
#                                              #y=1.1, 
#                                              y=1.4, 
#                                              transform=axd[plot_key].transAxes,
#                                              s=["A.", "B.", "C.", "D.", "E."][skip_section+idx1], 
#                                              fontsize=35, weight='bold')  
                        axd[plot_key].text(x=0.0, 
                                             #y=1.1, 
                                             y=1.1, 
                                             transform=axd[plot_key].transAxes,
                                             s=main_title[0], 
                                             fontsize=slice_title_fontsize)
    
    
                if rank_num==0 and idx2==1:
                    axd[plot_key].text(x=0.0, 
                                         #y=1.1, 
                                         y=1.1, 
                                         transform=axd[plot_key].transAxes,
                                         s=main_title[1], 
                                         fontsize=slice_title_fontsize)    
#                                            , weight='bold')                        
                    
  
                
                
                
                if print_legend_color and idx1==len(data_dict)-1 and idx2==len(data_dict[exp_name]["sample_list_list"])-1 and rank_num==print_legend_color_idx:

                    legend_elements = [Line2D([0], [0], marker='o', color=(1,1,1,1), 
                                              markerfacecolor=np.array((200,40,40))/256, 
                                              markeredgecolor=np.array((0,0,0))/256, 
                                              markersize=30, 
                                              label="Maligant"),
                                       Line2D([0], [0], marker='X', color=(1,1,1,1), 
                                              markerfacecolor=np.array((40,200,40))/256, 
                                              markeredgecolor=np.array((0,0,0))/256, 
                                              markersize=30, label="Benign"),]

                    legend_elements = [Line2D([0], [0], marker='s', color=(1,1,1,1), 
                                              markerfacecolor=red_color/256, 
                                              markeredgecolor=np.array((0,0,0))/256, 
                                              markersize=30, 
                                              label="Maligant"),
                                       Line2D([0], [0], marker='s', color=(1,1,1,1), 
                                              markerfacecolor=green_color/256, 
                                              markeredgecolor=np.array((0,0,0))/256, 
                                              markersize=30, label="Benign   (Upper left: True, Lower right: Pred)"),]        

                    axd[plot_key].legend(handles=legend_elements, 
                                        ncol=2, 
                                        handlelength=3,
                                        handletextpad=-0.1, 
                                        columnspacing=1.5,
                                        fontsize=23,
                                        loc='lower center', 
                                        bbox_to_anchor=(1, -0.45))  
                    
#                     axd[plot_key].legend(handles=legend_elements, 
#                                         ncol=2, 
#                                         handlelength=3,
#                                         handletextpad=-0.1, 
#                                         columnspacing=1.5,
#                                         fontsize=23,
#                                         loc='lower center', 
#                                         bbox_to_anchor=(0, -0.45))                      
                    
                rank_num+=1
                count+=1                      
            
    return fig

# run test

In [None]:
def select_subset(image_features_norm, metadata_all, 
                  logits, labels, subset_idx):
    
    image_features_norm_subset=image_features_norm[metadata_all.index.get_indexer(metadata_all[subset_idx].index)]    
    
    logits_subset=logits.iloc[logits.index.get_indexer(metadata_all[subset_idx].index)]
    
    label_subset=labels.iloc[labels.index.get_indexer(metadata_all[subset_idx].index)]
    
    return image_features_norm_subset, logits_subset, label_subset

image_features_norm_subset_from1_to2, logits_subset_from1_to2, label_subset_from1_to2 = \
select_subset(image_features_norm=variable_dict_isic["isic"]["image_features_norm"],
             metadata_all=variable_dict_isic["isic"]["metadata_all"],
             logits=variable_dict_isic["isic"][f"classifier_model_{hospital_1}_eval"]["logits"],
             labels=variable_dict_isic["isic"]["classifier_dataloader_all"].dataset.metadata_all["label"],
             subset_idx=get_subset_index(dataset_name="isic", 
                                         metadata_all=variable_dict_isic["isic"]["metadata_all"], 
                                         attribution=hospital_2)&(variable_dict_isic["isic"]["valid_idx"])) 

# test_result_list_from1_to2_with_disease=cluster_concept_test_real(similarity_info=variable_dict["isic"]["similarity_matrix"]
#                                                      , 
#                           clustering_features=pd.DataFrame(variable_dict["isic"]["efficientnet_feature"].numpy(),
#                                                              index=variable_dict["isic"]["metadata_all"].index,
#                                                             ), 
#                           fixed_answer=["red"],
#                          labels=label_subset_from1_to2, 
#                           logits=logits_subset_from1_to2, 
#                           threshold=max_f1_thres_isic[hospital_1],
#                          metric_diff=0.1,
#                          n_clusters=80, random_state=42)

test_result_list_from1_to2_concept_only=cluster_concept_test_real(similarity_info=variable_dict["isic"]["similarity_matrix"]
[variable_dict["isic"]["similarity_matrix"].columns[variable_dict["isic"]["similarity_matrix"].columns.map(lambda x: check_concept_name("isic", x))]]
                                                     , 
                          clustering_features=pd.DataFrame(variable_dict["isic"]["efficientnet_feature"].numpy(),
                                                             index=variable_dict["isic"]["metadata_all"].index,
                                                            ), 
                          fixed_answer=["red"],
                         labels=label_subset_from1_to2, 
                          logits=logits_subset_from1_to2, 
                          threshold=max_f1_thres_isic[hospital_1],
                         metric_diff=0.1,
                         n_clusters=80, random_state=42)

In [None]:
def select_subset(image_features_norm, metadata_all, 
                  logits, labels, subset_idx):
    
    image_features_norm_subset=image_features_norm[metadata_all.index.get_indexer(metadata_all[subset_idx].index)]    
    
    logits_subset=logits.iloc[logits.index.get_indexer(metadata_all[subset_idx].index)]
    
    label_subset=labels.iloc[labels.index.get_indexer(metadata_all[subset_idx].index)]
    
    return image_features_norm_subset, logits_subset, label_subset

image_features_norm_subset_from2_to1, logits_subset_from2_to1, label_subset_from2_to1 = \
select_subset(image_features_norm=variable_dict_isic["isic"]["image_features_norm"],
             metadata_all=variable_dict_isic["isic"]["metadata_all"],
             logits=variable_dict_isic["isic"][f"classifier_model_{hospital_2}_eval"]["logits"],
             labels=variable_dict_isic["isic"]["classifier_dataloader_all"].dataset.metadata_all["label"],
             subset_idx=get_subset_index(dataset_name="isic", 
                                         metadata_all=variable_dict_isic["isic"]["metadata_all"], 
                                         attribution=hospital_1)&(variable_dict_isic["isic"]["valid_idx"])) 

test_result_list_from2_to1_concept_only=cluster_concept_test_real(similarity_info=variable_dict["isic"]["similarity_matrix"]
[variable_dict["isic"]["similarity_matrix"].columns[variable_dict["isic"]["similarity_matrix"].columns.map(lambda x: check_concept_name("isic", x))]]
                                                     , 
                          clustering_features=pd.DataFrame(variable_dict["isic"]["efficientnet_feature"].numpy(),
                                                             index=variable_dict["isic"]["metadata_all"].index,
                                                            ), 
                          fixed_answer=["red"],
                         labels=label_subset_from2_to1, 
                          logits=logits_subset_from2_to1, 
                          threshold=max_f1_thres_isic[hospital_2],
                         metric_diff=0.1,
                         n_clusters=40, random_state=42)

test_result_list_from2_to1_with_disease=cluster_concept_test_real(similarity_info=variable_dict["isic"]["similarity_matrix"]
                                                     , 
                          clustering_features=pd.DataFrame(variable_dict["isic"]["efficientnet_feature"].numpy(),
                                                             index=variable_dict["isic"]["metadata_all"].index,
                                                            ), 
                          fixed_answer=["red"],
                         labels=label_subset_from2_to1, 
                          logits=logits_subset_from2_to1, 
                          threshold=max_f1_thres_isic[hospital_2],
                         metric_diff=0.1,
                         n_clusters=40, random_state=42)

# plot

In [None]:
test_result["labels"].sort_values("kmeans_dist").index

In [None]:
similarity_thres=variable_dict["isic"]["similarity_matrix"].quantile(0.5, axis=0)
data_dict_main={}
data_dict_supple={}
count=0
for test_result in test_result_list_from1_to2_concept_only:
    print(test_result["statistics"].sort_values("diff_magnitude", ascending=False))
    concept_name_list_plus=test_result["on_the_spot_plus_pred"][:5]
    sampe_list_plus=pd.concat([variable_dict["isic"]["similarity_matrix"][concept_name_list_plus].sum(axis=1).loc[test_result["labels"].index].rename("concept"),
test_result["labels"]], axis=1).sort_values(["accuracy", "concept"], ascending=[True, False]).index
#     sampe_list_plus=test_result["labels"].sort_values("kmeans_dist").index
#     test_result["labels"].sort_values("kmeans_dist").index,
#                                   test_result["labels_ref"].sort_values("kmeans_dist").index
    sub_title_plus=", ".join([shorten_concept_name(i, strict=False) for i in concept_name_list_plus])
    
    concept_name_list_minus=test_result["on_the_spot_minus_pred"][:3]
    concept_name_list_minus=test_result["statistics"][(test_result["statistics"]["diff_magnitude"]<0)&(test_result["statistics"]["mean_value"]>0.3)].sort_values("diff_magnitude", ascending=True).iloc[:5].index.tolist()
    sub_title_minus=", ".join([shorten_concept_name(i, strict=False) for i in concept_name_list_minus])  
    

    
    if count<5:
        data_dict_main[count]={
            "targets": label_subset_from1_to2.astype(int),
            "preds": (logits_subset_from1_to2>max_f1_thres_isic[hospital_1]).astype(int),
    #         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
    # #                              test_result["labels_ref"].sort_values("kmeans_dist").index
    #                             ],
            "sample_list_list":[
                sampe_list_plus
            ],           
            "main_title": [sub_title_plus, sub_title_minus],
        #     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
        }  
    if count<15:
        data_dict_supple[count]={
            "targets": label_subset_from1_to2.astype(int),
            "preds": (logits_subset_from1_to2>max_f1_thres_isic[hospital_1]).astype(int),
    #         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
    # #                              test_result["labels_ref"].sort_values("kmeans_dist").index
    #                             ],
            "sample_list_list":[
                sampe_list_plus
            ],           
            "main_title": [sub_title_plus, sub_title_minus],
        #     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
        }          
        
    count+=1
    
fig=plot_slice_figure(data_dict=data_dict_main,
                  prompt_info=variable_dict["isic"]["prompt_info"],
                  example_per_row=5,
                  row_per_slice=1,
                  normalize=True, 
                  show_small_box=True,
                  skip_section=0,
                  print_alphabet=False,
                  print_legend_number=False,
                  print_legend_color=False,
                      slice_title_fontsize=27,
                  figure_title=("D. ", "Trained at Med U. Vienna / Tested at Hosp. Barcelona "))
fig.savefig(log_dir/"plots"/f"model_audit_from1_to2_main.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/f"model_audit_from1_to2_main.pdf", bbox_inches='tight')

In [None]:
(data_dict_supple[5]["preds"].loc[data_dict_supple[5]["sample_list_list"][0]][
data_dict_supple[5]["targets"].loc[data_dict_supple[5]["sample_list_list"][0]]==1      
]==0).sum()

In [None]:
(74-19)

In [None]:
(74-19)/74

In [None]:
fig=plot_slice_figure(data_dict=data_dict_supple,
                  prompt_info=variable_dict["isic"]["prompt_info"],
                  example_per_row=10,
                  row_per_slice=1,
                  normalize=True, 
                  show_small_box=True,
                  skip_section=0,
                  print_alphabet=True,
                  print_legend_number=True,
                  print_legend_color=True,
                  figure_title=None)
fig.savefig(log_dir/"plots"/f"model_audit_from1_to2_supple.pdf", bbox_inches='tight')

In [None]:
fig=plot_slice_figure(data_dict={"a":data_dict_main[0]},
                  prompt_info=variable_dict["isic"]["prompt_info"],
                  example_per_row=10,
                  row_per_slice=1,
                  normalize=True, 
                  show_small_box=True,
                  skip_section=0,
                  print_alphabet=False,
                  print_legend_number=False,
                  print_legend_color=True,
                  print_legend_color_idx=4,
                  figure_title=("E. ", "Trained at Med U. Vienna / Tested at Hosp. Barcelona "))
fig.savefig(log_dir/"plots"/f"model_audit_main_legend.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/f"model_audit_main_legend.pdf", bbox_inches='tight')

In [None]:
similarity_thres=variable_dict["isic"]["similarity_matrix"].quantile(0.5, axis=0)
data_dict_main={}
data_dict_supple={}
count=0
for test_result in test_result_list_from2_to1_concept_only:
    print(test_result["statistics"].sort_values("diff_magnitude", ascending=False))
    concept_name_list_plus=test_result["on_the_spot_plus_pred"][:5]
    sampe_list_plus=pd.concat([variable_dict["isic"]["similarity_matrix"][concept_name_list_plus].sum(axis=1).loc[test_result["labels"].index].rename("concept"),
test_result["labels"]], axis=1).sort_values(["accuracy", "concept"], ascending=[True, False]).index
#     sampe_list_plus=test_result["labels"].sort_values("kmeans_dist").index
#     test_result["labels"].sort_values("kmeans_dist").index,
#                                   test_result["labels_ref"].sort_values("kmeans_dist").index
    sub_title_plus=", ".join([shorten_concept_name(i, strict=False) for i in concept_name_list_plus])
    
    concept_name_list_minus=test_result["on_the_spot_minus_pred"][:3]
    concept_name_list_minus=test_result["statistics"][(test_result["statistics"]["diff_magnitude"]<0)&(test_result["statistics"]["mean_value"]>0.3)].sort_values("diff_magnitude", ascending=True).iloc[:5].index.tolist()
    sub_title_minus=", ".join([shorten_concept_name(i, strict=False) for i in concept_name_list_minus])  
    

    
    if count<5:
        data_dict_main[count]={
            "targets": label_subset_from2_to1.astype(int),
            "preds": (logits_subset_from2_to1>max_f1_thres_isic[hospital_2]).astype(int),
    #         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
    # #                              test_result["labels_ref"].sort_values("kmeans_dist").index
    #                             ],
            "sample_list_list":[
                sampe_list_plus
            ],           
            "main_title": [sub_title_plus, sub_title_minus],
        #     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
        }  
    if count<15:
        data_dict_supple[count]={
            "targets": label_subset_from2_to1.astype(int),
            "preds": (logits_subset_from2_to1>max_f1_thres_isic[hospital_2]).astype(int),
    #         "sample_list_list": [test_result["labels"].sort_values("kmeans_dist").index,
    # #                              test_result["labels_ref"].sort_values("kmeans_dist").index
    #                             ],
            "sample_list_list":[
                sampe_list_plus
            ],           
            "main_title": [sub_title_plus, sub_title_minus],
        #     "slice_assignment_list": [slice_assignment_from1_to2[:,slice_idx] for slice_idx in [3]],
        }          
        
    count+=1
    
fig=plot_slice_figure(data_dict=data_dict_main,
                  prompt_info=variable_dict["isic"]["prompt_info"],
                  example_per_row=5,
                  row_per_slice=1,
                  normalize=True, 
                  show_small_box=True,
                  skip_section=0,
                  print_alphabet=False,
                  print_legend_number=False,
                  print_legend_color=False,
                  slice_title_fontsize=27,
                  figure_title=("E. ", "Trained at Hosp. Barcelona / Tested at Med U. Vienna"))
fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_main.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_main.pdf", bbox_inches='tight')

In [None]:
(data_dict_supple[0]["preds"].loc[data_dict_supple[0]["sample_list_list"][0]][   
data_dict_supple[0]["targets"].loc[data_dict_supple[0]["sample_list_list"][0]]==0   
]==1).sum()

In [None]:
(data_dict_supple[0]["targets"].loc[data_dict_supple[0]["sample_list_list"][0]]==0).sum()

In [None]:
fig=plot_slice_figure(data_dict=data_dict_supple,
                  prompt_info=variable_dict["isic"]["prompt_info"],
                  example_per_row=10,
                  row_per_slice=1,
                  normalize=True, 
                  show_small_box=True,
                  skip_section=0,
                  print_alphabet=True,
                  print_legend_number=True,
                  print_legend_color=True,
                  figure_title=None)
fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/f"model_audit_from2_to1_supple.pdf", bbox_inches='tight')

In [None]:
pd.concat([test_result["statistics"]["diff_magnitude"] for test_result in test_result_list_from1_to2_concept_only[:10]],
axis=1).sum(axis=1).sort_values(ascending=False)

In [None]:
pd.concat([test_result["statistics"]["diff_magnitude"] for test_result in test_result_list_from2_to1_concept_only],
axis=1).sum(axis=1).sort_values(ascending=False)

In [None]:
with open("notebooks/untitled.txt", "r") as f:
    f_lines=f.readlines()

In [None]:
f_lines=[i for i in f_lines if not i.startswith("100%")]

In [None]:
f_lines_valid=[]
for i in range(len(f_lines)):
    
    if i+1<=len(f_lines)-1 and f_lines[i+1].startswith("break"):
        f_lines_valid.append(f_lines[i])

In [None]:
np.mean([float(i.split("AUROC:")[-1].strip()) for i in f_lines_valid])

In [None]:
f_lines_test=[i for i in f_lines if i.startswith("Test loss")]

In [None]:
np.mean([float(i.split("AUROC:")[-1].strip()) for i in f_lines_test])