In [None]:
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import os
import tqdm
import pickle
from tqdm import tqdm
import pandas as pd
from model_utils import get_embedding, get_trained_model, get_model_parts
from concept_utils import get_concept_scores_mv_valid
from concept_utils import learn_concept
from data_utils import SkinDataset
import os
from skimage import io
from PIL import Image
import torch
from imagecorruptions import corrupt, get_corruption_names

from sklearn.svm import SVC

def config():
    parser = argparse.ArgumentParser()
    parser.add_argument("--fitzpatrick-csv-path", default="./dataset/fitzpatrick17k.csv", type=str)
    parser.add_argument("--model-dir", default="/path/outputs/fitzpatrick/", type=str)
    parser.add_argument("--concept-bank-path", default="./banks/concept_resnet_170.pkl", type=str)
    parser.add_argument("--data-dir", default="/path/dataset/data/finalfitz17k", type=str)
    parser.add_argument("--device", default="cuda", type=str)
    parser.add_argument("--batch-size", default=32, type=int)
    parser.add_argument("--num-workers", default=4, type=int)
    parser.add_argument("--n-samples", default=40, type=int, help="Number of positive/negatives for learning the concept.")
    parser.add_argument("--seed", default=42, type=int, help="Random seed")
    parser.add_argument("--model-type", default="resnet", type=str)
    parser.add_argument("--C", default=0.01, type=float)
    parser.add_argument("-f")
    return parser.parse_args()

args = config()

In [None]:
model_ft = models.resnet18(pretrained=True)
model_bottom, _ = get_model_parts(model_name=args.model_type, model=model_ft)
model_bottom = model_bottom.to(args.device).eval()

In [None]:
derma_concepts = {}

### Skin Color Concept

In [None]:

df = pd.read_csv(args.fitzpatrick_csv_path)
df["low"] = df['label'].astype('category').cat.codes
df["mid"] = df['nine_partition_label'].astype('category').cat.codes
df["high"] = df['three_partition_label'].astype('category').cat.codes
df["hasher"] = df["md5hash"]
torch.manual_seed(args.seed)
ds_transforms = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize(size=256),
                transforms.CenterCrop(size=224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
pos_ds = SkinDataset(df[((df.fitzpatrick == 5) | (df.fitzpatrick == 6)) & (df.low != 1)],
                     root_dir=args.data_dir,
                     transform=ds_transforms)
pos_loader = torch.utils.data.DataLoader(pos_ds,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.num_workers)

neg_ds = SkinDataset(df[((df.fitzpatrick == 1) | (df.fitzpatrick == 2)) & (df.low != 1)],
                     root_dir=args.data_dir,
                     transform=ds_transforms)

neg_loader = torch.utils.data.DataLoader(neg_ds,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.num_workers)

# Get model activations
pos_act, pos_fps = get_embedding(pos_loader, model_bottom, n_samples=args.n_samples, device=args.device)
neg_act, neg_fps = get_embedding(neg_loader, model_bottom, n_samples=args.n_samples, device=args.device)
activations = torch.cat([pos_act, neg_act], dim=0)
c_labels = torch.cat([torch.ones(args.n_samples), torch.zeros(args.n_samples)], dim=0)
activations, c_labels = activations.cpu().numpy(), c_labels.numpy()
concept_info = learn_concept(activations, c_labels, args, C=args.C)
derma_concepts["dark-skin-color"] = concept_info

## Image Quality Concepts

### Zoom

In [None]:
class EasyDataset():
    def __init__(self, images, transform=None):
        self.images = images
        self.transform = transform
    
    def __len__(self):
        # Return the length of the dataset
        return len(self.images)
    
    def __getitem__(self, idx):
        # Return the observation based on an index. Ex. dataset[0] will return the first element from the dataset, in this case the image and the label.
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_path = self.images[idx]
        image = io.imread(img_path)
        if self.transform:
            image = self.transform(image)
        res = {"image": image, 
              "fitzpatrick": -1}
        return res

In [None]:
df1 = df.groupby("fitzpatrick").sample(n=30, random_state=1).sample(args.n_samples*2)

In [None]:
torch.manual_seed(args.seed)
pos_ds = SkinDataset(df1,
                     root_dir=args.data_dir,
                     transform=transforms.Compose([
                         transforms.ToPILImage(),
                         transforms.Resize(size=224*4),
                         transforms.CenterCrop(size=224),
                         transforms.ToTensor(),
                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                     ]))


pos_loader = torch.utils.data.DataLoader(pos_ds,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.num_workers)

neg_ds = SkinDataset(df1,
                     root_dir=args.data_dir,
                     transform=ds_transforms)

neg_loader = torch.utils.data.DataLoader(neg_ds,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.num_workers)

# Get model activations
pos_act, pos_fps = get_embedding(pos_loader, model_bottom, n_samples=args.n_samples, device=args.device)
neg_act, neg_fps = get_embedding(neg_loader, model_bottom, n_samples=args.n_samples, device=args.device)
activations = torch.cat([pos_act, neg_act], dim=0)
c_labels = torch.cat([torch.ones(args.n_samples), torch.zeros(args.n_samples)], dim=0)
activations, c_labels = activations.cpu().numpy(), c_labels.numpy()
concept_info = learn_concept(activations, c_labels, args, C=args.C)
derma_concepts["zoom"] = concept_info

### Common Corruptions

In [None]:
class CorruptedSkinDataset():
    def __init__(self, df, root_dir, transform=None,
                corruption_name="brightness", severity=2):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.df = df
        self.root_dir = root_dir
        self.transform = transform
        self.corruption_name = corruption_name
        self.severity = severity

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img_name = os.path.join(self.root_dir,
                                f"{self.df.loc[self.df.index[idx], 'hasher']}.jpg")
        image = io.imread(img_name)
        image = corrupt(image, corruption_name=self.corruption_name, severity=self.severity)
        if(len(image.shape) < 3):
            image = skimage.color.gray2rgb(image)

        hasher = self.df.loc[self.df.index[idx], 'hasher']
        high = self.df.loc[self.df.index[idx], 'high']
        mid = self.df.loc[self.df.index[idx], 'mid']
        low = self.df.loc[self.df.index[idx], 'low']
        fitzpatrick = self.df.loc[self.df.index[idx], 'fitzpatrick']
        
        if self.transform:
            image = self.transform(image)
        sample = {
                    'image': image,
                    'high': high,
                    'mid': mid,
                    'low': low,
                    'hasher': hasher,
                    'fitzpatrick': fitzpatrick
                }
        return sample


corruptions = {"zoom_blur": 2,
              "contrast": 2,
              "defocus_blur": 3,
              "brightness": 3,
              "motion_blur": 3}

In [None]:
for corruption in corruptions.keys():
    df1 = df.sample(2*args.n_samples)
    pos_ds = CorruptedSkinDataset(df1,
                                  root_dir=args.data_dir,
                                  transform=ds_transforms,
                                  corruption_name=corruption,
                                  severity=corruptions[corruption])
    
    pos_loader = torch.utils.data.DataLoader(pos_ds,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.num_workers)
    
    neg_ds = SkinDataset(df1,
                         root_dir=args.data_dir,
                         transform=ds_transforms)
    
    neg_loader = torch.utils.data.DataLoader(neg_ds,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.num_workers)
    
    # Get model activations
    pos_act, pos_fps = get_embedding(pos_loader, model_bottom, n_samples=args.n_samples, device=args.device)
    neg_act, neg_fps = get_embedding(neg_loader, model_bottom, n_samples=args.n_samples, device=args.device)
    activations = torch.cat([pos_act, neg_act], dim=0)
    c_labels = torch.cat([torch.ones(args.n_samples), torch.zeros(args.n_samples)], dim=0)
    activations, c_labels = activations.cpu().numpy(), c_labels.numpy()
    concept_info = learn_concept(activations, c_labels, args, C=args.C)
    derma_concepts[corruption] = concept_info
    print(corruption, corruptions[corruption], concept_info[1], concept_info[2])

### Hair

In [None]:
class EasyDataset():
    def __init__(self, images, transform=None):
        self.images = images
        self.transform = transform
    
    def __len__(self):
        # Return the length of the dataset
        return len(self.images)
    
    def __getitem__(self, idx):
        # Return the observation based on an index. Ex. dataset[0] will return the first element from the dataset, in this case the image and the label.
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_path = self.images[idx]
        image = io.imread(img_path)
        if self.transform:
            image = self.transform(image)
        res = {"image": image, 
              "fitzpatrick": -1}
        return res


In [None]:
hair_images_folder = "./hair_images/" 
torch.manual_seed(1)

In [None]:
hair_ims = [os.path.join(hair_images_folder, p) for p in os.listdir(hair_images_folder)]
hair_n_samples = 10
pos_ds = EasyDataset(hair_ims, transform=ds_transforms)

pos_loader = torch.utils.data.DataLoader(pos_ds,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.num_workers)

neg_ds = SkinDataset(df.sample(200),
                     root_dir=args.data_dir,
                     transform=ds_transforms)

neg_loader = torch.utils.data.DataLoader(neg_ds,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.num_workers)

# Get model activations
pos_act, pos_fps = get_embedding(pos_loader, model_bottom, n_samples=len(hair_ims), device=args.device)
neg_act, neg_fps = get_embedding(neg_loader, model_bottom, n_samples=len(hair_ims), device=args.device)

print(neg_act.shape)

X_train = torch.cat([pos_act[:hair_n_samples], neg_act[:hair_n_samples]], dim=0).cpu().numpy()
X_test = torch.cat([pos_act[hair_n_samples:len(hair_ims)], neg_act[hair_n_samples:len(hair_ims)]], dim=0).cpu().numpy()
y_train = torch.cat([torch.ones(hair_n_samples), torch.zeros(hair_n_samples)], dim=0).numpy()
y_test = torch.cat([torch.ones(len(hair_ims)-hair_n_samples), torch.zeros(len(hair_ims)-hair_n_samples)], dim=0).numpy()


svm = SVC(kernel="linear", C=0.001, probability=False)
svm.fit(X_train, y_train)
train_acc = svm.score(X_train, y_train)
test_acc = svm.score(X_test, y_test)
print(f"Accuracy - Training: {train_acc}, Test: {test_acc}")

train_margin = ((np.dot(svm.coef_, X_train.T) + svm.intercept_) / np.linalg.norm(svm.coef_)).T

margin_info = {"max": np.max(train_margin),
               "min": np.min(train_margin),
               "pos_mean": np.mean(train_margin[train_margin > 0]),
               "pos_std": np.std(train_margin[train_margin > 0]),
               "neg_mean": np.mean(train_margin[train_margin < 0]),
               "neg_std": np.std(train_margin[train_margin < 0]),
               "q_90": np.quantile(train_margin, 0.9),
               "q_10": np.quantile(train_margin, 0.1)
               }
# print test accuracy
derma_concepts["derma-hair"] = concept_info

In [None]:
with open(args.concept_bank_path, "rb") as f:
    concept_bank = pickle.load(f)

for derm_concept, info in derma_concepts.items():
    print(derm_concept)
    assert derm_concept not in concept_bank
    concept_bank[derm_concept] = info

with open("./derma_concepts_resnet_new.pkl", "wb") as f:
    pickle.dump(concept_bank, f)

In [None]:
import seaborn as sns
sns.set_context("poster")
corruptions = {"zoom_blur": 2,
              "contrast": 2,
              "defocus_blur": 3,
              "brightness": 3,
              "motion_blur": 3}
df1 = df.groupby("fitzpatrick").sample(n=30, random_state=3).sample(args.n_samples*2)


fig, axs = plt.subplots(len(corruptions.keys()), 5, figsize=(20, 4*len(corruptions)))
img_name = os.path.join(args.data_dir, f"{df1.loc[df1.index[1], 'hasher']}.jpg")
image = io.imread(img_name)
for i, corruption in enumerate(corruptions.keys()):
    for j, severity in enumerate(range(1, 6)):
        ax = axs[i, j]
        ax.axis("off")
        if severity == 1:
            ax.imshow(image)
            ax.set_title(f"Original")
        else:
            corr_img = corrupt(image, severity=severity, corruption_name=corruption)
            ax.imshow(corr_img)
            ax.set_title(f"{corruption}:{severity}")
fig.tight_layout()
fig.savefig("./concept_demonstrations.png")
plt.close()