In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import models, transforms
import matplotlib.pyplot as plt

import os
import pickle
import requests
import PIL
from tqdm import tqdm
import sys
sys.path.append("..")
from model_utils import get_model_parts
from argparse import Namespace 
from concept_utils import ConceptBank
from PIL import Image


print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
data_root = '/path/metadataset/MetaDataset/subsets'
experiment_root = '/path/outputs/conceptualexplanations/metadataset/resnet_mv_scale_50'

args = Namespace()
args.model_name = "squeezenet"
args.input_size = 224
args.batch_size = 8
args.num_epochs = 5
args.SEED = 4
args.num_classes = 1000
args.feature_extract = True
args.num_workers = 4
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.bank_path = '/path/conceptual-explanations/banks/concept_squeezenet_170.pkl'
mean_pxs = np.array([0.485, 0.456, 0.406])
std_pxs = np.array([0.229, 0.224, 0.225])
data_transforms = transforms.Compose([
    transforms.Resize(args.input_size),
    transforms.CenterCrop(args.input_size),
    transforms.ToTensor(),
    transforms.Normalize(mean_pxs, std_pxs)
])

In [None]:
all_concepts = pickle.load(open(args.bank_path, 'rb'))
concept_bank = ConceptBank(all_concepts, args.device)


# Set up the Model

In [None]:
# Get the model 
model_ft = models.squeezenet1_0(pretrained=True)
# Model to test
model_bottom, model_top = get_model_parts(model_ft, args.model_name)
model_bottom.eval()
model_top.eval()
model_bottom = model_bottom.to(args.device)
model_top = model_top.to(args.device)

# Evaluation Methods

In [None]:
response = requests.get("https://git.io/JJkYN")
class_labels = response.text.split("\n")

In [None]:

def rgb_to_hsv(rgb):
    # Translated from source of colorsys.rgb_to_hsv
    # r,g,b should be a numpy arrays with values between 0 and 255
    # rgb_to_hsv returns an array of floats between 0.0 and 1.0.
    rgb = rgb.astype('float')
    hsv = np.zeros_like(rgb)
    # in case an RGBA array was passed, just copy the A channel
    hsv[..., 3:] = rgb[..., 3:]
    r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
    maxc = np.max(rgb[..., :3], axis=-1)
    minc = np.min(rgb[..., :3], axis=-1)
    hsv[..., 2] = maxc
    mask = maxc != minc
    hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask]
    rc = np.zeros_like(r)
    gc = np.zeros_like(g)
    bc = np.zeros_like(b)
    rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask]
    gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask]
    bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask]
    hsv[..., 0] = np.select(
        [r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc)
    hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0
    return hsv

def hsv_to_rgb(hsv):
    # Translated from source of colorsys.hsv_to_rgb
    # h,s should be a numpy arrays with values between 0.0 and 1.0
    # v should be a numpy array with values between 0.0 and 255.0
    # hsv_to_rgb returns an array of uints between 0 and 255.
    rgb = np.empty_like(hsv)
    rgb[..., 3:] = hsv[..., 3:]
    h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
    i = (h * 6.0).astype('uint8')
    f = (h * 6.0) - i
    p = v * (1.0 - s)
    q = v * (1.0 - s * f)
    t = v * (1.0 - s * (1.0 - f))
    i = i % 6
    conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5]
    rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v)
    rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t)
    rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p)
    return rgb.astype('uint8')

def shift_hue(arr,hout):
    hsv=rgb_to_hsv(arr)
    hsv[...,0]=hout
    rgb=hsv_to_rgb(hsv)
    return rgb

def colorize(image, hue):
    arr = np.array(image)
    arr_ = shift_hue(arr, hue)
    image_ = Image.fromarray(arr_)
    return image_

def get_concept_scores_mv_valid(tensor, labels, concept_bank, model_bottom, model_top, 
                                alpha=1e-4, beta=1e-4, n_steps=100,
                                lr=1e-1, momentum=0.9, enforce_validity=True):
    
    max_margins = concept_bank.margin_info.max
    min_margins = concept_bank.margin_info.min
    concept_norms = concept_bank.norms
    concept_intercepts = concept_bank.intercepts
    concepts = concept_bank.bank
    concept_names = concept_bank.concept_names.copy()
    device = tensor.device
    embedding = model_bottom(tensor)
    embedding = embedding.detach()
    criterion = nn.CrossEntropyLoss()
    W = nn.Parameter(torch.zeros(1, concepts.shape[0], device=device), requires_grad=True)
    
    # Normalize the concept vectors
    normalized_C = max_margins * concepts / concept_norms
    
    # Compute the current distance of the sample to decision boundaries of SVMs
    margins = (torch.matmul(concepts, embedding.T) + concept_intercepts) / concept_norms
    
    # Computing constraints for the concepts scores
    W_clamp_max = (max_margins*concept_norms - concept_intercepts - torch.matmul(concepts, embedding.T))
    W_clamp_min = (min_margins*concept_norms - concept_intercepts - torch.matmul(concepts, embedding.T))
    W_clamp_max = (W_clamp_max / (max_margins * concept_norms)).detach().T
    W_clamp_min = (W_clamp_min / (max_margins * concept_norms)).detach().T
    
    if enforce_validity:
        W_clamp_max[(margins > 0).T] = 0.
        W_clamp_min[(margins < 0).T] = 0.
    
    optimizer = optim.SGD([W], lr=lr, momentum=momentum)
    history = []
    es = n_steps
    for i in range(n_steps):
        optimizer.zero_grad()
        new_embedding = embedding + torch.matmul(W, normalized_C)
        new_out = model_top(new_embedding)
        l1_loss = torch.norm(W, dim=1, p=1)
        l2_loss = torch.norm(W, dim=1, p=2)
        ce_loss = criterion(new_out, labels)
        loss = ce_loss + l1_loss*alpha + l2_loss*beta

        #print(loss.item(), ce_loss.item(), l1_loss.item(), l2_loss.item())
        loss.backward()
        pred = new_out.argmax(dim=1).detach().item()
        history.append(f"{pred}, {ce_loss.item()}, {l1_loss.item()}, {l2_loss.item()}, {W[0, 0]}, {W.grad[0, 0]}, {W.sum()}")
        optimizer.step()
        if enforce_validity:
            W_projected = torch.where(W < W_clamp_min, W_clamp_min, W)
            W_projected = torch.where(W > W_clamp_max, W_clamp_max, W_projected)
            W.data = W_projected
            W.grad.zero_()
    
    final_emb = embedding + torch.matmul(W, normalized_C)
    W = W[0].detach().cpu().numpy().tolist()
    
    concept_scores = dict()
    for i, n in enumerate(concept_names): 
        concept_scores[n] = W[i] 
    concept_names = sorted(concept_names, key=concept_scores.get, reverse=True)  
    
    new_out, orig_out = model_top(final_emb), model_top(embedding)
    if (new_out.argmax(dim=1) == labels):
        success = True
    else:
        success = False
    return success, concept_scores, concept_names, np.array(W)


In [None]:
img = PIL.Image.open("green_apple.jpeg").convert("RGB")
img_ = PIL.Image.open("green_apple.jpeg").convert("L").convert("RGB")

In [None]:
preds = []
ces = []
images = []

alphas = np.concatenate([np.linspace(0., 0.3, 16), np.linspace(0.3, 0.6, 4)])
#alphas = np.linspace(0.1, 0.5, 15)

labels = torch.tensor(class_labels.index("Granny Smith")).long().view(1).to(args.device) 

for alpha in tqdm(alphas):
    average_img = PIL.Image.fromarray(np.array(alpha*np.array(img) + (1-alpha)*np.array(img_), dtype=np.uint8))
    images.append(average_img)
    tensor = data_transforms(average_img).unsqueeze(0).to(args.device)
    success, concept_scores, concept_scores_list, W_old = get_concept_scores_mv_valid(tensor, labels, 
                                                                                      concept_bank, 
                                                                                      model_bottom, model_top,
                                                                                      alpha=1e-2, beta=1e-1, lr=1e-2)

    pred = model_top(model_bottom(tensor)).detach().cpu().numpy()[0, class_labels.index("Granny Smith")]
    preds.append(pred)
    ces.append(concept_scores['greenness'])

In [None]:
plt.imshow(images[0])
plt.axis("off")
plt.savefig("./paper_figures/fig4_gray0.pdf")
plt.savefig("./paper_figures/fig4_gray0.png")
plt.close()
plt.imshow(images[len(images)//2])
plt.axis("off")
plt.savefig("./paper_figures/fig4_gray_half.png")
plt.savefig("./paper_figures/fig4_gray_half.pdf")
plt.close()
plt.imshow(images[-1])
plt.axis("off")
plt.savefig("./paper_figures/fig4_gray_1.pdf")
plt.savefig("./paper_figures/fig4_gray_1.png")
plt.close()


In [None]:
plt.figure(figsize=[7, 5])
plt.plot(np.linspace(0,1,len(alphas)), np.array(ces)[::-1], marker='o', color='green', label='\'Greenness\' CCE')
plt.plot(np.linspace(0,1,len(alphas)), preds[::-1], marker='o', color='black', label='\'Granny Smith\' prob predicted ')

plt.yticks(fontname='Arial', fontsize=18)
plt.xticks(fontname='Arial', fontsize=16)
plt.xlabel('Degree of perturbation', fontname='Arial', fontsize=18)
plt.legend(prop={'family':'Arial', 'size':16}, loc="upper right")
plt.savefig("./paper_figures/fig4_low_level_img.png")
plt.savefig("./paper_figures/fig4_low_level_img.pdf")

# 50 Images

In [None]:

from google_images_download import google_images_download   #importing the library
response = google_images_download.googleimagesdownload()   #class instantiation
arguments = {"keywords":"granny smith apple","limit":25,"print_urls":True, "size": "medium",
            "metadata":True}   #creating list of arguments
paths = response.download(arguments)   #passing the arguments to the function

print(paths)

In [None]:
paths = [os.path.join("./downloads/granny smith apple", f) for f in os.listdir("./downloads/granny smith apple/")]

In [None]:
ces_scores = []
#img_paths = paths[0]['granny smith apple']
img_paths = paths
labels = torch.tensor(class_labels.index("Granny Smith")).long().view(1).to(args.device) 
for path in tqdm(img_paths[2:]):
    img = PIL.Image.open(path).convert("RGB")
    img_ = PIL.Image.open(path).convert("L").convert("RGB")
    #plt.imshow(img_)
    ces_img = []
    alphas = np.concatenate([np.linspace(0., 0.3, 16), np.linspace(0.3, 0.6, 4)])
    for alpha in alphas:
        average_img = PIL.Image.fromarray(np.array(alpha*np.array(img) + (1-alpha)*np.array(img_), dtype=np.uint8))
        tensor = data_transforms(average_img).unsqueeze(0).to(args.device)
        success, concept_scores, concept_scores_list, W_old = get_concept_scores_mv_valid(tensor, labels, 
                                                                                          concept_bank, 
                                                                                          model_bottom, model_top,
                                                                                          alpha=0., beta=1e-2, lr=1.,
                                                                                          enforce_validity=True)
        
        pred = model_top(model_bottom(tensor)).detach().cpu().numpy()[0, class_labels.index("Granny Smith")]
        ces_img.append(concept_scores['greenness'])
    ces_scores.append(ces_img)

In [None]:
plt.figure(figsize=[7, 5])

ces_normalized = []
for k in range(len(ces_scores)):
    img_ces = np.array(ces_scores[k])
    normalized_ces = img_ces
    ces_normalized.append(normalized_ces)
    plt.plot(np.linspace(0, 1, len(alphas)), np.flip(normalized_ces), color='gray', lw=1)

    
plt.plot(np.linspace(0, 1, len(alphas)), np.mean(np.flip(np.array(ces_normalized), axis=1), axis=0), color='green', lw=4, marker='o')


plt.yticks(fontname='Arial', fontsize=18)
plt.xticks(fontname='Arial', fontsize=16)
plt.ylabel('Greenness CCE', fontname='Arial', fontsize=18)
plt.xlabel('Degree of perturbation', fontname='Arial', fontsize=18)
plt.savefig("./paper_figures/fig4_low_level.png")
plt.savefig("./paper_figures/fig4_low_level.pdf")