In [None]:
import os

import hydra
import omegaconf
import pyrootutils

root = pyrootutils.setup_root(os.path.abspath("revision_analysis.ipynb"), pythonpath=True)

os.chdir(root)

In [None]:
import sys

sys.path.append(str(root / "src"))

In [None]:
import os
import time
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib import gridspec
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn.metrics
import torch
import tqdm
from IPython import display
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

import clip

In [None]:
# !cp /System/Library/Fonts/Supplemental ~/.local/share/fonts/
# rm -fr ~/.cache/matplotlib
from matplotlib import font_manager
from matplotlib.lines import Line2D
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.edgecolor']='1.0'
plt.rcParams['legend.framealpha']=0

# https://github.com/dsc/colorbrewer-python/blob/master/colorbrewer.py

Set1 = {
    3: [[228,26,28], [55,126,184], [77,175,74]],
    4: [[228,26,28], [55,126,184], [77,175,74], [152,78,163]],
    5: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0]],
    6: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51]],
    7: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40]],
    8: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191]],
    9: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191], [153,153,153]],
}

Paired = {
    3: [(166,206,227), [31,120,180], [178,223,138]],
    4: [[166,206,227], [31,120,180], [178,223,138], [51,160,44]],
    5: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153]],
    6: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28]],
    7: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111]],
    8: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0]],
    9: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214]],
    10: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154]],
    11: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153]],
    12: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153], [177,89,40]]
}

color_qual_7=['#F53345',
            '#87D303',
            '#04CBCC',
            '#8650CD',
            (160/256, 95/256, 0),
            '#F5A637',              
            '#DBD783',            
             ]

pd.set_option('display.max_rows', 500)

In [None]:
import scipy.special
import tqdm.contrib.concurrent

In [None]:
# import importlib
# importlib.reload(sys.modules["MONET.utils.static"])
# from MONET.utils.static import (
#     concept_to_prompt)

In [None]:
from MONET.datamodules.multiplex_datamodule import MultiplexDatamodule
from MONET.utils.loader import custom_collate_per_key, dataloader_apply_func
from MONET.utils.metrics import skincon_calcualte_auc_all
from MONET.utils.static import (
    concept_to_prompt,
    fitzpatrick17k_disease_label,
    fitzpatrick17k_ninelabel,
    fitzpatrick17k_threelabel,
    skincon_cols,
)
from MONET.utils.text_processing import generate_prompt_token_from_concept

In [None]:
def wandb_to_exppath(wandb, log_path="/gscratch/cse/chanwkim/MONET_log/train/runs"):
    log_path = Path(log_path)
    for experiment in os.listdir(log_path):
        if os.path.exists(log_path / experiment / "wandb"):
            filenames = os.listdir(log_path / experiment / "wandb")
            filename = [filename for filename in filenames if filename.startswith("run")][0][-8:]
            if filename == wandb:
                return log_path / experiment
    raise RuntimeError("not found")


exppath = wandb_to_exppath(
    wandb="baqqmm5v", log_path="/projects/leelab2/chanwkim/dermatology_datasets/logs/train/runs"
)
print([exppath / "checkpoints" / ckpt for ckpt in os.listdir(exppath / "checkpoints/")])

In [None]:
variable_dict={}

In [None]:
def setup_dataloader(dataset_name):
    if dataset_name=="clinical_fd_clean_nodup":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "clinical_fd_clean_nodup=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()      
    
    elif dataset_name=="fitzpatrick17k_clean_threelabel_nodup":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "fitzpatrick17k_clean_threelabel_nodup=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()    
    
    elif dataset_name=="fitzpatrick17k_skincon":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "fitzpatrick17k_skincon=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()
        
    elif dataset_name=="ddi":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "ddi=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()  
        
    elif dataset_name=="ddiskincon":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "ddiskincon=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()          
        
    elif dataset_name=="isic":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "isic=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()  

        
    elif dataset_name=="derm7pt_derm":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "derm7pt_derm=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()     
        
        dataloader = dm.test_dataloader()           
        
    elif dataset_name=="allpubmedtextbook":
        cfg_dm = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "multiplex.yaml")
        # cfg.data_dir="/scr/chanwkim/dermatology_datasets"
        cfg_dm.data_dir = "/sdata/chanwkim/dermatology_datasets"
        cfg_dm.dataset_name_test = "pubmed=all,textbook=all"
        cfg_dm.split_seed = 42

        dm = hydra.utils.instantiate(cfg_dm)
        dm.setup()
        
        dataloader = dm.test_dataloader()   
        
        
        
        
       
        
    return {"dataloader": dataloader}

In [None]:
for dataset_name in ["clinical_fd_clean_nodup", "fitzpatrick17k_clean_threelabel_nodup", "fitzpatrick17k_skincon", "ddi", "ddiskincon", "isic", "allpubmedtextbook"]:
    variable_dict.setdefault(dataset_name, {})
    variable_dict[dataset_name].update(setup_dataloader(dataset_name))

In [None]:
for dataset_name in ["clinical_fd_clean_nodup", "isic", "derm7pt_derm", "allpubmedtextbook"]:
    variable_dict.setdefault(dataset_name, {})
    variable_dict[dataset_name].update(setup_dataloader(dataset_name))

In [None]:
import torchvision

efficientnet_device="cuda:6"
efficientnet = torchvision.models.efficientnet_v2_s(
    weights=torchvision.models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
).to(efficientnet_device)
efficientnet.eval()

def get_layer_feature(model, feature_layer_name, image):
    # image = self.normalize(self.toTensor(img)).unsqueeze(0).to(self.device)
    # embedding = torch.zeros(image.shape[0], num_features, 1, 1).to(image.device)
    feature_layer = model._modules.get(feature_layer_name)

    embedding = []

    def copyData(module, input, output):
        embedding.append(output.data)

    h = feature_layer.register_forward_hook(copyData)
    out = model(image.to(image.device))
    h.remove()
    embedding = embedding[0]
    assert embedding.shape[0] == image.shape[0], f"{embedding.shape[0]} != {image.shape[0]}"
    assert embedding.shape[2] == 1, f"{embedding.shape[2]} != 1"
    assert embedding.shape[2] == 1, f"{embedding.shape[3]} != 1"
    return embedding[:, :, 0, 0]

def batch_func(batch):
    with torch.no_grad():
        efficientnet_feature = get_layer_feature(
            efficientnet, "avgpool", batch["image"].to(efficientnet_device)
        )

    return {
        "efficientnet_feature": efficientnet_feature.detach().cpu(),
        "metadata": batch["metadata"],
    }

def setup_efficientnet_features(dataset_name, dataloader):
    loader_applied = dataloader_apply_func(
        dataloader=dataloader,
        func=batch_func,
        collate_fn=custom_collate_per_key,
    )    
    efficientnet_feature=loader_applied["efficientnet_feature"].cpu()
    efficientnet_metadata=loader_applied["metadata"]
    
    return {"efficientnet_feature":efficientnet_feature, 
            "efficientnet_metadata": efficientnet_metadata}

for dataset_name in ["clinical_fd_clean_nodup", "isic", "derm7pt_derm", "allpubmedtextbook"]:
    print("Featurizing...")
    print(dataset_name)
    variable_dict[dataset_name].update(setup_efficientnet_features(dataset_name, variable_dict[dataset_name]["dataloader"])) 

In [None]:
from sklearn.decomposition import PCA

def calculate_pca(efficientnet_feature):
    pca = PCA(n_components=50, svd_solver="auto")
    pca.fit(efficientnet_feature)
    efficientnet_feature_pc=pca.transform(efficientnet_feature)
    return {"efficientnet_feature_pc": efficientnet_feature_pc}

for dataset_name in ["clinical_fd_clean_nodup", "isic", "derm7pt_derm", "allpubmedtextbook"]:
    print("Calculating PCA...")
    print(dataset_name)
    variable_dict[dataset_name].update(calculate_pca(variable_dict[dataset_name]["efficientnet_feature"])) 

In [None]:
# torch.save(variable_dict, "logs/experiment_results/revision_0813.pt")

In [None]:
# variable_dict= torch.load("logs/experiment_results/revision_0813.pt", map_location="cpu")

In [None]:
import sklearn.metrics
def get_idx_from_concat_dataset(idx, concat_dataset):
    offset=0
    assert isinstance(concat_dataset, list)
    assert isinstance(concat_dataset[0], torch.utils.data.Dataset)
    
    for count, dataset in enumerate(concat_dataset):
        if idx-offset>=len(dataset):
            offset+=len(dataset)
            continue
        return count, idx-offset

def overlap_check(target_features, 
                  target_dataset,
                  ref_features,
                  ref_dataset):   

    
    pca = PCA(n_components=50, svd_solver="auto")
    pca.fit(ref_features)
    ref_features=pca.transform(ref_features)    
    target_features=pca.transform(target_features)    
    
    
    start_idx=0
    end_idx=0+500
    cut_off=0.9

    similarity_matrix=sklearn.metrics.pairwise.cosine_similarity(X=target_features, Y=ref_features)
    n_top=5
    n_row=((((similarity_matrix[start_idx:end_idx])>cut_off).sum(axis=1))>0).sum()

    print('total',((((similarity_matrix)>cut_off).sum(axis=1))>0).sum(), similarity_matrix.shape)
    
    fig, axes = plt.subplots(nrows=n_row, ncols=n_top+1, figsize=(5,n_row))

    row_count=0
    for idx_count in range(len(ref_dataset)):
        if row_count>=n_row:
            break

        if idx_count not in range(start_idx,end_idx):
            continue


#         diff_array_sort=np.sort(1-image_features_f17k_efficientnet_pc_cos[idx_count])
#         diff_array_argsort=np.argsort(1-image_features_f17k_efficientnet_pc_cos[idx_count])

        similarity_array_argsort=np.argsort(similarity_matrix[idx_count])[::-1]
        similarity_array_sort=similarity_matrix[idx_count][similarity_array_argsort]

        if similarity_array_sort[0]<=cut_off:
            continue       
            
            
        image_target=target_dataset.getitem(idx_count)["image"]
        axes[row_count, 0].imshow(image_target.resize((200,200)))

        axes[row_count, 0].set_xticks([])
        axes[row_count, 0].set_yticks([])    
        axes[row_count, 0].set_title(str(idx_count), y=0.6, fontdict={'color': 'red', "fontsize":7})

        col_count=1
        plotted_idx=[]
        for similarity_idx, similarity in zip(similarity_array_argsort, similarity_array_sort):
            if similarity<=cut_off:
                break                        
            if col_count>=n_top+1:
                break
    #         if diff_idx<idx_count:
    #             continue
            concat_dataset_idx, sample_idx = get_idx_from_concat_dataset(
            idx=similarity_idx,
            concat_dataset=ref_dataset.datasets
            )     
            image_ref=ref_dataset.datasets[concat_dataset_idx].getitem(sample_idx)["image"]

            plotted_idx.append(similarity_idx)
            axes[row_count, col_count].imshow(image_ref.resize((200,200)))

            axes[row_count, col_count].set_xticks([])
            axes[row_count, col_count].set_yticks([])         

            axes[row_count, col_count].set_title(str(similarity_idx)+', '+f"{similarity:.2f}", y=0.6, fontdict={'color': 'red', 
                                                                                                   "fontsize":7})
            col_count+=1
        print(', '.join([str(i) for i in [idx_count]+plotted_idx]))


        for col_count_idx in range(col_count, n_top+1):
            axes[row_count, col_count_idx].set_xticks([])
            axes[row_count, col_count_idx].set_yticks([])                     
        row_count+=1 

    #     similarity_matrix=variable_dict["clinical_fd_clean_nodup"]["efficientnet_feature_norm"]\
    #     @variable_dict["allpubmedtextbook"]["efficientnet_feature_norm"].T    
    
    
overlap_check(target_features=variable_dict["clinical_fd_clean_nodup"]["efficientnet_feature"],
              target_dataset=variable_dict["clinical_fd_clean_nodup"]["dataloader"].dataset,
              ref_features=variable_dict["allpubmedtextbook"]["efficientnet_feature"],
              ref_dataset=variable_dict["allpubmedtextbook"]["dataloader"].dataset,
             )

In [None]:
dup_found=pd.read_csv("scripts/preprocess/training_duplicate.csv", names=['target_idx', 1, 2, 3, 4, 5])
# ㄴㅇㄴ
dup_found["target_idx"]=variable_dict["clinical_fd_clean_nodup"]["dataloader"].dataset.metadata_all.iloc[dup_found["target_idx"]].index.values
for i in range(1,5+1):
    dup_found[i]=dup_found[i].map(lambda x: np.nan if np.isnan(x) else variable_dict["allpubmedtextbook"]["dataloader"].dataset[int(x)]["metadata"].name)
# dup_found.to_csv("data/fitzpatrick17k/training_overlap.csv", index=False)

In [None]:
duplicate_info=pd.read_csv("data/fitzpatrick17k/training_overlap.csv", index_col=0)
duplicate_info

In [None]:
def plot_duplicate(duplicate_info,
                  target_features, 
                  target_metadata,
                  target_dataset,
                  ref_features,
                  ref_metadata,
                  ref_dataset):

    
    pca = PCA(n_components=50, svd_solver="auto")
    pca.fit(ref_features)
    ref_features=pca.transform(ref_features)    
    target_features=pca.transform(target_features)    
    
    cut_off=0.9

    similarity_matrix=sklearn.metrics.pairwise.cosine_similarity(X=target_features, Y=ref_features)
    n_top=5
    n_row=len(duplicate_info)

    print('total',((((similarity_matrix)>cut_off).sum(axis=1))>0).sum(), similarity_matrix.shape)
    
    fig, axes = plt.subplots(nrows=n_row, ncols=n_top+1, figsize=(5,n_row))

    row_count=0
    for target_idx, row in duplicate_info.iterrows():
        
        image_target=target_dataset.getitem(target_metadata.index.tolist().index(target_idx))["image"]
        axes[row_count, 0].imshow(image_target.resize((200,200)))

        axes[row_count, 0].set_xticks([])
        axes[row_count, 0].set_yticks([])    
        axes[row_count, 0].set_title(str(target_idx)[:5], y=0.6, fontdict={'color': 'red', "fontsize":7})

        col_count=1
        plotted_idx=[]
        for ref_idx in row:
            if not isinstance(ref_idx, str) and np.isnan(ref_idx):
                break            
            
            
            concat_dataset_idx, sample_idx = get_idx_from_concat_dataset(
            idx=ref_metadata.index.tolist().index(ref_idx),
            concat_dataset=ref_dataset.datasets
            )     
            image_ref=ref_dataset.datasets[concat_dataset_idx].getitem(sample_idx)["image"]
            
            
            

#             plotted_idx.append(similarity_idx)
            axes[row_count, col_count].imshow(image_ref.resize((200,200)))

            axes[row_count, col_count].set_xticks([])
            axes[row_count, col_count].set_yticks([])         

            axes[row_count, col_count].set_title(str(ref_idx)[:5] + "\n" + f"{similarity_matrix[target_metadata.index.tolist().index(target_idx), ref_metadata.index.tolist().index(ref_idx)]:.2f}", y=0.6, fontdict={'color': 'red', 
                                                                                                   "fontsize":7})
            col_count+=1
#         print(', '.join([str(i) for i in [idx_count]+plotted_idx]))


        for col_count_idx in range(col_count, n_top+1):
            axes[row_count, col_count_idx].set_xticks([])
            axes[row_count, col_count_idx].set_yticks([])                     
        row_count+=1 
        
plot_duplicate(duplicate_info=duplicate_info.iloc[:],
              target_features=variable_dict["clinical_fd_clean_nodup"]["efficientnet_feature"],
              target_metadata=variable_dict["clinical_fd_clean_nodup"]["efficientnet_metadata"],
              target_dataset=variable_dict["clinical_fd_clean_nodup"]["dataloader"].dataset,
              ref_features=variable_dict["allpubmedtextbook"]["efficientnet_feature"],
              ref_metadata=variable_dict["allpubmedtextbook"]["efficientnet_metadata"],
              ref_dataset=variable_dict["allpubmedtextbook"]["dataloader"].dataset,
             )        

In [None]:
variable_dict["clinical_fd_clean_nodup"]["efficientnet_metadata"]

In [None]:
dup_found

In [None]:
dup_found

In [None]:
dup_found["target_idx"]

In [None]:
variable_dict["clinical_fd_clean_nodup"]["dataloader"].dataset.metadata_all

In [None]:
x=pd.read_csv("https://raw.githubusercontent.com/ISIC-Research/expert-annotation-agreement-data/main/metadata.csv")

In [None]:
x["exemplar"].value_counts()

In [None]:
print(1)

In [None]:
a=variable_dict["clinical_fd_clean_nodup"]["efficientnet_feature_pc"]

b=variable_dict["allpubmedtextbook"]["efficientnet_feature_pc"]

In [None]:
a.shape, b.shape

In [None]:
sklearn.metrics.pairwise.cosine_similarity(a).max()

In [None]:
variable_dict["allpubmedtextbook"]["dataloader"].dataset[0]

In [None]:
variable_dict["clinical_fd_clean_nodup"]["efficientnet_metadata"]

In [None]:
variable_dict["clinical_fd_clean_nodup"]["efficientnet_feature_norm"].shape

In [None]:
similarity_matrix.max(axis=1)

In [None]:
from IPython.display import display

In [None]:
similarity_matrix_max=similarity_matrix.max(axis=1)

In [None]:
similarity_matrix_max_sorted=pd.DataFrame(
    {
        "max_value":similarity_matrix_max.values,
        "ref_idx": similarity_matrix_max.indices.numpy()
    }).sort_values("max_value", ascending=False)

In [None]:
# sim_matrix_max_sorted=sim_matrix_max_sorted.astype({"ref_idx": np.int64})

In [None]:
sim_matrix_max_sorted

In [None]:
sim_matrix_max_sorted[sim_matrix_max_sorted["max_value"]>0.9].shape

In [None]:
for idx, (target_idx, row) in enumerate(similarity_matrix_max_sorted[similarity_matrix_max_sorted["max_value"]>0.9].iterrows()):
    print(row)
    if idx>10:
        break
    max_val=row["max_value"]
    ref_idx=row["ref_idx"]
    
    print(target_idx, ref_idx, max_val)


    image_target=variable_dict["clinical_fd_clean_nodup"]["dataloader"].dataset.getitem(target_idx)["image"]        

    concat_dataset_idx, sample_idx = get_idx_from_concat_dataset(
        idx=int(ref_idx),
        concat_dataset=variable_dict["allpubmedtextbook"]["dataloader"].dataset.datasets
    )     
    image_ref=variable_dict["allpubmedtextbook"]["dataloader"].dataset.datasets[concat_dataset_idx].getitem(sample_idx)["image"]        


    display(image_target.resize((100,100)))
    display(image_ref.resize((100,100)))
    print('------------------------------')

In [None]:
similarity_matrix_max_sorted[similarity_matrix_max_sorted["max_value"]>0.9].shape