In [None]:
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import os
import tqdm
import pickle
from tqdm import tqdm
import pandas as pd
from model_utils import get_trained_model, get_model_parts
from concept_utils import get_concept_scores_mv_valid, ConceptBank
from data_utils import SkinDataset
import os
from skimage import io
from PIL import Image
import torch


def config():
    parser = argparse.ArgumentParser()
    parser.add_argument("--fitzpatrick-csv-path", default="../../dataset/fitzpatrick17k.csv", type=str)
    parser.add_argument("--pretrained-model-path",
                        default="model-path_resnet_25_random-holdout_low_0.pth",
                        type=str)
    parser.add_argument("--model-dir", default="/path/outputs/fitzpatrick/", type=str)
    parser.add_argument("--concept-bank-path", default="./derma_concepts_resnet_new.pkl", type=str)
    parser.add_argument("--data-dir", default="/path/data/finalfitz17k", type=str)
    parser.add_argument("--device", default="cuda", type=str)
    parser.add_argument("--batch-size", default=32, type=int)
    parser.add_argument("--num-workers", default=4, type=int)
    parser.add_argument("--n-samples", default=100, type=int, help="Number of positive/negatives for learning the concept.")
    parser.add_argument("--seed", default=42, type=int, help="Random seed")
    parser.add_argument("--model-type", default="resnet", type=str)
    parser.add_argument("-f")

    return parser.parse_args()

args = config()

### Skin Color Analysis

In [None]:
df = pd.read_csv(args.fitzpatrick_csv_path)
df["low"] = df['label'].astype('category').cat.codes
df["mid"] = df['nine_partition_label'].astype('category').cat.codes
df["high"] = df['three_partition_label'].astype('category').cat.codes
df["hasher"] = df["md5hash"]
torch.manual_seed(args.seed)
std_pxs = np.array([0.229, 0.224, 0.225])
mean_pxs = np.array([0.485, 0.456, 0.406])

In [None]:
all_concepts = pickle.load(open(args.concept_bank_path, 'rb'))
concept_bank = ConceptBank(all_concepts, args.device)


### Finding Images where Transformations can explain model mistakes

In [None]:
scenarios = []
for i in range(1):
    scenarios.append({"path": f"random-holdout_low_{i}",
                      "df": df[((df.fitzpatrick == 5) | (df.fitzpatrick == 6)) & (df.low == 6)],
                      "tag": f"skin-type-bias-{i}"})

concept_name = "dark-skin-color"

for scenario in scenarios[:1]:
    res = []
    flip_list, ess = [], []
    model_path = os.path.join(args.model_dir, f"model-path_{args.model_type}_25_{scenario['path']}.pth")
    model_ft = get_trained_model(args.model_type, model_path)
    model_ft = model_ft.to(args.device)
    model_bottom, model_top = get_model_parts(model_name=args.model_type, model=model_ft)
    model_bottom = model_bottom.to(args.device)
    model_bottom.eval(), model_top.eval()
    
    test_df = scenario["df"]
    test_ds = SkinDataset(test_df,
                          root_dir=args.data_dir,
                          transform=transforms.Compose([
                              transforms.ToPILImage(),
                              transforms.Resize(size=256),
                              transforms.CenterCrop(size=224),
                              transforms.ToTensor(),
                              transforms.Normalize(mean_pxs, std_pxs)
                          ]))
    
    loader = torch.utils.data.DataLoader(test_ds,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=1)
    
    for i, batch in enumerate(tqdm(loader)):
        labels = batch["low"]
        inputs = batch["image"].to(args.device)
        
        labels = labels.long().to(args.device)
        orig_out = model_top(model_bottom(inputs))
        pred = orig_out.argmax(dim=1)
        prob = orig_out.max()
        if pred == labels:
            res.append(None)
            flip_list.append(None)
        else:
            opt_result = get_concept_scores_mv_valid(inputs, labels, 
                                                     concept_bank, 
                                                     model_bottom, model_top,
                                                     alpha=1e-2, beta=2e-2, lr=2e-2,
                                                     enforce_validity=True, momentum=0.9,
                                                     kappa="mean")
            res.append(opt_result.concept_scores_list.index(concept_name))
    #break
    top_ranks = np.array([a for a in res if a is not None])
    bottom_ranks = len(concept_bank.concept_names)-top_ranks
    print(res)
    print("Bottom Ranks:")
    print(np.mean(bottom_ranks), np.median(bottom_ranks), np.quantile(bottom_ranks, 0.25),
          np.quantile(bottom_ranks, 0.75),(bottom_ranks<6).mean())
    print("Top Ranks:")
    print(np.mean(1+top_ranks), np.median(1+top_ranks), np.quantile(1+top_ranks, 0.25),
          np.quantile(1+top_ranks, 0.75),(1+top_ranks<6).mean())
    