In [1]:
import numpy as np
from PIL import Image
import copy
import os
import torch
import random
from codes.util.check_dir import mk_dir
import matplotlib.pyplot as plt
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
IMG_SIZE = 224
SUPER_SIZE = 4
BATCH_SIZE = 56
SUPER_IMGSIZE = IMG_SIZE // SUPER_SIZE
def pos_transform(ids):   
    comps = []
    assert (len(ids.shape) == 2)
    for idx in ids:        
        # print (idx)
        row, col = idx//SUPER_IMGSIZE, idx%SUPER_IMGSIZE
        idx = IMG_SIZE * row * SUPER_SIZE + col*SUPER_SIZE
        comp = []
        for i in range(SUPER_SIZE):
            for j in range(SUPER_SIZE):
                comp.append(idx+i*IMG_SIZE+j)
        comp = np.concatenate(comp)
        comps.append(comp)
    return np.array(comps)

def get_bound(img):
    vmin, vmax = img.min(), img.max()
    bd = max(np.abs(vmin), np.abs(vmax))
    return bd
def pixel_shap(path):
    shap = np.genfromtxt(path, delimiter=',')[:,0]
    ids = np.genfromtxt(path, delimiter=',')[:,1].astype(int)[:,np.newaxis]
    # row, col = ids//SUPER_IMGSIZE, ids%SUPER_IMGSIZE
    # ids = IMG_SIZE * row * SUPER_SIZE + col*SUPER_SIZE
    ids = pos_transform(ids)
    val = np.zeros(IMG_SIZE**2)
    for i in range(ids.shape[0]):   
        val[ids[i]] = shap[i]
    draw = val.reshape((IMG_SIZE,IMG_SIZE))
    return draw
def gaussian_prob(x, y):   
    def distance(x, y):
        m, d = x.shape
        n, _ = y.shape        
        x = torch.from_numpy(x).cuda()
        y = torch.from_numpy(y).cuda()
        x = torch.unsqueeze(x, dim=1).expand(m, n, d)
        y = torch.unsqueeze(y, dim=0).expand(m, n, d)
        dist = torch.sum((x - y) ** 2, dim=2).float()
#         print (x.shape, y.shape, dist.shape)       
        return dist
    dist = distance(x, y)
    prob = 1 / (2*np.pi) * torch.exp(-dist/2)
    prob = prob.cpu().numpy()
    return np.sqrt(prob)
def batch_compute(cluster, batch_size=BATCH_SIZE**2):
    img = np.argwhere(np.zeros((IMG_SIZE, IMG_SIZE)) == 0)
    prob = np.zeros(IMG_SIZE**2)
    batch_num = IMG_SIZE**2 // batch_size + 2
    for batch_id in range(batch_num):
        batch_start = batch_id * batch_size
        batch_end = min(IMG_SIZE**2, (batch_id+1)*batch_size)
        batch = img[batch_start: batch_end, :]
        batch_prob = gaussian_prob(batch, cluster).sum(1)        
        prob[batch_start: batch_end] = batch_prob
    return prob
def cluster_color(ids, shap, rgbs):    
    cnt = 0
    tshap = shap.reshape(-1)
    color_map = np.ones((IMG_SIZE**2,3))
    shap_color = []
    color_id = -1*np.ones(IMG_SIZE**2)
    smap = np.ones(IMG_SIZE**2)
    for i in range(ids.shape[0]): 
        color_id[ids[i]] = cnt
        shap_color.append(np.abs(tshap[ids[i]]).mean())       
        cnt = cnt+1
    return color_id.astype(int), np.array(shap_color)
def prepare_color(cluster_ids, shap, rgbs):
    color_id, shap_color = cluster_color(cluster_ids, shap, rgbs)
    probs = []
    for clus in range(0,color_id.max()+1,1):
        ids = np.argwhere(color_id==clus).squeeze()    
        if ids.size == 0:
            continue
        coords = np.array([ids//IMG_SIZE,ids%IMG_SIZE])
        if ids.size == 1:
            coords = coords.reshape(1,2)
        else:
            coords = coords.transpose(1,0)   
#         print (coords.shape)
        probs.append(batch_compute(coords))
    probs = np.array(probs)
    colors = np.argmax(probs,0).squeeze().astype(int)
    bright = np.max(probs,0).squeeze()    
    return colors, bright, shap_color
def channel(img, ax):
    if img.shape[-1] != 3:
        img = img.transpose(1,2,0)
    img = img/np.abs(img).max() + 0.5
    ax.imshow(img)
    ax.axis('off')

In [7]:
setting, model, data = 'pretrained_normal','res18', 'tiny_cub'
img_root = f'./data/interaction/{data}/img'
ptb_root = f'./experiment/interaction/perturb/{setting}/{data}_{model}'
csv_root = f'./experiment/interaction/csv_4/{setting}/{data}_{model}'
gcolors =  [[65,105,225],[34,139,34],[227,23,13],[187,160,105],[135,206,235],[100,255,0],[255,192,203],[240,230,140]]
gcolors = 255 - np.array(gcolors)
out_root = f'./experiment/interaction/show/{setting}/{data}_{model}'
mk_dir(out_root)

cnt = 0
for cate in os.listdir(csv_root): 
    # if '080' not in cate:
    #     continue
    for fname in os.listdir(f'{csv_root}/{cate}'):
        # try:           
        prob_path = f'./experiment/interaction/prob/{setting}/{data}_{model}/{cate}/{fname}'
        mk_dir(prob_path)
        fig, ax = plt.subplots(1,3,figsize=(6,2))
        if data != 'celeba':
            img = Image.open(f'{img_root}/{cate}/{fname}').resize((IMG_SIZE,IMG_SIZE)) 
        else:
            img = Image.open(f'{img_root}/{cate}').resize((IMG_SIZE,IMG_SIZE))
        ax[0].imshow(img)
        ax[0].axis('off')                
        adv_info = np.load(f'{ptb_root}/{cate}/{fname}/ptb.npy', allow_pickle=True).item()
        ptb, target = adv_info['ptb'], adv_info['tgt']
        channel(ptb,ax[1])
        shap = pixel_shap(f'{csv_root}/{cate}/{fname}/backshap1.csv')       
        tgt = 64   
        print (cate, fname, target)
        # --------shap_color: v, bright:prob, whether to show, colors: color_id --------------- #
        path = f'{csv_root}/{cate}/{fname}/clusback{tgt}.csv' 
        if not os.path.exists(path):
            print ('Cluster not finish')
            continue
            
        cluster_ids = np.genfromtxt(path, delimiter=',')
        if len(cluster_ids.shape) >= 2:
            cluster_ids = cluster_ids[:,1:].astype(int)
        else:
            cluster_ids = cluster_ids[np.newaxis,:][:,1:].astype(int)
 
        cluster_ids = pos_transform(cluster_ids)
        colors, bright, shap_color = prepare_color(cluster_ids, shap, gcolors)  
        np.save(f'{prob_path}/idcolor_{tgt}.npy', colors)
        np.save(f'{prob_path}/bright_{tgt}.npy', bright)
        np.save(f'{prob_path}/scolor_{tgt}.npy', shap_color)
 
        bright = bright ** 0.01
       
        shap_color = shap_color / np.linalg.norm(shap_color)
        shap_color[shap_color>0.5] = 0      
        k = random.randint(0,3)
        for i in range(shap_color.shape[0]):          
            if shap_color[i]>0:
                colors[colors==i] = -((k+i)%4+4)               
            else:
                colors[colors==i] = -((i+k)%4)                
        colors = np.abs(colors)  
        show = gcolors[colors] 
        bright[bright< bright.mean()] = 0
        bright[bright>= bright.mean()] = 1    
        show = (show * bright[:,np.newaxis]).reshape(IMG_SIZE,IMG_SIZE,3)
        ax[2].imshow((255-show).astype(int))      
        ax[2].axis('off') 
        
        plt.subplots_adjust(wspace=0,hspace=0)   
      
        if data != 'celeba':
            fig.savefig(f'{out_root}/{fname[:-4]}.png', dpi=500, bbox_inches='tight',pad_inches=0)
        else:
            fig.savefig(f'{out_root}/{cate[:-4]}.png', dpi=500, bbox_inches='tight',pad_inches=0)
        plt.close()
    

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
160.Black_throated_Blue_Warbler Black_Throated_Blue_Warbler_0085_161621.jpg 3
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
160.Black_throated_Blue_Warbler Black_Throated_Blue_Warbler_0064_161656.jpg 3
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
160.Black_throated_Blue_Warbler Black_Throated_Blue_Warbler_0079_161194.jpg 5
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
100.Brown_Pelican Brown_Pelican_0122_94022.jpg 2
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
100.Brown_Pelican Brown_Pelican_0095_94290.jpg 5
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
100.Brow