# set working directory

In [None]:
import os

import hydra
import omegaconf
import pyrootutils

root = pyrootutils.setup_root(os.path.abspath("inherently_interpretable_model.ipynb"), pythonpath=True)
import os

os.chdir(root)

# set python path

In [None]:
import sys

sys.path.append(str(root / "src"))

# import packages

In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import tqdm
from scipy.special import softmax
from scipy.stats import norm, pearsonr
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split
from torchvision import models, transforms

import clip
from MONET.datamodules.multiplex_datamodule import MultiplexDatamodule
from MONET.utils.loader import custom_collate_per_key, dataloader_apply_func
from MONET.utils.static import (
    concept_to_prompt,
    fitzpatrick17k_disease_label,
    fitzpatrick17k_ninelabel,
    fitzpatrick17k_threelabel,
    skincon_cols,
)
from MONET.utils.text_processing import generate_prompt_token_from_concept
from MONET.utils.io import load_pkl
from PIL import Image

In [None]:
def wandb_to_exppath(wandb, log_path="/gscratch/cse/chanwkim/MONET_log/train/runs"):
    log_path = Path(log_path)
    for experiment in os.listdir(log_path):
        if os.path.exists(log_path / experiment / "wandb"):
            filenames = os.listdir(log_path / experiment / "wandb")
            filename = [filename for filename in filenames if filename.startswith("run")][0][-8:]
            if filename == wandb:
                return log_path / experiment
    raise RuntimeError("not found")


# exppath = wandb_to_exppath(
#     wandb="15eh81uv", log_path="/projects/leelab2/chanwkim/dermatology_datasets/logs/train/runs"
# )
# print([exppath / "checkpoints" / ckpt for ckpt in os.listdir(exppath / "checkpoints/")])

In [None]:
log_dir = Path("logs")

In [None]:
!gpustat

# Initialize Model

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:7"

cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

model = hydra.utils.instantiate(cfg_model)
model.to(model_device)
model.eval()

In [None]:
model_path_dir = {
    "zt0n2xd0": "/projects/leelab2/chanwkim/dermatology_datasets/logs/train/runs/2023-01-17_20-58-15/checkpoints/last.ckpt",
}

if model_name != "ViT-L/14":
    model_path = model_path_dir[model_name]
    loaded = torch.load(model_path, map_location=model_device)
    model.load_state_dict(loaded["state_dict"])

In [None]:
model_name = "zt0n2xd0"
model_device = "cuda:7"

cfg_model = omegaconf.OmegaConf.load(root / "configs" / "model" / "contrastive.yaml")
cfg_model.net.model_name_or_path = "ViT-L/14"
cfg_model.net.device = model_device
cfg_model

model_vanilla = hydra.utils.instantiate(cfg_model)
model_vanilla.to(model_device)
model_vanilla.eval()

In [None]:
cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
# cfg.data_dir="/scr/chanwkim/dermatology_datasets"
cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
cfg_dm.dataset_name_test = "fitzpatrick17k_skincon=all"
# cfg_dm.dataset_name_test = "fitzpatrick17k=all"

# cfg_dm.dataset_name_train =

cfg_dm.split_seed = 42
cfg_dm

In [None]:
variable_dict={}

In [None]:
cfg_dm.dataset_name_test = "fitzpatrick17k_clean_threelabel_nodup=all"
dm = hydra.utils.instantiate(cfg_dm)
dm.setup()
# train_dataloader = dm.train_dataloader()
test_dataloader_f17k = dm.test_dataloader()

cfg_dm.dataset_name_test = "fitzpatrick17k_threelabel=all"
dm = hydra.utils.instantiate(cfg_dm)
dm.setup()
# train_dataloader = dm.train_dataloader()
test_dataloader_f17k_all = dm.test_dataloader()

cfg_dm.dataset_name_test = "ddi=all"
dm = hydra.utils.instantiate(cfg_dm)
dm.setup()
# train_dataloader = dm.train_dataloader()
test_dataloader_ddi = dm.test_dataloader()

cfg_dm.dataset_name_test = "fitzpatrick17k_skincon=all"
dm = hydra.utils.instantiate(cfg_dm)
dm.setup()
# train_dataloader = dm.train_dataloader()
test_dataloader_f17k_skincon = dm.test_dataloader()

In [None]:
# 4386

In [None]:
def batch_func(batch):
    with torch.no_grad():
        batch["image"] = batch["image"].to(model_device)
        image_features = model.model_step_with_image(batch)["image_features"]
    # print(batch["metadata"])
    return {
#         "image":  batch["image"],
        "image_features": image_features.detach().cpu(),
        "metadata": batch["metadata"],
    }


print("Featurizing and saving...")
loader_applied = dataloader_apply_func(
    dataloader=test_dataloader_f17k,
    func=batch_func,
    collate_fn=custom_collate_per_key,
)
metadata_all_f17k = loader_applied["metadata"]
image_features_f17k = loader_applied["image_features"].cpu()



loader_applied = dataloader_apply_func(
    dataloader=test_dataloader_f17k_all,
    func=batch_func,
    collate_fn=custom_collate_per_key,
)
metadata_all_f17k_all = loader_applied["metadata"]
image_features_f17k_all = loader_applied["image_features"].cpu()



loader_applied = dataloader_apply_func(
    dataloader=test_dataloader_ddi,
    func=batch_func,
    collate_fn=custom_collate_per_key,
)
metadata_all_ddi = loader_applied["metadata"]
image_features_ddi = loader_applied["image_features"].cpu()



loader_applied = dataloader_apply_func(
    dataloader=test_dataloader_f17k_skincon,
    func=batch_func,
    collate_fn=custom_collate_per_key,
)
metadata_all_f17k_skincon = loader_applied["metadata"]
image_features_f17k_skincon = loader_applied["image_features"].cpu()

In [None]:
def batch_func_vanilla_clip(batch):
    with torch.no_grad():
        batch["image"] = batch["image"].to(model_device)
        image_features = model_vanilla.model_step_with_image(batch)["image_features"]
    # print(batch["metadata"])
    return {
#         "image":  batch["image"],
        "image_features": image_features.detach().cpu(),
        "metadata": batch["metadata"],
    }

loader_applied = dataloader_apply_func(
    dataloader=test_dataloader_f17k,
    func=batch_func_vanilla_clip,
    collate_fn=custom_collate_per_key,
)
metadata_all_f17k_vanilla = loader_applied["metadata"]
image_features_f17k_vanilla = loader_applied["image_features"].cpu()

loader_applied = dataloader_apply_func(
    dataloader=test_dataloader_f17k_all,
    func=batch_func_vanilla_clip,
    collate_fn=custom_collate_per_key,
)
metadata_all_f17k_all_vanilla = loader_applied["metadata"]
image_features_f17k_all_vanilla = loader_applied["image_features"].cpu()
# images_all_f17k = loader_applied["images"]

loader_applied = dataloader_apply_func(
    dataloader=test_dataloader_ddi,
    func=batch_func_vanilla_clip,
    collate_fn=custom_collate_per_key,
)
metadata_all_ddi_vanilla = loader_applied["metadata"]
image_features_ddi_vanilla = loader_applied["image_features"].cpu()

In [None]:
ddi_map = {
    "acral-melanotic-macule": "melanoma look-alike",
    "atypical-spindle-cell-nevus-of-reed": "melanoma look-alike",
    "benign-keratosis": "melanoma look-alike",
    "blue-nevus": "melanoma look-alike",
    "dermatofibroma": "melanoma look-alike",
    "dysplastic-nevus": "melanoma look-alike",
    "epidermal-nevus": "melanoma look-alike",
    "hyperpigmentation": "melanoma look-alike",
    "keloid": "melanoma look-alike",
    "inverted-follicular-keratosis": "melanoma look-alike",
    "melanocytic-nevi": "melanoma look-alike",
    "melanoma": "melanoma",
    "melanoma-acral-lentiginous": "melanoma",
    "melanoma-in-situ": "melanoma",
    "nevus-lipomatosus-superficialis": "melanoma look-alike",
    "nodular-melanoma-(nm)": "melanoma",
    "pigmented-spindle-cell-nevus-of-reed": "melanoma look-alike",
    "seborrheic-keratosis": "melanoma look-alike",
    "seborrheic-keratosis-irritated": "melanoma look-alike",
    "solar-lentigo": "melanoma look-alike",
}

In [None]:
if "index" not in metadata_all_ddi:
    metadata_all_ddi = metadata_all_ddi.reset_index()
#     metadata_all_ddi = metadata_all_ddi[(metadata_all_ddi["skincon_Do not consider this image"]!=1).values]
if "index" not in metadata_all_f17k:
    metadata_all_f17k = metadata_all_f17k.reset_index()
#     metadata_all_f17k = metadata_all_f17k[(metadata_all_f17k["skincon_Do not consider this image"]!=1).values]
if "index" not in metadata_all_f17k_all:
    metadata_all_f17k_all = metadata_all_f17k_all.reset_index()
#     metadata_all_f17k_all = metadata_all_f17k_all[(metadata_all_f17k_all["skincon_Do not consider this image"]!=1).values]

In [None]:
# metadata_all_f17k_filtered = metadata_all_f17k.query(
#     "nine_partition_label == 'malignant melanoma'"
#     " | nine_partition_label == 'benign melanocyte'"
#     " | label == 'seborrheic keratosis'"
#     " | label == 'dermatofibroma'"
# )

In [None]:
# Filter f17k and ddi datasets to melanoma and melanoma look alikes
metadata_all_f17k_filtered = metadata_all_f17k.query(
    "nine_partition_label == 'malignant melanoma'"
    " | nine_partition_label == 'benign melanocyte'"
    " | label == 'seborrheic keratosis'"
    " | label == 'dermatofibroma'"
)

image_features_f17k_filtered = image_features_f17k[metadata_all_f17k_filtered.index]
image_features_f17k_filtered_vanilla = image_features_f17k_vanilla[metadata_all_f17k_filtered.index]
print(image_features_f17k_filtered.shape)

In [None]:
metadata_all_f17k_all_filtered = metadata_all_f17k_all.query(
    "nine_partition_label == 'malignant melanoma'"
    " | nine_partition_label == 'benign melanocyte'"
    " | label == 'seborrheic keratosis'"
    " | label == 'dermatofibroma'"
)
image_features_f17k_all_filtered = image_features_f17k_all[metadata_all_f17k_all_filtered.index]
image_features_f17k_all_filtered_vanilla = image_features_f17k_all_vanilla[metadata_all_f17k_all_filtered.index]
print(image_features_f17k_all_filtered.shape)

In [None]:
mimic_kws = set(ddi_map.keys())
metadata_all_ddi_filtered = metadata_all_ddi.query("disease in @mimic_kws")
image_features_ddi_filtered = image_features_ddi[metadata_all_ddi_filtered.index]
image_features_ddi_filtered_vanilla = image_features_ddi_vanilla[metadata_all_ddi_filtered.index]

In [None]:
image_features_all_melanoma = torch.cat((image_features_f17k_filtered, image_features_ddi_filtered))
image_features_all_malignancy = torch.cat((image_features_f17k, image_features_ddi))
image_features_all_f17all_melanoma = torch.cat((image_features_f17k_all_filtered, image_features_ddi_filtered))
image_features_all_f17all_malignancy = torch.cat((image_features_f17k_all, image_features_ddi))

In [None]:
image_features_all_melanoma_vanilla = torch.cat((image_features_f17k_filtered_vanilla, image_features_ddi_filtered_vanilla))
image_features_all_malignancy_vanilla = torch.cat((image_features_f17k_vanilla, image_features_ddi_vanilla))
image_features_all_f17all_melanoma_vanilla = torch.cat((image_features_f17k_all_filtered_vanilla, image_features_ddi_filtered_vanilla))
image_features_all_f17all_malignancy_vanilla = torch.cat((image_features_f17k_all_vanilla, image_features_ddi_vanilla))

In [None]:
# image_features_all = image_features_f17k
y_f17k_melanoma = list(
    map(
        lambda x: int(x),
        list(metadata_all_f17k_filtered["nine_partition_label"] == "malignant melanoma"),
    )
)
y_f17k_all_melanoma = list(
    map(
        lambda x: int(x),
        list(metadata_all_f17k_all_filtered["nine_partition_label"] == "malignant melanoma"),
    )
)
y_f17k_malignancy = list(
    map(
        lambda x: int(x),
        list(metadata_all_f17k["three_partition_label"] == "malignant"),
    )
)
y_f17k_all_malignancy = list(
    map(
        lambda x: int(x),
        list(metadata_all_f17k_all["three_partition_label"] == "malignant"),
    )
)
#y_ddi_melanoma = list(map(lambda x: int(x), metadata_all_ddi_filtered["disease"] == "melanoma"))
y_ddi_melanoma = list(map(lambda x: int(x), metadata_all_ddi_filtered["disease"].map(lambda x: ddi_map[x]) == "melanoma"))
y_ddi_malignancy = list(map(lambda x: int(x), metadata_all_ddi["malignant"]))

y_melanoma = y_f17k_melanoma + y_ddi_melanoma
y_melanoma_all = y_f17k_all_melanoma + y_ddi_melanoma
y_malignancy = y_f17k_malignancy + y_ddi_malignancy
y_malignancy_all = y_f17k_all_malignancy + y_ddi_malignancy

metadata_all_melanoma=pd.concat([metadata_all_f17k_filtered, metadata_all_ddi_filtered], axis=0)
metadata_all_melanoma_all=pd.concat([metadata_all_f17k_all_filtered, metadata_all_ddi_filtered], axis=0)
metadata_all_malignancy=pd.concat([metadata_all_f17k, metadata_all_ddi], axis=0)
metadata_all_malignancy_all=pd.concat([metadata_all_f17k_all, metadata_all_ddi], axis=0)

In [None]:
np.array(y_f17k_melanoma).sum(), (1-np.array(y_f17k_melanoma)).sum()

In [None]:
len(np.array(y_f17k_melanoma))

In [None]:
np.array(y_melanoma).sum(), (1-np.array(y_melanoma)).sum()

In [None]:
metadata_all_melanoma

In [None]:
def train_using_manual_labels(xtrain,
                             xtest,
                             ytrain,
                             ytest, 
                             alpha=0.001):

    clf_manual_labels = SGDClassifier(loss="log_loss", penalty="l1", alpha=alpha)
    clf_manual_labels.fit(xtrain, ytrain)
    y_pred = clf_manual_labels.predict(xtest)
    auc = roc_auc_score(ytest, clf_manual_labels.predict_proba(xtest)[:, 1])
    #print(f"AUC on test set:{auc}")
    return auc, clf_manual_labels, clf_manual_labels.predict_proba(xtest)[:, 1]
    # accuracy_scores_f17k_test_set.append(auc)


#auc, clf_manual_labels = train_using_manual_labels(test_dataloader_skincon, skincon_cols)

In [None]:
def get_similarity_score(image_features, text_features_dict):
    image_features_norm = image_features / image_features.norm(dim=1, keepdim=True)

    similarity_dict = {}
    for key, text_features in text_features_dict.items():

        text_features_norm = text_features / text_features.norm(dim=1, keepdim=True)

        similarity = image_features_norm.float() @ text_features_norm.T.float()

        if similarity.shape[1] > 1:
            similarity_per_prompt = similarity.softmax(
                dim=0
            )  # (batch_size, num_prompts) -> (batch_size, num_prompts)
            similarity_ensemble = similarity_per_prompt.mean(
                dim=1
            ).numpy()  # (batch_size, num_prompts) -> (batch_size)
        else:
            # (batch_size, 1)
            similarity_ensemble = similarity[:, 0].numpy()

        assert len(similarity_ensemble.shape) == 1

        similarity_dict[key] = similarity_ensemble

    return similarity_dict

In [None]:
from collections import OrderedDict
def tune_best_temp_for_concepts(model, image_features, y, concept_list, train_idx, test_idx, alpha=0.001):
    image_features_all_norm = image_features / image_features.norm(dim=1, keepdim=True)

    x_dict = OrderedDict()

    for j, concept in enumerate(concept_list):
        similarity_list = []
        for concept_value in concept_dict[concept]:
            prompt = f"This is photo of {concept_value}"
            # print(prompt)
            with torch.no_grad():
                output = model.model_step_with_text(
                    {"text": clip.tokenize(prompt).to(model_device)}
                )
                similarity_train = get_similarity_score(
                    image_features=image_features_all_norm,
                    text_features_dict={0: output["text_features"].detach().cpu()},
                )[0]
                #print(similarity_train)

            similarity_list.append(similarity_train)
        similarity_list = np.array(similarity_list)

        similarity_list = similarity_list.T

        x_dict[concept] = similarity_list
        # x[:,j]=sim_prob_list[:,0]
        
    best_temp_dict = {}
    
#     for concept in concept_list:
    best_auc = 0
    best_temperature = None
    best_clf = None
    for temperature in [5, 2, 1, 0.5, 0.2, 0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001]:

        x_softmax = np.array(
                [softmax(x_dict[concept] / temperature, axis=1)[:, 0] for concept in x_dict.keys()]
        ).T
            
        clf = SGDClassifier(loss="log_loss", penalty="l1", alpha=alpha)  # , eta0=1e-1)
        
        xtrain=x_softmax[train_idx]
        xtest=x_softmax[test_idx]
        
        ytrain=np.array(y)[train_idx]
        ytest=np.array(y)[test_idx]
        
        clf.fit(xtrain, ytrain)

        auc = roc_auc_score(ytest, clf.predict_proba(xtest)[:, 1])

        if auc > best_auc:
            best_auc = auc
            best_temperature = temperature
            best_clf = clf
        #print(temperature, auc)
        #print(f"Concept:{concept}: Best auc={best_auc:.3f}: Temp={best_temperature:.3f}")
    
        best_temp_dict[concept] = best_temperature
    return best_auc, best_temperature, best_clf

In [None]:
def train_with_best_temp_softmax(model, image_features_train, image_features_test, ytrain, ytest, concept_list, temp, num_ref_concepts=5, use_template_as_reference=True, alpha=0.001):
    image_features_train_norm = image_features_train / image_features_train.norm(dim=1, keepdim=True)
    image_features_test_norm = image_features_test / image_features_test.norm(dim=1, keepdim=True)
#     print(concept_list)

    x_dict_train = {}
    x_dict_test = {}


    for j, concept in enumerate(concept_list):
        similarity_list_train = []
        similarity_list_test = []
        
        if use_template_as_reference:
            concept_sampled=concept_dict[concept][:1]+np.random.choice(a=concept_dict[concept][1:], size=min(num_ref_concepts, len(concept_dict[concept][1:])), replace=False).tolist()
            prompt_list=[f"This is photo of {concept_value}" for concept_value in concept_sampled]
        else:
            prompt_list=[f"This is photo of {concept_dict[concept][0]}", f"This is photo"]
            
        for prompt in prompt_list:
            with torch.no_grad():
                output = model.model_step_with_text(
                    {"text": clip.tokenize(prompt).to(model_device)}
                )
                similarity_train = get_similarity_score(
                    image_features=image_features_train_norm,
                    text_features_dict={0: output["text_features"].detach().cpu()},
                )[0]

                similarity_test = get_similarity_score(
                    image_features=image_features_test_norm,
                    text_features_dict={0: output["text_features"].detach().cpu()},
                )[0]

            similarity_list_train.append(similarity_train)
            similarity_list_test.append(similarity_test)

            
            
        similarity_list_train = np.array(similarity_list_train).T
        similarity_list_test = np.array(similarity_list_test).T


#         similarity_list = similarity_list.T
#         x_dict[concept] = np.array([similarity_list[:, 0], np.mean(similarity_list[:, 1:], axis=1)]).T
        x_dict_train[concept] = similarity_list_train
        x_dict_test[concept] = similarity_list_test
#         print(x_dict[concept].shape)
        # x[:,j]=sim_prob_list[:,0]
    #print([(x_dict_train[concept] / 0.02)[:, :] for concept in x_dict_train.keys()])
    if num_ref_concepts>0:
        x_softmax_train = np.array(
            [softmax(x_dict_train[concept] / temp, axis=1)[:, 0] for concept in x_dict_train.keys()]
        ).T

        x_softmax_test = np.array(
            [softmax(x_dict_test[concept] / temp, axis=1)[:, 0] for concept in x_dict_test.keys()]
        ).T
    else:
        #print('ehere')
        x_softmax_train = np.array(
            [(x_dict_train[concept]/temp )[:, 0] for concept in x_dict_train.keys()]
        ).T

        x_softmax_test = np.array(
            [(x_dict_test[concept]/temp )[:, 0] for concept in x_dict_test.keys()]
        ).T        
        

#     xtrain, xtest, ytrain, ytest = train_test_split(
#         x_softmax, y, random_state=8, test_size=0.2, shuffle=True
#     )

    clf = SGDClassifier(loss="log_loss", penalty="l1", alpha=alpha)  # , eta0=1e-1)
    clf.fit(x_softmax_train, ytrain)

    auc = roc_auc_score(ytest, clf.predict_proba(x_softmax_test)[:, 1])
    return clf, auc, clf.predict_proba(x_softmax_test)[:, 1]

In [None]:
def return_train_with_best_temp_softmax(model, image_features_train, image_features_test, ytrain, ytest, concept_list, temp, num_ref_concepts=5, use_template_as_reference=True, alpha=0.001):
    image_features_train_norm = image_features_train / image_features_train.norm(dim=1, keepdim=True)
    image_features_test_norm = image_features_test / image_features_test.norm(dim=1, keepdim=True)
#     print(concept_list)

    x_dict_train = {}
    x_dict_test = {}


    for j, concept in enumerate(concept_list):
        similarity_list_train = []
        similarity_list_test = []
        
        if use_template_as_reference:
            concept_sampled=concept_dict[concept][:1]+np.random.choice(a=concept_dict[concept][1:], size=min(num_ref_concepts, len(concept_dict[concept][1:])), replace=False).tolist()
            prompt_list=[f"This is photo of {concept_value}" for concept_value in concept_sampled]
        else:
            prompt_list=[f"This is photo of {concept_dict[concept][0]}", f"This is photo"]
            
        for prompt in prompt_list:
            print("prompt:", prompt)
            with torch.no_grad():
                output = model.model_step_with_text(
                    {"text": clip.tokenize(prompt).to(model_device)}
                )
                similarity_train = get_similarity_score(
                    image_features=image_features_train_norm,
                    text_features_dict={0: output["text_features"].detach().cpu()},
                )[0]

                similarity_test = get_similarity_score(
                    image_features=image_features_test_norm,
                    text_features_dict={0: output["text_features"].detach().cpu()},
                )[0]

            similarity_list_train.append(similarity_train)
            similarity_list_test.append(similarity_test)

            
            
        similarity_list_train = np.array(similarity_list_train).T
        similarity_list_test = np.array(similarity_list_test).T


#         similarity_list = similarity_list.T
#         x_dict[concept] = np.array([similarity_list[:, 0], np.mean(similarity_list[:, 1:], axis=1)]).T
        x_dict_train[concept] = similarity_list_train
        x_dict_test[concept] = similarity_list_test
#         print(x_dict[concept].shape)
        # x[:,j]=sim_prob_list[:,0]
    #print([(x_dict_train[concept] / 0.02)[:, :] for concept in x_dict_train.keys()])
    
    x_softmax_train = np.array(
            [softmax(x_dict_train[concept] / temp, axis=1)[:, 0] for concept in x_dict_train.keys()]
        ).T     
    
    x_softmax_test = np.array(
        [softmax(x_dict_test[concept] / temp, axis=1)[:, 0] for concept in x_dict_test.keys()]
    ).T    
    clf = SGDClassifier(loss="log_loss", penalty="l1", alpha=alpha)  # , eta0=1e-1)
    clf.fit(x_softmax_train, ytrain)
    
    auc = roc_auc_score(ytest, 
                        clf.predict_proba(x_softmax_test)[:, 1])        
    print(list(zip(concept_list,clf.coef_[0,:])), clf.intercept_)
    print(auc)    
    
     
    
    
    clf_to1 = SGDClassifier(loss="log_loss", penalty="l1", alpha=alpha)  # , eta0=1e-1)
    clf_to1.fit(x_softmax_train/x_softmax_train.max(axis=0, keepdims=True), ytrain)
    
    auc = roc_auc_score(ytest, 
                        clf_to1.predict_proba(x_softmax_test/x_softmax_train.max(axis=0, keepdims=True))[:, 1])        
    print(list(zip(concept_list,clf_to1.coef_[0,:])), clf_to1.intercept_)
    print(auc)   
    
    
    clf_std = SGDClassifier(loss="log_loss", penalty="l1", alpha=alpha)  # , eta0=1e-1)
    clf_std.fit(x_softmax_train/x_softmax_train.std(axis=0, keepdims=True), ytrain)
    
    print((x_softmax_train/x_softmax_train.std(axis=0, keepdims=True)).max(axis=0))

    auc = roc_auc_score(ytest, 
                        clf_std.predict_proba(x_softmax_test/x_softmax_train.std(axis=0, keepdims=True))[:, 1])        
    
    print(list(zip(concept_list,clf_std.coef_[0,:])), clf_std.intercept_)
    print(auc)       
          
    
    x_softmax_train_temp = np.array(
            [softmax(x_dict_train[concept] / temp, axis=1)[:, 0] for concept in x_dict_train.keys()]
        ).T    
    
    x_softmax_train_notemp = np.array(
            [softmax(x_dict_train[concept] / 1, axis=1)[:, 0] for concept in x_dict_train.keys()]
        ).T 
    
    x_softmax_train_max1 = np.array(
            [softmax(x_dict_train[concept] / temp, axis=1)[:, 0] for concept in x_dict_train.keys()]
        ).T     
    x_softmax_train_max1=x_softmax_train_max1/x_softmax_train_max1.max(axis=0, keepdims=True)
    
    x_softmax_train_std = np.array(
            [softmax(x_dict_train[concept] / temp, axis=1)[:, 0] for concept in x_dict_train.keys()]
        ).T     
    x_softmax_train_std=x_softmax_train_std/x_softmax_train_std.std(axis=0, keepdims=True)    
    
    x_nosoftmax_train = np.array(
            [(x_dict_train[concept][:,0]) for concept in x_dict_train.keys()]
        ).T            
    
    for j, concept in enumerate(concept_list):
        fig=plt.figure(figsize=(16,3))
        axes=fig.subplots(1,5)

        axes[0].hist(x_nosoftmax_train[:,j])
        axes[0].set_title(f"Original (cosine sim)\nmean: {x_nosoftmax_train[:,j].mean():.4f} std: {x_nosoftmax_train[:,j].std():.4f}")
        axes[1].hist(x_softmax_train_notemp[:,j])
        axes[1].set_title(f"softmax (temp=1)\nmean: {x_softmax_train_notemp[:,j].mean():.4f} std: {x_softmax_train_notemp[:,j].std():.4f}")
        axes[2].hist(x_softmax_train_temp[:,j])
        axes[2].set_title(f"softmax (temp=0.02)\nmean: {x_softmax_train_temp[:,j].mean():.4f} std: {x_softmax_train_temp[:,j].std():.4f}")
        
        axes[3].hist(x_softmax_train_max1[:,j])
        axes[3].set_title(f"softmax and max to 1\nmean: {x_softmax_train_max1[:,j].mean():.4f} std: {x_softmax_train_max1[:,j].std():.4f}")        
        
        axes[4].hist(x_softmax_train_std[:,j])
        axes[4].set_title(f"softmax and divide by std\nmean: {x_softmax_train_std[:,j].mean():.4f} std: {x_softmax_train_std[:,j].std():.4f}")                
        
        
        fig.suptitle(concept, y=1.1)
        
    return clf, clf_to1, clf_std

In [None]:
def return_train_with_best_temp_softmax_(model, image_features_train, image_features_test, ytrain, ytest, concept_list, temp, num_ref_concepts=5, use_template_as_reference=True):
    image_features_train_norm = image_features_train / image_features_train.norm(dim=1, keepdim=True)
    image_features_test_norm = image_features_test / image_features_test.norm(dim=1, keepdim=True)
#     print(concept_list)

    x_dict_train = {}
    x_dict_test = {}


    for j, concept in enumerate(concept_list):
        similarity_list_train = []
        similarity_list_test = []

        concept_sampled=concept_dict[concept][:1]+np.random.choice(a=concept_dict[concept][1:], size=min(num_ref_concepts, len(concept_dict[concept][1:])), replace=False).tolist()
        prompt_list=[f"This is photo of {concept_value}" for concept_value in concept_sampled]        
        

        print("prompt:", prompt_list)
        with torch.no_grad():
            output = model.model_step_with_text(
                {"text": clip.tokenize(prompt_list).to(model_device)}
            )
        for features in [output["text_features"][0:1], output["text_features"][1:].mean(axis=0, keepdims=True)]:            
            similarity_train = get_similarity_score(
                image_features=image_features_train_norm,
                text_features_dict={0: features.detach().cpu()},
            )[0]

            similarity_test = get_similarity_score(
                image_features=image_features_test_norm,
                text_features_dict={0: features.detach().cpu()},
            )[0]
            
            similarity_list_train.append(similarity_train)
            similarity_list_test.append(similarity_test)            
            
        similarity_list_train = np.array(similarity_list_train).T
        similarity_list_test = np.array(similarity_list_test).T

#         similarity_list = similarity_list.T
#         x_dict[concept] = np.array([similarity_list[:, 0], np.mean(similarity_list[:, 1:], axis=1)]).T
        x_dict_train[concept] = similarity_list_train
        x_dict_test[concept] = similarity_list_test
#         print(x_dict[concept].shape)
        # x[:,j]=sim_prob_list[:,0]
    #print([(x_dict_train[concept] / 0.02)[:, :] for concept in x_dict_train.keys()])
    
    x_softmax_train = np.array(
            [softmax(x_dict_train[concept] / temp, axis=1)[:, 0] for concept in x_dict_train.keys()]
        ).T     
    
    x_softmax_test = np.array(
        [softmax(x_dict_test[concept] / temp, axis=1)[:, 0] for concept in x_dict_test.keys()]
    ).T    
    clf = SGDClassifier(loss="log_loss", penalty="l1", alpha=0.001)  # , eta0=1e-1)
    clf.fit(x_softmax_train, ytrain)
    
    auc = roc_auc_score(ytest, 
                        clf.predict_proba(x_softmax_test)[:, 1])        
    print(list(zip(concept_list,clf.coef_[0,:])), clf.intercept_)
    print(auc)    
    
     
    
    
    clf_to1 = SGDClassifier(loss="log_loss", penalty="l1", alpha=0.001)  # , eta0=1e-1)
    clf_to1.fit(x_softmax_train/x_softmax_train.max(axis=0, keepdims=True), ytrain)
    
    auc = roc_auc_score(ytest, 
                        clf_to1.predict_proba(x_softmax_test/x_softmax_train.max(axis=0, keepdims=True))[:, 1])        
    print(list(zip(concept_list,clf_to1.coef_[0,:])), clf_to1.intercept_)
    print(auc)   
    
    
    clf_std = SGDClassifier(loss="log_loss", penalty="l1", alpha=0.001)  # , eta0=1e-1)
    clf_std.fit(x_softmax_train/x_softmax_train.std(axis=0, keepdims=True), ytrain)
    
    print((x_softmax_train/x_softmax_train.std(axis=0, keepdims=True)).max(axis=0))

    auc = roc_auc_score(ytest, 
                        clf_std.predict_proba(x_softmax_test/x_softmax_train.std(axis=0, keepdims=True))[:, 1])        
    
    print(list(zip(concept_list,clf_std.coef_[0,:])), clf_std.intercept_)
    print(auc)       
          
    
    x_softmax_train_temp = np.array(
            [softmax(x_dict_train[concept] / temp, axis=1)[:, 0] for concept in x_dict_train.keys()]
        ).T    
    
    x_softmax_train_notemp = np.array(
            [softmax(x_dict_train[concept] / 1, axis=1)[:, 0] for concept in x_dict_train.keys()]
        ).T 
    
    x_softmax_train_max1 = np.array(
            [softmax(x_dict_train[concept] / temp, axis=1)[:, 0] for concept in x_dict_train.keys()]
        ).T     
    x_softmax_train_max1=x_softmax_train_max1/x_softmax_train_max1.max(axis=0, keepdims=True)
    
    x_softmax_train_std = np.array(
            [softmax(x_dict_train[concept] / temp, axis=1)[:, 0] for concept in x_dict_train.keys()]
        ).T     
    x_softmax_train_std=x_softmax_train_std/x_softmax_train_std.std(axis=0, keepdims=True)    
    
    x_nosoftmax_train = np.array(
            [(x_dict_train[concept][:,0]) for concept in x_dict_train.keys()]
        ).T            
    
    for j, concept in enumerate(concept_list):
        fig=plt.figure(figsize=(16,3))
        axes=fig.subplots(1,5)

        axes[0].hist(x_nosoftmax_train[:,j])
        axes[0].set_title(f"Original (cosine sim)\nmean: {x_nosoftmax_train[:,j].mean():.4f} std: {x_nosoftmax_train[:,j].std():.4f}")
        axes[1].hist(x_softmax_train_notemp[:,j])
        axes[1].set_title(f"softmax (temp=1)\nmean: {x_softmax_train_notemp[:,j].mean():.4f} std: {x_softmax_train_notemp[:,j].std():.4f}")
        axes[2].hist(x_softmax_train_temp[:,j])
        axes[2].set_title(f"softmax (temp=0.02)\nmean: {x_softmax_train_temp[:,j].mean():.4f} std: {x_softmax_train_temp[:,j].std():.4f}")
        
        axes[3].hist(x_softmax_train_max1[:,j])
        axes[3].set_title(f"softmax and max to 1\nmean: {x_softmax_train_max1[:,j].mean():.4f} std: {x_softmax_train_max1[:,j].std():.4f}")        
        
        axes[4].hist(x_softmax_train_std[:,j])
        axes[4].set_title(f"softmax and divide by std\nmean: {x_softmax_train_std[:,j].mean():.4f} std: {x_softmax_train_std[:,j].std():.4f}")                
        
        
        fig.suptitle(concept, y=1.1)
        
    return clf, clf_to1, clf_std

In [None]:
return_train_with_best_temp_softmax(model=model_select,
                                 image_features_train=image_select_train,                           
                                 image_features_test=image_select_test, 
                                 ytrain=y_select_train, 
                                 ytest=y_select_test,
                                 concept_list=concept_list_target, 
                                 temp=0.02,
                                 num_ref_concepts=100) 

In [None]:
return_train_with_best_temp_softmax(model=model_select,
                                 image_features_train=image_select_train,                           
                                 image_features_test=image_select_test, 
                                 ytrain=y_select_train, 
                                 ytest=y_select_test,
                                 concept_list=concept_list_target, 
                                 temp=0.01,
                                 num_ref_concepts=100, alpha=0.0001) 

In [None]:
clf, clf_to1, clf_std=return_train_with_best_temp_softmax_(model=model_select,
                                 image_features_train=image_select_train,                           
                                 image_features_test=image_select_test, 
                                 ytrain=y_select_train, 
                                 ytest=y_select_test,
                                 concept_list=concept_list_target, 
                                 temp=0.02,
                                 num_ref_concepts=100) 

In [None]:
normal_skin=['clean', "smooth", 'Healthy', 'normal', 'soft', 'flat']
concept_dict = {
    "Asymmetry": ["Asymmetry", "Symmetry", "Regular", "Uniform"],
    "Irregular": ["Irregular", "Regular", "Smooth"],
    "Black": ["Black", "White", "Creamy", "Colorless", "Unpigmented"],
    "Blue": ["Blue", "Green", "Red"],
    "White": ["White", "Black", "Colored", "Pigmented"],
    "Brown": ["Brown", "Pale", "White"],
    "Erosion":["Erosion", "Deposition", "Buildup"],
    "Multiple Colors": ["Multiple colors", "Single Color", "Unicolor"],
    "Tiny": ["Tiny", "Large", "Big"],
    "Regular": ["Regular", "Irregular"],  
}

for key in concept_dict.keys():
    print(f"{key}: {concept_dict[key]}")

In [None]:
concept_dict_temp={}
for concept_name in skincon_cols:    
    prompt_dict, text_counter = concept_to_prompt(concept_name[8:])
    prompt_engineered_list = []
    for k, v in prompt_dict.items():
        if k != "original":
            prompt_engineered_list += v    
    concept_term_list = list(set([prompt.replace("This is ", "").replace("This photo is ", "").replace("This lesion is ", "").replace("skin has become ", "").lower()
                              for prompt in prompt_engineered_list]))    
    
    
    if concept_name=="skincon_Patch":
        negative_terms=["Spotted"]    
    elif concept_name == "skincon_Exudate":
        negative_terms = ["Absence"]
    elif concept_name == "skincon_Xerosis":
        negative_terms = ["Moisturized"]
    elif concept_name == "skincon_Warty/Papillomatous":
        negative_terms = ["Smooth"]
    elif concept_name == "skincon_Dome-shaped":
        negative_terms = ["Flat"]
    elif concept_name == "skincon_Brown(Hyperpigmentation)":
        negative_terms = ["Hypopigmentation"]
    elif concept_name == "skincon_Translucent":
        negative_terms = ["Opaque"]
    elif concept_name == "skincon_White(Hypopigmentation)":
        negative_terms = ["Hyperpigmentation"]
    elif concept_name == "skincon_Purple":
        negative_terms = ["Yellow"]
    elif concept_name == "skincon_Yellow":
        negative_terms = ["Purple"]
    elif concept_name == "skincon_Black":
        negative_terms = ["White", "Creamy", "Colorless", "Unpigmented"]
    elif concept_name == "skincon_Lichenification":
        negative_terms = ["Softening"]
    elif concept_name == "skincon_Blue":
        negative_terms = ["Orange"]
    elif concept_name == "skincon_Gray":
        negative_terms = ["Colorful"]
    else:
        negative_terms = ['clean', 'smooth', 'Healthy', 'normal', 'soft', 'flat']
        
    concept_dict_temp[concept_name]=[concept_term_list[0]]+negative_terms
concept_dict.update(concept_dict_temp)

In [None]:
concept_list_curated=['Asymmetry', 'Irregular', 'Black', 'Blue', 'White', 'Brown', 
                      'Erosion',
                      'Multiple Colors', 'Tiny', 'Regular']

In [None]:
!gpustat

In [None]:
f17k_data_dir = "/sdata/chanwkim/dermatology_datasets/fitzpatrick17k/final_image"
ddi_data_dir = "/sdata/chanwkim/dermatology_datasets/ddi/final_image"
image_paths_f17k = [f17k_data_dir + "/" + image_id for image_id in list(metadata_all_f17k['index'])]
image_paths_ddi = [ddi_data_dir + "/" + image_id for image_id in list(metadata_all_ddi['index'])]
image_paths_all = image_paths_f17k + image_paths_ddi

image_paths_f17k_filtered = [f17k_data_dir + "/" + image_id for image_id in list(metadata_all_f17k_filtered['index'])]
image_paths_ddi_filtered = [ddi_data_dir + "/" + image_id for image_id in list(metadata_all_ddi_filtered['index'])]
image_paths_all_filtered = image_paths_f17k_filtered + image_paths_ddi_filtered

IMAGE_SIZE = 224

norm_constants = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
image_path_dict_all=OrderedDict()
for key, value in test_dataloader_ddi.dataset.image_path_dict.items():
    image_path_dict_all[key]=value
for key, value in test_dataloader_f17k_all.dataset.image_path_dict.items():
    image_path_dict_all[key]=value    

In [None]:
def convert_image_to_rgb(image):
    return image.convert("RGB")

class ResNetDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths_list, labels):
        assert len(image_paths_list)==len(labels)
        self.image_paths_list = image_paths_list
        self.transforms = transforms.Compose([
                                transforms.RandomResizedCrop(IMAGE_SIZE),
                                transforms.RandomHorizontalFlip(),
                                convert_image_to_rgb,
                                transforms.ToTensor(),
                                transforms.Normalize(*norm_constants),
                            ])
        self.labels = labels
        
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths_list[idx])
#         print(image.size)
        image = self.transforms(image)
        
        return image, self.labels[idx]

In [None]:
import torchvision
from torch import nn
from torch.nn import functional as F
from torchmetrics import AUROC

class Classifier(nn.Module):
    def __init__(self, output_dim, freeze_backbone=False):
        super().__init__()
        self.backbone = torchvision.models.resnet50(weights="ResNet50_Weights.IMAGENET1K_V1")

        for param in self.backbone.parameters():
            if freeze_backbone:
                param.requires_grad = False
            else:
                param.requires_grad = True
            # pass

        head_in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.head = nn.Linear(head_in_features, output_dim)

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x


class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
        
def train_classifier(train_dataloader, val_dataloader, test_dataloader, freeze_backbone, lr, verbose):
    classifier = Classifier(output_dim=1, freeze_backbone=freeze_backbone)
    classifier_device = "cuda:6"
    classifier.to(classifier_device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=2, verbose=True)
    early_stopper = EarlyStopper(patience=5, min_delta=0)

    train_auroc = AUROC(task="binary")
    val_auroc = AUROC(task="binary")
    for epoch in range(50):
        train_loss = 0
        train_correct = 0
        classifier.train()
        if verbose:
            pbar=tqdm.tqdm(train_dataloader)
        else:
            pbar=train_dataloader
        for batch in pbar:
            image, label = batch[0].to(classifier_device), batch[1].to(classifier_device)
            logits = classifier(image)
            weight = torch.ones(label.shape[0], device=label.device)
            loss = F.binary_cross_entropy_with_logits(
                input=logits[:, 0], target=(label == 1).float()
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * image.size(0)
            train_auroc.update(logits, (label == 1))

        val_loss = 0
        val_auc_best=0
        classifier.eval()
        with torch.no_grad():
            if verbose:
                pbar=tqdm.tqdm(val_dataloader)
            else:
                pbar=val_dataloader
            for batch in pbar:
                image, label = batch[0].to(classifier_device), batch[1].to(
                    classifier_device
                )
                logits = classifier(image)
                loss = F.binary_cross_entropy_with_logits(
                    input=logits[:, 0], target=(label == 1).float()
                )
                val_loss += loss.item() * image.size(0)
                val_auroc.update(logits, (label == 1))
        if verbose:
            print(
                f"Epoch {epoch}: Train loss: {train_loss/len(train_dataloader.dataset):.3f} AUROC: {train_auroc.compute():.3f} Val loss: {val_loss/len(val_dataloader.dataset):.3f} AUROC: {val_auroc.compute():.3f}"
            )
        if val_auroc.compute() > val_auc_best:
            val_auc_best = val_auroc.compute()        
        
        scheduler.step(val_loss)
        if early_stopper.early_stop(val_loss):
            if verbose:
                print("break")
            break
        train_auroc.reset()
        val_auroc.reset() 
        
    test_auroc = AUROC(task="binary")    
    test_loss = 0
    classifier.eval()
    test_preds=[]
    with torch.no_grad():
        if verbose:
            pbar=tqdm.tqdm(test_dataloader)
        else:
            pbar=test_dataloader        
        for batch in pbar:
            image, label = batch[0].to(classifier_device), batch[1].to(
                classifier_device
            )
            logits = classifier(image)
            loss = F.binary_cross_entropy_with_logits(
                input=logits[:, 0], target=(label == 1).float()
            )
            test_loss += loss.item() * image.size(0)
            test_auroc.update(logits, (label == 1))
            test_preds+=logits.detach().cpu().numpy().tolist()

    if verbose:
        print(
            f"Test loss: {test_loss/len(test_dataloader.dataset):.3f} AUROC: {test_auroc.compute():.3f}"
        )   
    
    return classifier, test_auroc.compute(), test_preds

def generate_data_and_train_resnet(metadata_select, y_select, train_idx, test_idx, freeze_backbone=False, lr=1e-3):  
    train_idx_train, train_idx_valid=train_test_split(train_idx, random_state=random_seed, test_size=0.25, shuffle=True)
    
    train_dataset_resnet = ResNetDataset(image_paths_list=[image_path_dict_all[idx] for idx in metadata_select["index"].iloc[train_idx_train]], 
                                         labels=np.array(y_select)[train_idx_train])
    train_dataloader_resnet = torch.utils.data.DataLoader(
            train_dataset_resnet, batch_size=32, shuffle=True, pin_memory=True,
            drop_last=False, num_workers=4)

    valid_dataset_resnet = ResNetDataset(image_paths_list=[image_path_dict_all[idx] for idx in metadata_select["index"].iloc[train_idx_valid]], 
                                         labels=np.array(y_select)[train_idx_valid])
    valid_dataloader_resnet = torch.utils.data.DataLoader(
            valid_dataset_resnet, batch_size=32, shuffle=True, pin_memory=True,
            drop_last=False, num_workers=4)

    test_dataset_resnet = ResNetDataset(image_paths_list=[image_path_dict_all[idx] for idx in metadata_select["index"].iloc[test_idx]], 
                                         labels=np.array(y_select)[test_idx])
    test_dataloader_resnet = torch.utils.data.DataLoader(
            test_dataset_resnet, batch_size=32, shuffle=False, pin_memory=True,
            drop_last=False, num_workers=4)            
    
    classifier, auc_best, test_preds=train_classifier(train_dataloader_resnet, 
                                          valid_dataloader_resnet, 
                                          test_dataloader_resnet, freeze_backbone=freeze_backbone, lr=lr, verbose=False)
    
    return auc_best.item(), test_preds
    

In [None]:
!gpustat

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.ticker as ticker

In [None]:
def calculate_zero_short(model, image_features_train, image_features_test, ytrain, ytest, temp, num_ref_concepts=5, use_template_as_reference=True):
    image_features_train_norm = image_features_train / image_features_train.norm(dim=1, keepdim=True)
    image_features_test_norm = image_features_test / image_features_test.norm(dim=1, keepdim=True)
#     print(concept_list)

    x_dict_train = {}
    x_dict_test = {}


    for j, concept in enumerate(["melanoma"]):
        similarity_list_train = []
        similarity_list_test = []
        
        prompt_list=[f"This is photo of {concept}", f"This is photo"]
            
        for prompt in prompt_list:
            with torch.no_grad():
                output = model.model_step_with_text(
                    {"text": clip.tokenize(prompt).to(model_device)}
                )
                similarity_train = get_similarity_score(
                    image_features=image_features_train_norm,
                    text_features_dict={0: output["text_features"].detach().cpu()},
                )[0]

                similarity_test = get_similarity_score(
                    image_features=image_features_test_norm,
                    text_features_dict={0: output["text_features"].detach().cpu()},
                )[0]

            similarity_list_train.append(similarity_train)
            similarity_list_test.append(similarity_test)

        
        similarity_list_train = np.array(similarity_list_train).T
        similarity_list_test = np.array(similarity_list_test).T


#         similarity_list = similarity_list.T
#         x_dict[concept] = np.array([similarity_list[:, 0], np.mean(similarity_list[:, 1:], axis=1)]).T
        x_dict_train[concept] = similarity_list_train
        x_dict_test[concept] = similarity_list_test
#         print(x_dict[concept].shape)
        # x[:,j]=sim_prob_list[:,0]
    #print([(x_dict_train[concept] / 0.02)[:, :] for concept in x_dict_train.keys()])
    print(x_dict_train.keys())
    if num_ref_concepts>0:
        x_softmax_train = np.array(
            [softmax(x_dict_train[concept] / temp, axis=1)[:, 0] for concept in x_dict_train.keys()]
        ).T

        x_softmax_test = np.array(
            [softmax(x_dict_test[concept] / temp, axis=1)[:, 0] for concept in x_dict_test.keys()]
        ).T
    else:
        #print('ehere')
        x_softmax_train = np.array(
            [(x_dict_train[concept]/temp )[:, 0] for concept in x_dict_train.keys()]
        ).T

        x_softmax_test = np.array(
            [(x_dict_test[concept]/temp )[:, 0] for concept in x_dict_test.keys()]
        ).T        
        


#     xtrain, xtest, ytrain, ytest = train_test_split(
#         x_softmax, y, random_state=8, test_size=0.2, shuffle=True
#     )

    print(x_softmax_train.shape, x_softmax_test.shape, roc_auc_score(ytest, x_softmax_test[:,0]))
    
    clf = SGDClassifier(loss="log_loss", penalty="l1", alpha=0.001)  # , eta0=1e-1)
    clf.fit(x_softmax_train, ytrain)

    auc = roc_auc_score(ytest, clf.predict_proba(x_softmax_test)[:, 1])        
    return auc

In [None]:
skincon_cols.index("skincon_Ulcer"), skincon_cols.index("skincon_Erosion")

In [None]:
metadata_select_skincon_train[skincon_cols].values[:,9][y_select_train==1].mean()

In [None]:
metadata_select_skincon_train[skincon_cols].values[:,9][y_select_train==0].mean()

In [None]:
metadata_select_skincon_train[skincon_cols].values[:,11][y_select_train==1].mean()

In [None]:
metadata_select_skincon_train[skincon_cols].values[:,11][y_select_train==0].mean()

In [None]:
pd.Series(clf.coef_[0], index=skincon_cols)

In [None]:
skincon_cols

In [None]:
record_all_list_new=[]
for record in record_all_list:
    if "alpha" in record:
        
    else:
        record_all_list_new.append(record)

In [None]:
[record for record in record_all_list if "alpha" in record and record["alpha"]==0.0001]

# test

In [None]:
int(len(y_melanoma)*0.8), int(len(y_malignancy)*0.8) 

# main

In [None]:
# record_all_list=[]
for task in ["melanoma", "malignancy"]:
# for task in ["melanoma"]:
    #for is_clean in ["clean_only", "all"]:
    for is_clean in ["clean_only"]:
        if task=="melanoma" and is_clean=="clean_only":
            metadata_select=metadata_all_melanoma[(metadata_all_melanoma["skincon_Do not consider this image"]!=1).values]
            y_select=np.array(y_melanoma)[(metadata_all_melanoma["skincon_Do not consider this image"]!=1).values]
            image_monet_select=image_features_all_melanoma[(metadata_all_melanoma["skincon_Do not consider this image"]!=1).values]
            image_vanilla_select=image_features_all_melanoma_vanilla[(metadata_all_melanoma["skincon_Do not consider this image"]!=1).values]
        elif task=="melanoma" and is_clean=="all":
            metadata_select=metadata_all_melanoma_all[(metadata_all_melanoma_all["skincon_Do not consider this image"]!=1).values]
            y_select=np.array(y_melanoma_all)[(metadata_all_melanoma_all["skincon_Do not consider this image"]!=1).values]
            image_monet_select=image_features_all_f17all_melanoma[(metadata_all_melanoma_all["skincon_Do not consider this image"]!=1).values]
            image_vanilla_select=image_features_all_f17all_melanoma_vanilla[(metadata_all_melanoma_all["skincon_Do not consider this image"]!=1).values]
        elif task=="malignancy" and is_clean=="clean_only":
            metadata_select=metadata_all_malignancy[(metadata_all_malignancy["skincon_Do not consider this image"]!=1).values]
            y_select=np.array(y_malignancy)[(metadata_all_malignancy["skincon_Do not consider this image"]!=1).values]
            image_monet_select=image_features_all_malignancy[(metadata_all_malignancy["skincon_Do not consider this image"]!=1).values]
            image_vanilla_select=image_features_all_malignancy_vanilla[(metadata_all_malignancy["skincon_Do not consider this image"]!=1).values]
        elif task=="malignancy" and is_clean=="all":
            metadata_select=metadata_all_malignancy_all[(metadata_all_malignancy_all["skincon_Do not consider this image"]!=1).values]
            y_select=np.array(y_malignancy_all)[(metadata_all_malignancy_all["skincon_Do not consider this image"]!=1).values]
            image_monet_select=image_features_all_f17all_malignancy[(metadata_all_malignancy_all["skincon_Do not consider this image"]!=1).values]
            image_vanilla_select=image_features_all_f17all_malignancy_vanilla[(metadata_all_malignancy_all["skincon_Do not consider this image"]!=1).values]
        else:
            raise NotImplementedError(task, is_clean)        
#         if task=="melanoma" and is_clean=="clean_only":
#             y_select=y_melanoma
#             metadata_select=metadata_all_melanoma
#             image_monet_select=image_features_all_melanoma
#             image_vanilla_select=image_features_all_melanoma_vanilla
#         elif task=="melanoma" and is_clean=="all":
#             y_select=y_melanoma_all
#             metadata_select=metadata_all_melanoma_all            
#             image_monet_select=image_features_all_f17all_melanoma
#             image_vanilla_select=image_features_all_f17all_melanoma_vanilla            
#         elif task=="malignancy" and is_clean=="clean_only":
#             y_select=y_malignancy              
#             metadata_select=metadata_all_malignancy
#             image_monet_select=image_features_all_malignancy
#             image_vanilla_select=image_features_all_malignancy_vanilla            
#         elif task=="malignancy" and is_clean=="all":
#             y_select=y_malignancy_all
#             metadata_select=metadata_all_malignancy_all            
#             image_monet_select=image_features_all_f17all_malignancy
#             image_vanilla_select=image_features_all_f17all_malignancy_vanilla            
#         else:
#             raise NotImplementedError(task, is_clean)
            
        assert len(y_select)==len(metadata_select)==len(image_monet_select)==len(image_vanilla_select)
        
        for random_seed in tqdm.tqdm(range(1,20)):
            train_idx, test_idx = train_test_split(np.arange(len(y_select)), random_state=random_seed, test_size=0.2, shuffle=True)

            #print(task, is_clean, len(y_select), len(metadata_select), len(train_idx), len(test_idx))

#             for method in ["skincon_manual"]:
            for method in ["skincon_manual", "automatic", "resnet", "resnet_freeze_backbone"]:
                if method=="skincon_manual":
                    for test_mode in ["full", "less_concept", "less_sample"]:
                        if test_mode=="full":
                            #for alpha in [0.001, 0.0001]:
                            for alpha in [0.001]:
                                metadata_select_train=metadata_select[skincon_cols].iloc[train_idx]
                                metadata_select_skincon_train=metadata_select_train[~metadata_select_train["skincon_Vesicle"].isnull()]
                                metadata_select_test=metadata_select[skincon_cols].iloc[test_idx]
                                metadata_select_skincon_test=metadata_select_test[~metadata_select_test["skincon_Vesicle"].isnull()]

                                y_select_train=np.array(y_select)[train_idx][~metadata_select_train["skincon_Vesicle"].isnull()]
                                y_select_test=np.array(y_select)[test_idx][~metadata_select_test["skincon_Vesicle"].isnull()]

                                auc,_,y_test_pred=train_using_manual_labels(xtrain=metadata_select_skincon_train[skincon_cols].values,
                                                          xtest=metadata_select_skincon_test[skincon_cols].values,
                                                          ytrain=y_select_train,
                                                          ytest=y_select_test,
                                                                alpha=alpha,
                                                         )

                                print(task, is_clean, len(y_select), random_seed, method, test_mode, alpha, f"{auc:.3f}")
                                record_all_list.append({"task": task,
                                                        "is_clean": is_clean,
                                                        "num_sample": len(y_select_train)+len(y_select_test),
                                                        "num_sample_train": len(y_select_train),
                                                        "num_sample_test": len(y_select_test),
                                                        "random_seed": random_seed,                                            
                                                        "method": method+"_"+test_mode,
                                                        "auc":auc,
                                                        "y_test":y_select_test,
                                                        "y_test_pred":y_test_pred,
                                                        "alpha": alpha,
                                                       })
                            
                        elif test_mode=="less_concept":
                            for num_concept in [1]+list(range(5, len(skincon_cols), 5))+[len(skincon_cols)]:
                                skincon_cols_select=np.random.choice(skincon_cols, 
                                                                     size=num_concept, 
                                                                     replace=False, p=None).tolist()                      
                            
                                metadata_select_train=metadata_select[skincon_cols].iloc[train_idx]
                                metadata_select_skincon_train=metadata_select_train[~metadata_select_train["skincon_Vesicle"].isnull()]
                                metadata_select_test=metadata_select[skincon_cols].iloc[test_idx]
                                metadata_select_skincon_test=metadata_select_test[~metadata_select_test["skincon_Vesicle"].isnull()]

                                y_select_train=np.array(y_select)[train_idx][~metadata_select_train["skincon_Vesicle"].isnull()]
                                y_select_test=np.array(y_select)[test_idx][~metadata_select_test["skincon_Vesicle"].isnull()]

                                auc,_,y_test_pred=train_using_manual_labels(xtrain=metadata_select_skincon_train[skincon_cols_select].values,
                                                          xtest=metadata_select_skincon_test[skincon_cols_select].values,
                                                          ytrain=y_select_train,
                                                          ytest=y_select_test 
                                                         )
                                print(task, is_clean, len(y_select), random_seed, method, test_mode, num_concept, f"{auc:.3f}")
                                
                                record_all_list.append({"task": task,
                                                        "is_clean": is_clean,
                                                        "num_sample": len(y_select_train)+len(y_select_test),
                                                        "num_sample_train": len(y_select_train),
                                                        "num_sample_test": len(y_select_test),
                                                        "random_seed": random_seed,                                            
                                                        "method": method+"_"+test_mode,
                                                        "num_concept": num_concept,
                                                        "auc":auc,
                                                        "y_test":y_select_test,
                                                        "y_test_pred":y_test_pred,                                                        
                                                       })
                        elif test_mode=="less_sample":
                            if len(train_idx)>600 and len(train_idx)<700:
                                num_sample_train_select_range=[100, 200, 300, 400, 500, 600, len(train_idx)]
                            elif len(train_idx)>3900 and len(train_idx)<4000:
                                num_sample_train_select_range=[100, 500, 1000, 1500, 2000, 2500, 3000, 3500, len(train_idx)]
                            else:
                                raise
                            for num_sample_train_select in num_sample_train_select_range:
                                train_idx_select=np.random.choice(train_idx, size=num_sample_train_select, replace=False)                                                                    
                                
                                metadata_select_train=metadata_select[skincon_cols].iloc[train_idx_select]
                                metadata_select_skincon_train=metadata_select_train[~metadata_select_train["skincon_Vesicle"].isnull()]
                                metadata_select_test=metadata_select[skincon_cols].iloc[test_idx]
                                metadata_select_skincon_test=metadata_select_test[~metadata_select_test["skincon_Vesicle"].isnull()]

                                y_select_train=np.array(y_select)[train_idx_select][~metadata_select_train["skincon_Vesicle"].isnull()]
                                y_select_test=np.array(y_select)[test_idx][~metadata_select_test["skincon_Vesicle"].isnull()]

                                auc,_,y_test_pred=train_using_manual_labels(xtrain=metadata_select_skincon_train[skincon_cols].values,
                                                          xtest=metadata_select_skincon_test[skincon_cols].values,
                                                          ytrain=y_select_train,
                                                          ytest=y_select_test 
                                                         )

                                print(task, is_clean, len(y_select), random_seed, method, test_mode, num_sample_train_select, f"{auc:.3f}")
                                record_all_list.append({"task": task,
                                                        "is_clean": is_clean,
                                                        "num_sample": len(y_select_train)+len(y_select_test),
                                                        "num_sample_train": len(y_select_train),
                                                        "num_sample_test": len(y_select_test),
                                                        "random_seed": random_seed,                                            
                                                        "method": method+"_"+test_mode,
                                                        "num_sample_train_select": num_sample_train_select,
                                                        "auc":auc,
                                                        "y_test":y_select_test,
                                                        "y_test_pred":y_test_pred,                                                        
                                                       })
                                
                                
                    
                elif method=="resnet_freeze_backbone":
                    auc, y_test_pred=generate_data_and_train_resnet(metadata_select=metadata_select, 
                                                       y_select=y_select, 
                                                       train_idx=train_idx, 
                                                       test_idx=test_idx, 
                                                       freeze_backbone=True)
                    
                    print(task, is_clean, len(y_select), random_seed, method, f"{auc:.3f}")
                    record_all_list.append({"task": task,
                                            "is_clean": is_clean,
                                            "num_sample": len(train_idx)+len(test_idx),
                                            "num_sample_train": len(train_idx),
                                            "num_sample_test": len(test_idx),                                            
                                            "random_seed": random_seed,                                            
                                            "method": method,
                                            "auc":auc,
                                            "y_test":np.array(y_select)[test_idx],
                                            "y_test_pred":y_test_pred,                                              
                                           })                     
                    
                elif method=="resnet":
                    auc, y_test_pred=generate_data_and_train_resnet(metadata_select=metadata_select, 
                                                       y_select=y_select, 
                                                       train_idx=train_idx, 
                                                       test_idx=test_idx)
                    
                    print(task, is_clean, len(y_select), random_seed, method, f"{auc:.3f}")
                    record_all_list.append({"task": task,
                                            "is_clean": is_clean,
                                            "num_sample": len(train_idx)+len(test_idx),
                                            "num_sample_train": len(train_idx),
                                            "num_sample_test": len(test_idx),                                            
                                            "random_seed": random_seed,                                            
                                            "method": method,
                                            "auc":auc,
                                            "y_test":np.array(y_select)[test_idx],
                                            "y_test_pred":y_test_pred,                                             
                                           }) 
                    
                   
                    
                elif method=="automatic":
                    for concept_list_type in ["curated", "skincon"]:
                        if concept_list_type=="curated":
                            concept_list_target=concept_list_curated
                        elif concept_list_type=="skincon":
                            concept_list_target=skincon_cols
                        else:
                            raise NotImplementedError                            
                            
                        for trained in ["monet", "vanilla"]:
                            for test_mode in ["full", "less_concept", "less_reference", "less_sample"]:
                                y_select_train=np.array(y_select)[train_idx]
                                y_select_test=np.array(y_select)[test_idx]                         
                                if trained=="monet":
                                    image_select_train=image_monet_select[train_idx]
                                    image_select_test=image_monet_select[test_idx]
                                    model_select=model
                                elif trained=="vanilla":
                                    image_select_train=image_vanilla_select[train_idx]
                                    image_select_test=image_vanilla_select[test_idx]
                                    model_select=model_vanilla

                                if test_mode=="full":
                                    #for alpha in [0.001, 0.0001]:
                                    for alpha in [0.001]:
                                        #for temp in [0.02, 0.01, 0.005]:
                                        for temp in [0.02]:
                                            clf, auc, y_test_pred=train_with_best_temp_softmax(model=model_select,
                                                                             image_features_train=image_select_train,                           
                                                                             image_features_test=image_select_test, 
                                                                             ytrain=y_select_train, 
                                                                             ytest=y_select_test,
                                                                             concept_list=concept_list_target, 
                                                                             temp=temp,
                                                                             num_ref_concepts=100, alpha=alpha) 
                                            print(task, is_clean, len(y_select), random_seed, method, concept_list_type, trained, test_mode, alpha, temp, f"{auc:.3f}")
                                            record_all_list.append({"task": task,
                                                                    "is_clean": is_clean,
                                                                    "num_sample": len(y_select_train)+len(y_select_test),
                                                                    "num_sample_train": len(y_select_train),
                                                                    "num_sample_test": len(y_select_test),                                                        
                                                                    "random_seed": random_seed,                                            
                                                                    "method": method+"_"+concept_list_type+"_"+trained+"_"+test_mode,
                                                                    "auc": auc,
                                                                    "y_test":y_select_test,
                                                                    "y_test_pred":y_test_pred,                                                                       
                                                                    "alpha": alpha,
                                                                    "temp": temp,
                                                                    "clf":clf
                                                                   })   



                                elif test_mode=="less_concept":
                                    if len(concept_list_target)>20:
                                        num_concept_list=[1]+list(range(5, len(concept_list_target), 5))+[len(concept_list_target)]
                                    else:
                                        num_concept_list=list(range(1, len(concept_list_target)+1))
                                    for num_concept in num_concept_list:
                                        concept_list_target_select=np.random.choice(concept_list_target, 
                                                                                     size=num_concept, 
                                                                                     replace=False, p=None)

                                        clf, auc, y_test_pred=train_with_best_temp_softmax(model=model_select,                                                    
                                                                         image_features_train=image_select_train, 
                                                                         image_features_test=image_select_test, 
                                                                         ytrain=y_select_train, 
                                                                         ytest=y_select_test,
                                                                         concept_list=concept_list_target_select, 
                                                                         temp=0.02,
                                                                         num_ref_concepts=100) 

                                        print(task, is_clean, len(y_select), random_seed, method, concept_list_type, trained, test_mode, num_concept, f"{auc:.3f}")
                                        record_all_list.append({"task": task,
                                                                "is_clean": is_clean,
                                                                "num_sample": len(y_select_train)+len(y_select_test),
                                                                "num_sample_train": len(y_select_train),
                                                                "num_sample_test": len(y_select_test),                                                            
                                                                "random_seed": random_seed,                                            
                                                                "method": method+"_"+concept_list_type+"_"+trained+"_"+test_mode,
                                                                "num_concept": num_concept,
                                                                "auc":auc,
                                                                    "y_test":y_select_test,
                                                                    "y_test_pred":y_test_pred,                                                                  
                                                               })  

                                elif test_mode=="less_reference":
                                    for num_ref_concepts in range(0, 5+1):
                                        clf, auc, y_test_pred=train_with_best_temp_softmax(model=model_select,                       
                                                                         image_features_train=image_select_train,                               
                                                                         image_features_test=image_select_test, 
                                                                         ytrain=y_select_train, 
                                                                         ytest=y_select_test,
                                                                         concept_list=concept_list_target, 
                                                                         temp=0.02,
                                                                         num_ref_concepts=num_ref_concepts)
                                        print(task, is_clean, len(y_select), random_seed, method, concept_list_type, trained, test_mode, num_ref_concepts, f"{auc:.3f}")
                                        record_all_list.append({"task": task,
                                                                "is_clean": is_clean,
                                                                "num_sample": len(y_select_train)+len(y_select_test),
                                                                "num_sample_train": len(y_select_train),
                                                                "num_sample_test": len(y_select_test),                                                               
                                                                "random_seed": random_seed,                                            
                                                                "method": method+"_"+concept_list_type+"_"+trained+"_"+test_mode,
                                                                "num_ref_concepts": num_ref_concepts,
                                                                "auc":auc,
                                                                    "y_test":y_select_test,
                                                                    "y_test_pred":y_test_pred,                                                                  
                                                               })         



                                elif test_mode=="less_sample":
                                    for sample_prop in [0.1, 0.2, 0.4, 0.6, 0.8, 1]:
                                        sample_select=np.random.choice(np.arange(len(image_select_train)), size=int(len(image_select_train)*sample_prop), replace=False)                                    
                                        clf, auc, y_test_pred=train_with_best_temp_softmax(model=model_select,                                                     
                                                                         image_features_train=image_select_train[sample_select], 
                                                                         image_features_test=image_select_test, 
                                                                         ytrain=y_select_train[sample_select], 
                                                                         ytest=y_select_test,
                                                                         concept_list=concept_list_target, 
                                                                         temp=0.02,
                                                                         num_ref_concepts=100)                                    

                                        print(task, is_clean, len(y_select), random_seed, method, concept_list_type, trained, test_mode, sample_prop, f"{auc:.3f}")
                                        record_all_list.append({"task": task,
                                                                "is_clean": is_clean,
                                                                "num_sample": len(y_select_train[sample_select])+len(y_select_test),
                                                                "num_sample_train": len(y_select_train[sample_select]),
                                                                "num_sample_test": len(y_select_test),                                                               
                                                                "random_seed": random_seed,                                            
                                                                "method": method+"_"+concept_list_type+"_"+trained+"_"+test_mode,
                                                                "sample_prop": sample_prop,
                                                                "auc":auc,
                                                                    "y_test":y_select_test,
                                                                    "y_test_pred":y_test_pred,                                                                  
                                                               })
                                else:
                                    raise NotImplementedError
                else:
                    raise NotImplementedError
                #print(task, is_clean, len(y_select), random_seed, method, f"{auc:.3f}")
                #print(y_select)
            #print(method, train_using_manual_labels(xtrain= ))
            pass
            

In [None]:
# torch.save(record_all_list, f=log_dir/"experiment_results"/"cbm_complete_valid_230607.pt")

In [None]:
# pd.set_option('display.max_rows', 50)
x=pd.DataFrame(record_all_list)
x[x["method"]=="skincon_manual_full"]

In [None]:
x["method"].unique()

In [None]:
x[x["method"]=="automatic_skincon_monet_full"]

In [None]:
775*0.8

In [None]:
len(train_idx)

In [None]:
len(y_malignancy)*0.8

In [None]:
len(y_melanoma)

In [None]:
record_all_list_0513=torch.load(f=log_dir/"experiment_results"/"cbm_complete_230513.pt", 
                                map_location="cpu")

In [None]:
record_all_list_0501=torch.load(f=log_dir/"experiment_results"/"cbm_complete_230501.pt", 
                                map_location="cpu")

In [None]:
 cbm_complete_230501.pt

In [None]:
x=pd.DataFrame(record_all_list_0513)

In [None]:
379+304

In [None]:
pd.DataFrame(record_all_list_0501).columns

In [None]:
pd.DataFrame(record_all_list).apply(lambda x: sklearn.metrics.roc_auc_score(x["y_test"], x["y_test_pred"]), 
                                    axis=1)\
-pd.DataFrame(record_all_list)['auc']

In [None]:
record_all_list

In [None]:
~metadata_select_test["skincon_Vesicle"].isnull()

In [None]:
delong_roc_test(ground_truth, predictions_one, predictions_two)

In [None]:
(~metadata_all_melanoma["skincon_Cyst"].isnull()).sum()

In [None]:
record_all_list[19]["y_test_pred"]

In [None]:
record_all_list[-1]["y_test"]

In [None]:
record_all_list_new=torch.load(f=log_dir/"experiment_results"/"cbm_complete_230513.pt")

In [None]:
pd.DataFrame(record_all_list_new)["method"].unique()

In [None]:
score_diff=pd.DataFrame(record_all_list_new).groupby(["task", "method"]).apply(lambda x: x.sort_values("random_seed")["auc"].values)\
.loc["malignancy", "automatic_curated_monet_full"]-\
pd.DataFrame(record_all_list_new).groupby(["task", "method"]).apply(lambda x: x.sort_values("random_seed")["auc"].values)\
.loc["malignancy", "resnet_freeze_backbone"]

In [None]:
cv=40

In [None]:
avg_diff = np.mean(score_diff)

numerator = avg_diff * np.sqrt(cv)
denominator = np.sqrt(
    sum([(diff - avg_diff) ** 2 for diff in score_diff]) / (cv - 1)
)
t_stat = numerator / denominator

pvalue = scipy.stats.t.sf(np.abs(t_stat), cv - 1) * 2.0
float(t_stat), float(pvalue)

In [None]:
len(score_diff)

In [None]:
np.log10(0.001)

In [None]:
record_all_list[19]["auc"], record_all_list[42]["auc"] 

In [None]:
scipy.stats.ttest_ind(a=)

In [None]:
avg_diff = np.mean(score_diff)

numerator = avg_diff * np.sqrt(cv)
denominator = np.sqrt(
    sum([(diff - avg_diff) ** 2 for diff in score_diff]) / (cv - 1)
)
t_stat = numerator / denominator

pvalue = stats.t.sf(np.abs(t_stat), cv - 1) * 2.0
return float(t_stat), float(pvalue)

In [None]:
delong_roc_test(ground_truth=record_all_list[-1]["y_test"], 
             predictions_one=record_all_list[19]["y_test_pred"], 
             predictions_two=np.array(record_all_list[42]["y_test_pred"]))

In [None]:
pd.DataFrame(record_all_list)

In [None]:
# pd.DataFrame(record_all_list)#.apply(lambda x: x["y_test"].sum(), axis=1)

In [None]:
(17*40)/60

In [None]:
sklearn.metrics.roc_auc_score

In [None]:
                            for num_sample_train_select in num_sample_train_select_range:
                                train_idx_select=np.random.choice(train_idx, size=num_sample_train_select, replace=False)                                                                    
                                
                                metadata_select_train=metadata_select[skincon_cols].iloc[train_idx_select]
                                metadata_select_skincon_train=metadata_select_train[~metadata_select_train["skincon_Vesicle"].isnull()]
                                metadata_select_test=metadata_select[skincon_cols].iloc[test_idx]
                                metadata_select_skincon_test=metadata_select_test[~metadata_select_test["skincon_Vesicle"].isnull()]


In [None]:
len(train_idx_select), len(y_select_train)

In [None]:
record_all_list

In [None]:
num_sample_train_select

In [None]:
train_idx.shape

In [None]:
len(train_idx_select)

In [None]:
len(train_idx)

In [None]:
len(train_idx)

In [None]:
len(train_idx)=657

In [None]:
len(y_select_train), len(y_select_test), y_select_train.sum(), y_select_test.sum()

In [None]:
num_sample_train_select

In [None]:
pd.DataFrame(record_all_list).groupby()

In [None]:
pd.DataFrame(record_all_list).groupby(["task", "is_clean", "method"]).median()

In [None]:
malignancy
Use negative 0.811498
Use template 0.812509
No reference 0.804630

melanoma
Use negative 0.907219
Use template 0.877861
No reference 0.832318

In [None]:
for key, value in concept_dict.items():
    print(key+":",", ".join(value[1:]))

In [None]:
record_all_df=pd.DataFrame(record_all_list)

In [None]:
record_all_df.groupby(["task", "is_clean", "method"]).mean()

In [None]:
#torch.save(record_all_list, f=log_dir/"experiment_results"/"cbm_complete_230501.pt")

In [None]:
torch.load(log_dir/"experiment_results"/"cbm_complete_230501.pt")

In [None]:
torch.save(record_all_list, f=log_dir/"experiment_results"/"cbm_complete_230513.pt")

In [None]:
record_all_list=torch.load(log_dir/"experiment_results"/"cbm_complete_230501.pt", map_location="cpu")

In [None]:
#record_all_df[record_all_df["method"].str.contains("skincon_manual")].groupby(["task", "is_clean", "method", "num_concept"]).mean()
#record_all_df[record_all_df["method"].str.contains("skincon_manual")].groupby(["task", "is_clean", "method", "num_sample_train"]).mean()
#record_all_df[record_all_df["method"].str.contains("automatic_skincon_monet")]
#record_all_df[record_all_df["method"].str.contains("automatic_skincon_monet")].groupby(["task", "is_clean", "method"]).mean()

In [None]:
record_all_list=torch.load(f=log_dir/"experiment_results"/"cbm_result.pt")
record_all_list_new=torch.load(f=log_dir/"experiment_results"/"cbm_result_new.pt")

In [None]:
# torch.save(record_all_list, f=log_dir/"cbm.result")

In [None]:
record_all_list

In [None]:
pd.DataFrame(record_all_list)["method"].unique()

In [None]:
from matplotlib import gridspec
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

In [None]:
# !cp /System/Library/Fonts/Supplemental ~/.local/share/fonts/
# rm -fr ~/.cache/matplotlib
from matplotlib import font_manager
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.edgecolor']='1.0'
plt.rcParams['legend.framealpha']=0

# https://github.com/dsc/colorbrewer-python/blob/master/colorbrewer.py

Set1 = {
    3: [[228,26,28], [55,126,184], [77,175,74]],
    4: [[228,26,28], [55,126,184], [77,175,74], [152,78,163]],
    5: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0]],
    6: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51]],
    7: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40]],
    8: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191]],
    9: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191], [153,153,153]],
}

Paired = {
    3: [(166,206,227), [31,120,180], [178,223,138]],
    4: [[166,206,227], [31,120,180], [178,223,138], [51,160,44]],
    5: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153]],
    6: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28]],
    7: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111]],
    8: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0]],
    9: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214]],
    10: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154]],
    11: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153]],
    12: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153], [177,89,40]]
}

color_qual_7=['#F53345',
            '#87D303',
            '#04CBCC',
            '#8650CD',
            (160/256, 95/256, 0),
            '#F5A637',              
            '#DBD783',            
             ]

pd.set_option('display.max_rows', 500)

In [None]:
def shorten_method_name(method_name):
    if method_name=="automatic_monet_full":
        short_name="MONET+CBM (Curated)"
    elif method_name=="automatic_curated_monet_full":
        short_name="MONET+CBM (Curated)"       
    elif method_name=="automatic_curated_vanilla_full":
        short_name="CLIP+CBM"               
    elif method_name=="automatic_skincon_monet_full":
        short_name="MONET+CBM (SkinCon)"   
    elif method_name=="skincon_manual":
        short_name="Manual Label (SkinCon)"
    elif method_name=="skincon_manual_full":
        short_name="Manual Label (SkinCon)"        
    elif method_name=="resnet":
        short_name="Supervised (ResNet-50)"
    elif method_name=="resnet_freeze_backbone":
        short_name="Linear probing (ResNet-50)"
    elif method_name=="automatic_vanilla_full":
        short_name="CLIP+CBM"
        
    return short_name

In [None]:
#record_all_list=torch.load(f=log_dir/"experiment_results"/"cbm_complete_new.pt")

In [None]:
record_all_list=torch.load(f=log_dir/"experiment_results"/"cbm_complete_230513.pt", 
                           map_location="cpu")

In [None]:
record_all_df=pd.DataFrame(record_all_list)

In [None]:
record_all_df["method"].unique()

In [None]:
record_all_df[record_all_df["method"]=="automatic_skincon_monet_full"]

In [None]:
record_all_df[record_all_df["method"]=="skincon_manual_full"]

In [None]:
plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                Paired[12][3],
                                                                                Paired[12][5],
                                                                                Paired[12][7],
                                                                                Paired[12][9],
                                                                                Paired[12][11]
                                                                                ]])

main_method_list=["automatic_curated_monet_full", 
                  "skincon_manual_full", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  "automatic_curated_vanilla_full"]


fig = plt.figure(figsize=(3*10, 3*(3 + 2.5 + 4 + 3 + 0.35*3)))

box1 = gridspec.GridSpec(4,1,
                         height_ratios=[3, 2.5, 4, 3],
                         wspace=0.0,
                         hspace=0.35)


# temp array([  nan, 0.02 , 0.01 , 0.005])
# alpha array([0.001 , 0.0001,    nan])
alpha=0.001
temp=0.02

axd={}
for idx1, stage in enumerate(["overview", "skincon", "performance", "weight"]):
    if stage=="overview":
        plot_key=stage
        ax=plt.Subplot(fig, box1[idx1])
        fig.add_subplot(ax) 
        axd[plot_key]=ax
    elif stage=="empty":
        plot_key=stage
        ax=plt.Subplot(fig, box1[idx1])
        fig.add_subplot(ax) 
        axd[plot_key]=ax        
    elif stage=="performance":
        box2 = gridspec.GridSpecFromSubplotSpec(1, 4,
                        subplot_spec=box1[idx1], 
                        width_ratios=[0.1, 1, 0.1, 1], wspace=0., hspace=0.)        
        for idx2, task in enumerate(["empty_malignancy", "malignancy", "empty_melanoma", "melanoma"]):
#             elif investigation_type=="statistics":
            plot_key=f"{stage}_{task}"
            ax=plt.Subplot(fig, box2[idx2])
            fig.add_subplot(ax)
            axd[plot_key]=ax  
            
    elif stage=="weight":
        box2 = gridspec.GridSpecFromSubplotSpec(1, 4,
                        subplot_spec=box1[idx1], 
                        width_ratios=[0.15, 1, 0.15, 1], wspace=0.1, hspace=0.)        
        for idx2, task in enumerate(["empty_malignancy", "malignancy", "empty_melanoma", "melanoma"]):
#             elif investigation_type=="statistics":
            plot_key=f"{stage}_{task}"
            ax=plt.Subplot(fig, box2[idx2])
            fig.add_subplot(ax)
            axd[plot_key]=ax  
            
    elif stage=="skincon":
        box2 = gridspec.GridSpecFromSubplotSpec(1, 8,
                        subplot_spec=box1[idx1], width_ratios=[0.15, 1, 0.15, 1,     0.27, 1, 0.15, 1], wspace=0.0, hspace=0.)        
        
        
        for idx2, variable in enumerate(["empty", "malignancy_num_concept", "empty1", "malignancy_num_sample", 
                                         "empty2", "melanoma_num_concept", "empty3", "melanoma_num_sample"]):
#             elif investigation_type=="statistics":
            plot_key=f"{stage}_{variable}"
            ax=plt.Subplot(fig, box2[idx2])
            fig.add_subplot(ax)
            axd[plot_key]=ax              
            
        
for plot_key in axd.keys():
    if 'overview' in plot_key:
        axd[plot_key].set_xticks([])
        axd[plot_key].set_yticks([])
        for axis in ['top','bottom','left','right']:
            axd[plot_key].spines[axis].set_linewidth(0) 
            
    if 'empty' in plot_key:
        axd[plot_key].set_xticks([])
        axd[plot_key].set_yticks([])
        for axis in ['top','bottom','left','right']:
            axd[plot_key].spines[axis].set_linewidth(0)           
        
for idx1, stage in enumerate(["overview", "performance", "weight", "skincon"]):
    if stage=="overview":
        plot_key=stage
        
        axd[plot_key].text(x=-0.0, y=1.0, transform=axd[plot_key].transAxes,
                                 s="A", fontsize=35, weight='bold')           

    elif stage=="performance":   
        for idx2, task in enumerate(["empty_malignancy", "malignancy", "empty_melanoma", "melanoma"]):
#             elif investigation_type=="statistics":
            plot_key=f"{stage}_{task}"
            
            if task=="malignancy" or task=="melanoma":
                record_all_df_perf=pd.DataFrame(record_all_list)
                record_all_df_perf=record_all_df_perf[record_all_df_perf["random_seed"]<20]
                record_all_df_perf_filtered=record_all_df_perf[record_all_df_perf["is_clean"]=="clean_only"]
                record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["method"].isin(main_method_list)]
                record_all_df_perf_filtered=record_all_df_perf_filtered[(~record_all_df_perf_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))|\
                                                             ((record_all_df_perf_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))&(record_all_df_perf_filtered["alpha"]==alpha)&(record_all_df_perf_filtered["temp"]==temp))]
                
#                 record_all_df_new=pd.DataFrame(record_all_list_new)
#                 record_all_df_filtered_new=record_all_df_new[record_all_df_new["is_clean"]=="clean_only"]
#                 record_all_df_filtered_new=record_all_df_filtered_new[record_all_df_filtered_new["method"]=="automatic_skincon_monet_full"]
#                 record_all_df_filtered=pd.concat([record_all_df_filtered, record_all_df_filtered_new], axis=0)
#                 dsdsdsds
                if task=="malignancy":
#                     dsds
                    record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["task"]=="malignancy"]
                    axd[plot_key].set_ylim(0.49, 1.01)
                elif task=="melanoma":
                    record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["task"]=="melanoma"]
                    axd[plot_key].set_ylim(0.49, 1.01)
                else:
                    raise ValueError
#                 sdsdsd
                b=sns.boxplot(x="method", y="auc", 
                              order=main_method_list,
                              width=0.5,
                              linewidth=3,
                              saturation=1.3,
                              boxprops=dict(alpha=.9),
                              data=record_all_df_perf_filtered, 
                            ax=axd[plot_key])
                
                
                sns.swarmplot(x="method", y="auc", 
                              order=main_method_list,
                              color='black', 
                              alpha=0.8,
                              size=9,
                              data=record_all_df_perf_filtered, ax=axd[plot_key])
 

                record_all_df_perf_filtered_pvalue=record_all_df_perf_filtered.groupby("task")\
                .apply(lambda x: x.groupby("method")\
                .apply(lambda y: 
                scipy.stats.ttest_rel(y.set_index('random_seed')["auc"], 
                x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"], alternative="less").pvalue)).T

    
                count=0
                for method_name in main_method_list[::-1]:
#                     continue
                    if method_name=="automatic_curated_monet_full":
                        continue
                        
                    pvalue_x1=main_method_list.index("automatic_curated_monet_full")
                    pvalue_x2=main_method_list.index(method_name)
                    
                    pvalue_y=record_all_df_perf_filtered[record_all_df_perf_filtered["method"]=="automatic_curated_monet_full"]["auc"].max()
                    pvalue_y_=record_all_df_perf_filtered[record_all_df_perf_filtered["method"]==method_name]["auc"].max()
                    
                    
                    print(method_name, record_all_df_perf_filtered_pvalue.loc[method_name][task])
                    if record_all_df_perf_filtered_pvalue.loc[method_name][task]<0.001:
                        pvalue_str="***"
                    elif record_all_df_perf_filtered_pvalue.loc[method_name][task]<0.01:
                        pvalue_str="**"
                    elif record_all_df_perf_filtered_pvalue.loc[method_name][task]<0.05:
                        pvalue_str="*"                        
                    else:
                        pvalue_str="ns"  
                        
                    axd[plot_key].text((pvalue_x2), 
                                       pvalue_y_+0.0035,
                             s=pvalue_str, 
                             fontsize=25,
                             ha='center', 
                             va='bottom', 
                             color="k")
                    count+=1   
                    continue
                    
                    
                    axd[plot_key].plot([pvalue_x1, 
                                        pvalue_x1, 
                                        pvalue_x2, 
                                        pvalue_x2], 
                                       [pvalue_y+0.013+0.023*(count),
                                        pvalue_y+0.013+0.023*(count)+0.005, 
                                        pvalue_y+0.013+0.023*(count)+0.005, 
                                        pvalue_y+0.013+0.023*(count)], 
                                       lw=3, c='black')
                      
                    

                    
                print(task, record_all_df_perf_filtered.groupby("method")["auc"].apply(lambda x: {"mean": x.mean(),
                                                                      "std": x.std(),
                                                                      "q3": x.quantile(q=0.75),                                                                      
                                                                      "median": x.median(),
                                                                      "q1": x.quantile(q=0.25),
                                                                     }))

                if task=="malignancy":
                    axd[plot_key].set_title("Malignancy", fontsize=30)
                if task=="melanoma":
                    axd[plot_key].set_title("Melanoma", fontsize=30)

                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(3)                

                if idx2==1: 
                    axd[plot_key].set_ylabel('Area under the ROC curve', fontsize=25)
                if idx2==3: 
                    axd[plot_key].set_ylabel('')                    

                axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
                axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
                axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
                #axd[plot_key].yaxis.grid(True, which='minor', linewidth=2, alpha=0.1)

                axd[plot_key].spines['right'].set_visible(False)
                axd[plot_key].spines['top'].set_visible(False)                   
                axd[plot_key].spines['bottom'].set_visible(False)    

                axd[plot_key].tick_params(axis='x', which='major', labelsize=25)
                axd[plot_key].tick_params(axis='y', which='major', labelsize=25)   

                axd[plot_key].tick_params(
                    axis='x',          # changes apply to the x-axis
                    which='both',      # both major and minor ticks are affected
                    bottom=False,
                    labelbottom=False,      # ticks along the bottom edge are off
                    )            

                axd[plot_key].set_xlabel(None)
            
            if task=="empty_malignancy":
                axd[plot_key].text(x=-0.05, y=1.02, transform=axd[plot_key].transAxes,
                                     s="F", fontsize=35, weight='bold') 
                
            if task=="malignancy":
                
                legend_elements=[Patch(facecolor=plt.rcParams["axes.prop_cycle"].by_key()['color'][method_idx], 
                                       edgecolor="black", linewidth=2, 
                                       label=shorten_method_name(method_name)) for method_idx, method_name in enumerate(main_method_list)]
                

                axd[plot_key].legend(handles=legend_elements, 
                            ncol=5, 
                            handlelength=2.5,
                            handletextpad=0.4, 
                            columnspacing=1.3,
                            fontsize=23,
                            loc='lower center', bbox_to_anchor=(1., -0.1))                   
                
                #axd[plot_key].set_ylabel('Concepts', fontsize=30)
                
                
            if task=="empty_melanoma":
                axd[plot_key].text(x=-0.0, y=1.02, transform=axd[plot_key].transAxes,
                                     s="G", fontsize=35, weight='bold')  
                
                
                
                
        
    elif stage=="weight":  
        for idx2, task in enumerate(["empty_malignancy", "malignancy", "empty_melanoma", "melanoma"]):
            plot_key=f"{stage}_{task}"
            if task=="malignancy" or task=="melanoma":
                record_all_df_weight=pd.DataFrame(record_all_list)
                record_all_df_weight=record_all_df_weight[record_all_df_weight["random_seed"]<20]
                record_all_df_weight_filtered=record_all_df_weight[record_all_df_weight["is_clean"]=="clean_only"]
                record_all_df_weight_filtered=record_all_df_weight_filtered[(record_all_df_weight_filtered["method"]=="automatic_curated_monet_full")&(record_all_df_weight_filtered["alpha"]==alpha)&(record_all_df_weight_filtered["temp"]==temp)]
                
#                 record_all_df_filtered=record_all_df_filtered[record_all_df_filtered["method"].isin(main_method_list)]
#                 record_all_df_filtered=record_all_df_filtered[(~record_all_df_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))|\
#                                                              ((record_all_df_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))&(record_all_df_filtered["alpha"]==alpha)&(record_all_df_filtered["temp"]==temp))]                
                
                #print(record_all_df_filtered["alpha"])

                if task=="malignancy":
                    record_all_df_weight_filtered=record_all_df_weight_filtered[record_all_df_weight_filtered["task"]=="malignancy"]
                elif task=="melanoma":
                    record_all_df_weight_filtered=record_all_df_weight_filtered[record_all_df_weight_filtered["task"]=="melanoma"]
                else:
                    raise ValueError

                coef_dict_list=[]
                for clf_idx, clf in enumerate(record_all_df_weight_filtered["clf"]):
                    for concept_name, coef in zip(concept_list_curated, clf.coef_[0]):
                        coef_dict_list.append({"concept_name": concept_name,
                                               "coef": coef,
                                               "clf_idx": clf_idx
                                              })


                #pd.DataFrame(coef_dict_list).groupby("concept_name")["coef"].to_csv(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}_{task}.csv")
                print(task, pd.DataFrame(coef_dict_list).groupby("concept_name")["coef"].mean())
                
                
                weight_bar=sns.barplot(y="concept_name", x="coef", 
                            color=np.array(Paired[12][7])/256,
                            edgecolor='black',
                            linewidth=2,
                            width=0.5,    
                            order=['Asymmetry', 'Irregular', 'Erosion',
                                       'Black', 'Blue', 'White', 'Brown',
                                    'Multiple Colors', 'Tiny', 'Regular'],
                            errwidth=5,
                                       
                            data=pd.DataFrame(coef_dict_list), ax=axd[plot_key])
                
#                 for container in weight_bar.containers:
#                     axd[plot_key].bar_label(container)                
                    
#                 for p in weight_bar.patches:
#                     _x = p.get_x() + p.get_width() / 2
#                     _y = p.get_y() + p.get_height()
#                     value = '{:.2f}'.format(p.get_height())
#                     weight_bar.text(_x, _y, value, ha="center")  
                
                #rint(axd[plot_key].get_yticks())
        
#                 bar_labels=[]
#                 for concept_name in concept_list_curated:
#                     temp=pd.DataFrame(coef_dict_list)
#                     temp=temp[temp["concept_name"]==concept_name]
                    
#                     if temp["coef"].mean()>3:
#                         bar_labels.append()
#                     else:
#                         bar_labels.append("")
#                 print(temp, bar_labels)
#                 axd[plot_key].bar_label(weight_bar.containers[0], labels=bar_labels, fontsize=20)
    
                for p, concept_name in zip(weight_bar.patches, ['Asymmetry', 'Irregular', 'Erosion',
                                       'Black', 'Blue', 'White', 'Brown',
                                    'Multiple Colors', 'Tiny', 'Regular']):
                    _x = p.get_x() + p.get_width() / 2
                    _y = p.get_y() + p.get_height()
                    
                    coef_dict_list_df=pd.DataFrame(coef_dict_list)
                    coef_dict_list_df=coef_dict_list_df[coef_dict_list_df["concept_name"]==concept_name]
                    
                    if coef_dict_list_df["coef"].mean()>3:
                        value=f"{coef_dict_list_df['coef'].mean():.2f} (±{1.96*coef_dict_list_df['coef'].std()/ np.sqrt(len(coef_dict_list_df)) :.2f})"
                        axd[plot_key].text(2.8, _y+0.6, value, ha="center", fontsize=20, zorder=100)
                    
                    #value = '{:.2f}'.format(p.get_height())
                    
                
#                 for c in weight_bar.containers:
#                     c_mean=c.datavalues.mean()
#                     c_mean=np.round(c_mean,2)
#                     ci=1.96*c.datavalues.std()/np.sqrt(len(c.datavalues))
#                     ci=np.round(ci,2)
                    #axd[plot_key].bar_label(c, labels=[f"{c_mean:.2f} (±{ci})"], fontsize=20)

    #             sns.boxplot(x="method", y="auc", 
    #                         data=record_all_df_filtered, 
    #                         width=0.5,
    #                         linewidth=3,
    #                         ax=axd[plot_key])
    #             sns.swarmplot(x="method", y="auc", 
    #                           color='black', 
    #                           alpha=0.8,
    #                           size=10,
    #                           data=record_all_df_filtered, ax=axd[plot_key])


#             if task=="empty_malignancy":
#                 axd[plot_key].text(x=-0.0, y=1.0, transform=axd[plot_key].transAxes,
#                                                          s="D", fontsize=35, weight='bold')
#             if task=="empty_melanoma":
#                 axd[plot_key].text(x=-0.0, y=1.0, transform=axd[plot_key].transAxes,
#                                                          s="E", fontsize=35, weight='bold')


            if task=="empty_malignancy":
                axd[plot_key].text(x=-0., y=1.05, transform=axd[plot_key].transAxes,
                                     s="H", fontsize=35, weight='bold')              
                
                #axd[plot_key].set_ylabel('Concepts', fontsize=30)
                
                
            if task=="empty_melanoma":
                axd[plot_key].text(x=-0.1, y=1.05, transform=axd[plot_key].transAxes,
                                     s="I", fontsize=35, weight='bold')    

            if task=="malignancy":
                axd[plot_key].set_title("Malignancy", fontsize=30, pad=20)
            if task=="melanoma":
                axd[plot_key].set_title("Melanoma", fontsize=30, pad=20)

            if task=="malignancy" or task=="melanoma":
                if task=="malignancy":
                    axd[plot_key].set_xlim(-3,3)
                elif task=="melanoma":
                    axd[plot_key].set_xlim(-3,3)
                else:
                    raise ValueError            
            
                axd[plot_key].axvline(x=0, ymin=0, ymax=1, color='black', alpha=0.7, linewidth=5, zorder=-5)

                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(3)                

                axd[plot_key].set_ylabel('Area under the curve', fontsize=25)

    #             axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
    #             axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
                axd[plot_key].xaxis.set_major_locator(MultipleLocator(0.5))
                axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.1))
                axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        
                axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                #axd[plot_key].yaxis.grid(True, which='minor', linewidth=2, alpha=0.1)

                axd[plot_key].spines['left'].set_visible(False)
                axd[plot_key].spines['right'].set_visible(False)
                axd[plot_key].spines['top'].set_visible(False)                   
                #axd[plot_key].spines['bottom'].set_visible(False)    

                axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
                axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

                axd[plot_key].tick_params(
                    axis='x',          # changes apply to the x-axis
                    which='both',      # both major and minor ticks are affected
                    top=False,
                    labeltop=False,      # ticks along the bottom edge are off                    
                    bottom=True,
                    labelbottom=True,      # ticks along the bottom edge are off
                    )            

                axd[plot_key].set_ylabel(None)
                axd[plot_key].set_xlabel("Coefficients of linear model", fontsize=25, labelpad=5)
                #axd[plot_key].xaxis.set_label_position('top') 
    elif stage=="skincon":       
        for idx2, variable in enumerate(["malignancy_num_concept", "malignancy_num_sample", 
                                         "melanoma_num_concept", "melanoma_num_sample", ]):
            plot_key=f"{stage}_{variable}"    
            
            record_all_df_skincon=pd.DataFrame(record_all_list)
            record_all_df_skincon=record_all_df_skincon[record_all_df_skincon["random_seed"]<20]
            record_all_df_skincon_filtered=record_all_df_skincon[record_all_df_skincon["is_clean"]=="clean_only"]
            record_all_df_skincon_filtered=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"].str.contains("skincon")]
            

            if variable.startswith("malignancy"):
                record_all_df_skincon_filtered=record_all_df_skincon_filtered[record_all_df_skincon_filtered["task"]=="malignancy"]
            elif variable.startswith("melanoma"):
                record_all_df_skincon_filtered=record_all_df_skincon_filtered[record_all_df_skincon_filtered["task"]=="melanoma"]
            else:
                raise ValueError
            
            
            record_all_df_skincon_filtered_ref=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]=="automatic_skincon_monet_full"]
            if variable.endswith("num_sample"):
                record_all_df_skincon_filtered_obs=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]=="skincon_manual_less_sample"]
                if variable.startswith("melanoma"):
                    pass
#                     sdsd
                record_all_df_skincon_filtered_obs_sample_prop=record_all_df_skincon_filtered_obs.groupby("num_sample_train_select")["num_sample_train"].mean()
                record_all_df_skincon_filtered_obs["num_sample_train_pseudo"]=record_all_df_skincon_filtered_obs.apply(lambda x: record_all_df_skincon_filtered_obs_sample_prop[x["num_sample_train_select"]], axis=1)
                b=sns.lineplot(x="num_sample_train_pseudo", y="auc", 
                               color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1], linewidth=4,
                               data=record_all_df_skincon_filtered_obs, ax=axd[plot_key])
                if b.legend_ is not None:
                    b.legend_.remove()
                #axd[plot_key].axhline(y=ref_value, xmin=0, xmax=100000, color='red', alpha=0.7, linewidth=3, zorder=-5)
                
                
                axd[plot_key].scatter(0, record_all_df_skincon_filtered_ref["auc"].mean(), s=300, marker='X', color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0])
                print(stage, variable, record_all_df_skincon_filtered_ref["auc"].mean())
#                 record_all_df_skincon_filtered_ref_matched=pd.DataFrame(list(itertools.product(record_all_df_skincon_filtered_obs_sample_prop.values, 
#                                                     record_all_df_skincon_filtered_ref["auc"].values)),
#                             columns=["num_sample_train_pseudo", "auc"])                
#                 b=sns.lineplot(x="num_sample_train_pseudo", y="auc", 
#                                color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], linewidth=4, linestyle="--",
#                                data=record_all_df_skincon_filtered_ref_matched, ax=axd[plot_key])
#                 if b.legend_ is not None:
#                     b.legend_.remove()
                #axd[plot_key].set_xlim(record_all_df_skincon_filtered_obs["num_sample_train_pseudo"].min(), record_all_df_skincon_filtered_obs["num_sample_train_pseudo"].max())
                
            elif variable.endswith("num_concept"):
                record_all_df_skincon_filtered_obs=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]=="skincon_manual_less_concept"]
                b=sns.lineplot(x="num_concept", y="auc", 
                               color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1], linewidth=4,
                               data=record_all_df_skincon_filtered_obs, ax=axd[plot_key])
                if b.legend_ is not None:
                    b.legend_.remove()
                    
                    
                axd[plot_key].scatter(48, record_all_df_skincon_filtered_ref["auc"].mean(), s=300, marker='X', color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0])
                print(stage, variable, record_all_df_skincon_filtered_ref["auc"].mean())
                #sds
                #axd[plot_key].axhline(y=ref_value, xmin=0, xmax=100000, color='red', alpha=0.7, linewidth=3, zorder=-5)                    
                
#                 record_all_df_skincon_filtered_ref_matched=pd.DataFrame(list(itertools.product(record_all_df_skincon_filtered_obs["num_concept"].unique(), 
#                                                     record_all_df_skincon_filtered_ref["auc"].values)),
#                             columns=["num_sample_train_pseudo", "auc"])                
#                 b=sns.lineplot(x="num_sample_train_pseudo", y="auc", 
#                                color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], linewidth=4, linestyle="--",
#                                data=record_all_df_skincon_filtered_ref_matched, ax=axd[plot_key])                
#                 if b.legend_ is not None:
#                     b.legend_.remove()
                
                
                
            
        
#             dsds  
#             ref_value=record_all_df_skincon_filtered_ref[record_all_df_skincon_filtered_ref["task"]=="melanoma"]["auc"].mean()
            
#             #

#             #sns.lineplot(x="sample_prop", y="auc", hue="task", style="method", data=record_all_df_skincon_filtered, ax=axd[plot_key])
#             dsdsd


#             
#             axd[plot_key].set_xlim(0, record_all_df_skincon_filtered["num_sample_train"].max())
#                 #axd[plot_key].set_xlim(1-0.2, 11.5)

#             record_all_df_skincon_filtered     


            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)    
            axd[plot_key].tick_params(axis='y', which='major', labelsize=20)
            
            if variable.endswith("num_concept"):
                axd[plot_key].xaxis.set_major_locator(MultipleLocator(10))
                axd[plot_key].xaxis.set_minor_locator(MultipleLocator(5))
                axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
                axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
                
                
                axd[plot_key].set_xlabel("Num. of concepts", fontsize=25)
                axd[plot_key].set_xlim(-0.1, 49)
                
                
#                 axd[plot_key].tick_params(
#                     axis='y',          # changes apply to the x-axis
#                     which='both',      # both major and minor ticks are affected
#                     labelleft=False)                
                
            elif variable=="num_reference":
                axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
                axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
                
                
                       
            elif variable.endswith("num_sample"):
                #axd[plot_key].set_xticks([0.05, 0.1 , 0.2 , 0.4 , 0.6 , 0.8 , 1.])
                #.set_xticks([2,4,6,8,10])
                if variable.startswith("malignancy"):
                    axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
                    axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))                
                elif variable.startswith("malignancy"):
                    axd[plot_key].xaxis.set_major_locator(MultipleLocator(100))
                    axd[plot_key].xaxis.set_minor_locator(MultipleLocator(50))                                    
                
                axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
                axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
                
                axd[plot_key].set_xlabel("Num. of expert-labeled samples", fontsize=25)#, labelpad=-10)
            
            if idx2==0:
                axd[plot_key].set_ylabel("Area under the ROC curve", fontsize=25)
            else:
                axd[plot_key].set_ylabel(None)    
            
                
                #axd[plot_key].tick_params(axis='x', which='major', left=False, labelleft=False)
                
            if variable.startswith("malignancy"):
                #axd[plot_key].set_ylim(0.61, 0.98)
                #axd[plot_key].set_ylim(0.531, 0.881)
                axd[plot_key].set_ylim(0.49, 1.01)
                axd[plot_key].set_title("Malignancy", fontsize=25, pad=20)
            elif variable.startswith("melanoma"):                
                axd[plot_key].set_title("Melanoma", fontsize=25, pad=20)
                axd[plot_key].set_ylim(0.49, 1.01)
                
            if variable.endswith("num_sample"):
                if variable.startswith("malignancy"):
                    axd[plot_key].set_xlim(left=-50)
                elif variable.startswith("melanoma"):    
                    axd[plot_key].set_xlim(left=-10)
                
            if variable.endswith("num_concept"):
                axd[plot_key].set_xlim(left=-1)
                axd[plot_key].set_xlim(right=50) 
                
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(3)                     
            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            axd[plot_key].text(x=-0.15, y=1.05, transform=axd[plot_key].transAxes,
                                     s=["B", "C", "D", "E"][idx2], fontsize=35, weight='bold')  
            
            
            
            
            if idx2==1:
                legend_elements=[Line2D([], [], color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], label="MONET+CBM (SkinCon)", linestyle='None', marker='X', markersize=20),
                                 Line2D([0], [0], color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1], linewidth=10, label="Manual Label (SkinCon)")]
                axd[plot_key].legend(handles=legend_elements, 
                            ncol=2, 
                            handlelength=3,
                            handletextpad=0.6, 
                            columnspacing=1.5,
                            fontsize=23,
                            loc='lower center', 
                            bbox_to_anchor=(1, -0.33)).set_zorder(100)              
            
            #record_all_df_filtered=record_all_df_filtered[record_all_df_filtered["method"]=="automatic_monet_full"]    

# fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}.pdf", bbox_inches='tight')

# fig.savefig(log_dir/"plots"/"main_cbm.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm.pdf", bbox_inches='tight')
# # plt.close(fig)

In [None]:
x1 = np.array([65, 75, 86, 69, 60, 81,  88, 53, 75, 73])
x2  = np.array([77, 98, 92, 77, 65, 77, 100, 73, 93, 75])

In [None]:
np.mean(x1-x2)/np.std(x1-x2, ddof=1)

In [None]:
record_all_df_perf_filtered_pvalue=record_all_df_perf_filtered.groupby("task")\
.apply(lambda x: x.groupby("method")\
.apply(lambda y: 
       
    
    (
        scipy.stats.ttest_rel(y.set_index('random_seed')["auc"], 
        x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"], alternative="less").pvalue,
       
        (np.mean(y.set_index('random_seed')["auc"]-x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"]))\
        /(np.std(y.set_index('random_seed')["auc"]-x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"], ddof=1)),
    
        (np.mean(y.set_index('random_seed')["auc"]-x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"]))
    )
      
      )
      
).T


In [None]:
record_all_df_perf_filtered_pvalue.loc["automatic_curated_vanilla_full"]["melanoma"]

In [None]:
fig.savefig(log_dir/"plots"/"main_cbm.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"main_cbm.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"main_cbm.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"main_cbm.pdf", bbox_inches='tight')

In [None]:
import scipy

In [None]:
# fig.savefig(log_dir/"plots"/"main_cbm.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm.pdf", bbox_inches='tight')

In [None]:
aa

In [None]:
task

In [None]:
record_all_df_perf_filtered.shape

In [None]:
record_all_df_perf_filtered

In [None]:
record_all_df_perf=pd.DataFrame(record_all_list)

In [None]:
record_all_df_perf.groupby(["task", "method"]).mean()

In [None]:
pd.DataFrame(coef_dict_list)["concept_name"].unique()

In [None]:
fig.savefig(log_dir/"plots"/"main_cbm.png", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"main_cbm.jpg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"main_cbm.svg", bbox_inches='tight')
fig.savefig(log_dir/"plots"/"main_cbm.pdf", bbox_inches='tight')

In [None]:
import scipy

In [None]:
record_all_df_perf_filtered

In [None]:
record_all_df_perf_filtered

In [None]:
# record_all_df_perf_filtered.groupby(["task","random_seed"]).apply(lambda x: 
# ((x[x["method"]=="resnet"]["auc"].iloc[0]-x[x["method"]=="automatic_curated_monet_full"]["auc"].iloc[0])>0)
# )

record_all_df_perf=pd.DataFrame(record_all_list)
record_all_df_perf=record_all_df_perf[record_all_df_perf["random_seed"]<20]
record_all_df_perf_filtered=record_all_df_perf[record_all_df_perf["is_clean"]=="clean_only"]
record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["method"].isin(main_method_list)]
record_all_df_perf_filtered=record_all_df_perf_filtered[(~record_all_df_perf_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))|\
                                             ((record_all_df_perf_filtered["method"].isin(["automatic_curated_vanilla_full", "automatic_curated_monet_full"]))&(record_all_df_perf_filtered["alpha"]==alpha)&(record_all_df_perf_filtered["temp"]==temp))]

record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["task"]=="malignancy"]
# record_all_df_perf_filtered=record_all_df_perf_filtered[record_all_df_perf_filtered["task"]=="melanoma"]

record_all_df_perf_filtered.groupby("task")\
.apply(lambda x: x.groupby("method")\
.apply(lambda y: 
((y.set_index('random_seed')["auc"] - x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"])>0)\
.sum())).T

In [None]:
# record_all_df_perf_filtered.groupby(["task","random_seed"]).apply(lambda x: 
# ((x[x["method"]=="resnet"]["auc"].iloc[0]-x[x["method"]=="automatic_curated_monet_full"]["auc"].iloc[0])>0)
# )

record_all_df_perf_filtered.groupby("task")\
.apply(lambda x: x.groupby("method")\
.apply(lambda y: 
1-scipy.stats.ttest_rel(y.set_index('random_seed')["auc"], 
x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"], alternative="greater").pvalue)).T

In [None]:
# record_all_df_perf_filtered.groupby(["task","random_seed"]).apply(lambda x: 
# ((x[x["method"]=="resnet"]["auc"].iloc[0]-x[x["method"]=="automatic_curated_monet_full"]["auc"].iloc[0])>0)
# )

record_all_df_perf_filtered.groupby("task")\
.apply(lambda x: x.groupby("method")\
.apply(lambda y: 
scipy.stats.ttest_rel(y.set_index('random_seed')["auc"], 
x[x["method"]=="automatic_curated_monet_full"].set_index("random_seed")["auc"], alternative="less").pvalue)).T

In [None]:
from scipy.stats import wilcoxon

In [None]:
record_all_df_skincon_filtered["method"].unique()

In [None]:
record_all_df_skincon_filtered_obs.plot.scatter(x="num_sample", y="auc")

In [None]:
record_all_df_skincon_filtered_obs.plot.scatter(x="sample_prop", y="auc")

In [None]:
record_all_df_skincon_filtered_obs.plot.scatter(x="num_sample_train", y="auc")

In [None]:
record_all_df_skincon_filtered_ref

In [None]:
#                 record_all_df_skincon_filtered_ref_matched=pd.DataFrame(list(itertools.product(record_all_df_skincon_filtered_obs_sample_prop.values, 
#                                                     record_all_df_skincon_filtered_ref["auc"].values)),
#                             columns=["num_sample_train_pseudo", "auc"])                
#                 b=sns.lineplot(x="num_sample_train_pseudo", y="auc", 
#                                color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], linewidth=4, linestyle="--",
#                                data=record_all_df_skincon_filtered_ref_matched, ax=axd[plot_key])
#                 if b.legend_ is not None:
#                     b.legend_.remove()

In [None]:
record_all_df_perf_filtered

In [None]:
record_all_df_perf_filtered.groupby(["method"]).apply(lambda x: len(x))

In [None]:
main_method_list

In [None]:
import itertools

In [None]:
record_all_df_perf_filtered.groupby("method")["auc"].apply(lambda x: {"mean": x.mean(),
                                                                      "std": x.std(),
                                                                      "q3": x.quantile(q=0.75),                                                                      
                                                                      "median": x.median(),
                                                                      "q1": x.quantile(q=0.25),
                                                                     })

In [None]:
pd.DataFrame(coef_dict_list).groupby("concept_name")["coef"].to_csv(log_dir/"plots"/f"main_cbm_a_{alpha:.1e}_t_{temp:.1e}_{task}.csv")




In [None]:
record_all_df_weight_filtered.groupby("method")["auc"].mean()

In [None]:
record_all_list

In [None]:
coef_dict_list

In [None]:
record_all_df_perf_filtered

In [None]:
import itertools

In [None]:
record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"].str.contains("skincon")]

In [None]:
temp

In [None]:
f"main_cbm_a={alpha:.1e}_t={temp:.1e}.png"

In [None]:
(record_all_df_filtered["method"]!="automatic_curated_monet_full")&\
((record_all_df_filtered["method"]=="automatic_curated_monet_full")&(record_all_df_filtered["alpha"]==alpha))


In [None]:
main_method_list

In [None]:
"automatic_curated_monet_full"

In [None]:
record_all_df_filtered["temp"].unique()
# temp array([  nan, 0.02 , 0.01 , 0.005])
# alpha array([0.001 , 0.0001,    nan])

In [None]:
record_all_df_filtered["alpha"].unique()

In [None]:
plt.rcParams["axes.prop_cycle"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], 
                                                                                Paired[12][3],
                                                                                Paired[12][5],
                                                                                Paired[12][7],
                                                                                Paired[12][9],
                                                                                Paired[12][11]
                                                                                ]])

main_method_list=["automatic_monet_full", 
                  "skincon_manual", 
                  "resnet", 
                  "resnet_freeze_backbone", 
                  "automatic_vanilla_full"]


fig = plt.figure(figsize=(3*10, 3*(2.5) + 0.25*3))

box1 = gridspec.GridSpec(1, 1,
                         height_ratios=[2.5],
                         wspace=0.0,
                         hspace=0.25)

axd={}
for idx1, stage in enumerate(["ablation"]):
            
    if stage=="ablation":
        box2 = gridspec.GridSpecFromSubplotSpec(1, 2,
                        subplot_spec=box1[idx1], width_ratios=[1, 1], wspace=0.2, hspace=0.)        
        for idx2, variable in enumerate(["num_concept", "num_samples"]):
#             elif investigation_type=="statistics":
            plot_key=f"{stage}_{variable}"
            ax=plt.Subplot(fig, box2[idx2])
            fig.add_subplot(ax)
            axd[plot_key]=ax    
            
for idx1, stage in enumerate(["ablation"]):            
    if stage=="ablation":
  
        for idx2, variable in enumerate(["num_concept",  "num_samples"]):
#             elif investigation_type=="statistics":
            plot_key=f"{stage}_{variable}"    
    
            record_all_df=pd.DataFrame(record_all_list)
            record_all_df_filtered=record_all_df[record_all_df["is_clean"]=="clean_only"]
#             sdsd
            if variable=="num_concept":
                record_all_df_filtered=record_all_df_filtered[record_all_df_filtered["method"]=="automatic_curated_monet_less_concept"]
                b=sns.lineplot(x="num_concept", y="auc", hue="task", data=record_all_df_filtered, ax=axd[plot_key])
                b.legend_.remove()
                
                axd[plot_key].set_xlabel("Num. of concepts", fontsize=30)
                axd[plot_key].set_xlim(1-0.2, 11.5)
            elif variable=="num_reference":
                record_all_df_filtered=record_all_df_filtered[record_all_df_filtered["method"]=="automatic_curated_monet_less_reference"]
                b=sns.lineplot(x="num_ref_concepts", y="auc", hue="task", data=record_all_df_filtered, ax=axd[plot_key])
                b.legend_.remove()
                
                axd[plot_key].set_xlabel("Num. of reference concepts", fontsize=30)
                axd[plot_key].set_xlim(-0.2, 5.1)
                
            elif variable=="num_samples":
                record_all_df_filtered=record_all_df_filtered[record_all_df_filtered["method"]=="automatic_curated_monet_less_sample"]                
                b=sns.lineplot(x="sample_prop", y="auc", hue="task", data=record_all_df_filtered, ax=axd[plot_key])
                b.legend_.remove()
                
                axd[plot_key].set_xlabel("Proportion of training data", fontsize=30)
            else:
                raise ValueError
                
            axd[plot_key].set_ylim(0.61, 0.94)
            
                
            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)    
            axd[plot_key].tick_params(axis='y', which='major', labelsize=25)
            
            if variable=="num_concept":
                axd[plot_key].xaxis.set_major_locator(MultipleLocator(2))
                axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].tick_params(axis='x', which='major', labelsize=25)                
                
            elif variable=="num_reference":
                axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
                axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].tick_params(axis='x', which='major', labelsize=25)
                       
            elif variable=="num_samples":
                #axd[plot_key].set_xticks([0.05, 0.1 , 0.2 , 0.4 , 0.6 , 0.8 , 1.])
                #.set_xticks([2,4,6,8,10])
                axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
                axd[plot_key].tick_params(axis='x', which='major', labelsize=30)
            
            if idx2==0:
                axd[plot_key].set_ylabel("Area under the ROC curve", fontsize=30)
            else:
                #axd[plot_key].set_ylabel(None)    
                axd[plot_key].set_ylabel("Area under the ROC curve", fontsize=30)
            
                
                #axd[plot_key].tick_params(axis='x', which='major', left=False, labelleft=False)
                
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(3)                     
            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            axd[plot_key].text(x=-0.15, y=1.0, transform=axd[plot_key].transAxes,
                                     s=["A", "B"][idx2], fontsize=35, weight='bold')               
            
            if idx2==1:
                legend_elements=[Line2D([0], [0], color=np.array(Paired[12][1])/256, linewidth=10, label="Malignancy"),
                                 Line2D([0], [0], color=np.array(Paired[12][3])/256, linewidth=10, label="Melanoma")]
                axd[plot_key].legend(handles=legend_elements, 
                            ncol=2, 
                            handlelength=3,
                            handletextpad=0.6, 
                            columnspacing=1.5,
                            fontsize=30,
                            loc='lower center', bbox_to_anchor=(-0.15, -0.3))              
            
# fig.savefig(log_dir/"plots"/"main_cbm_ablation.png", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm_ablation.jpg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm_ablation.svg", bbox_inches='tight')
# fig.savefig(log_dir/"plots"/"main_cbm_ablation.pdf", bbox_inches='tight')
#plt.close(fig)               

In [None]:
record_all_df_filtered["method"].unique()

In [None]:
import matplotlib.pyplot as plt

trianglex = [ 1, 10, 7, 1 ] 
triangley = [ 2, 8, 4, 2 ]    
triangle2x = [ 13, 25, 21, 13]
triangle2y = [ 5,  7 , 14, 5 ]

plt.figure('Triangles')
for i in range(3):
    plt.plot(trianglex, triangley, 'o-')
plt.fill(trianglex, triangley)


plt.show()

In [None]:
box2

In [None]:
        for p in ax.patches:
            _x = p.get_x() + p.get_width() / 2
            _y = p.get_y() + p.get_height()
            value = '{:.2f}'.format(p.get_height())
            ax.text(_x, _y, value, ha="center")

In [None]:
for rect in weight_bar:
    height = rect.get_height()
    print(height)
    #plt.text(rect.get_x() + rect.get_width() / 2.0, height, f'{height:.0f}', ha='center', va='bottom')


In [None]:
weight_bar

In [None]:
weight_bar

In [None]:
plt.rcParams["axes.prop_cycle"].by_key()["color"][0]

In [None]:
record_all_df_skincon_filtered_ref=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]=="automatic_skincon_monet_full"]
if variable.endswith("num_sample"):
    record_all_df_skincon_filtered_obs=record_all_df_skincon_filtered[record_all_df_skincon_filtered["method"]=="skincon_manual_less_sample"]
    record_all_df_skincon_filtered_obs_sample_prop=record_all_df_skincon_filtered_obs.groupby("sample_prop")["num_sample_train"].mean()
    record_all_df_skincon_filtered_obs["num_sample_train_pseudo"]=record_all_df_skincon_filtered_obs.apply(lambda x: record_all_df_skincon_filtered_obs_sample_prop[x["sample_prop"]], axis=1)
    b=sns.lineplot(x="num_sample_train_pseudo", y="auc", style="method", data=record_all_df_skincon_filtered_obs, ax=axd[plot_key])


In [None]:
import itertools

In [None]:
record_all_df_skincon_filtered_obs_sample_prop

In [None]:
record_all_df_skincon_filtered_ref["auc"]

In [None]:
record_all_df_skincon_filtered_obs_sample_prop.index