# Visualize debiasing experiments on CelebA

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import os
from os.path import join
import matplotlib.pyplot as plt
import torch

from post_hoc_celeba import load_celeba, get_resnet_model
from PIL import Image

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

descriptions = ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive',
                'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',
                'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair',
                'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses',
                'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
                'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes',
                'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose',
                'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling',
                'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat',
                'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie',
                'Young', 'White', 'Black', 'Asian', 'Index']

def sigmoid(x):
    return 1/(1 + np.exp(-x)) 

In [None]:
def image_from_index(index, folder='~/post_hoc_debiasing/data/celeba/img_align_celeba/', show=False):
    # given the index of the image, output the image
    file = str(index).zfill(6)+'.jpg'
    img = Image.open(join(os.path.expanduser(folder), file))
    if show:
        plt.imshow(img)
        plt.show()
    return img
    
def imshow_group(imgs, n):
    # plot multiple images at once
    plt.figure(figsize=(20,10))
    columns = n
        
    for i in range(n):
        plt.subplot(1, columns, i + 1)
        img = imgs[i]
        #img = img.astype(int)
        plt.axis('off')
        plt.imshow(img)

In [None]:
def output_debiased_imgs(biased_net,
                         debiased_net,
                         loader,
                         protected_attr,
                         prediction_attr):
    """
    Display images along with their biased and debiased predictions
    """    
    prediction_index = descriptions.index(prediction_attr)
    protected_index = descriptions.index(protected_attr)
    ind = descriptions.index('Index')

    outputs = []
    total_batches = len(loader)
    for batch_num, (inputs, labels) in enumerate(loader):
        inputs, labels = inputs.to(device), labels.to(device)
        biased_outputs = biased_net(inputs)[:, 0]
        debiased_outputs = debiased_net(inputs)[:, 0]

        for i in range(len(inputs)):
            img = image_from_index(labels[i][ind].item())
            label = labels[i][prediction_index].item()
            protected = labels[i][protected_index].item()
            biased_output = sigmoid(biased_outputs[i].item())
            debiased_output = sigmoid(debiased_outputs[i].item())            

            outputs.append([img, label, protected, biased_output, debiased_output])

        if batch_num % 10 == 0:
            print('At', batch_num, '/', total_batches)

    return outputs

In [None]:
# load the test set
_, _, _, _, _, testloader = load_celeba(trainsize=0, 
                                          testsize=100, 
                                          num_workers=0, 
                                          batch_size=32,
                                          transform_type='tensor')

biased_model_path = 'models/by_random_checkpoint.pt'
debiased_model_path = 'models/by_checkpoint.pt'

# load the biased and unbiased models
biased_net = get_resnet_model()
biased_net.load_state_dict(torch.load(biased_model_path, map_location=device))

debiased_net = get_resnet_model()
debiased_net.load_state_dict(torch.load(debiased_model_path, map_location=device)['model_state_dict'])

# output images
outputs = output_debiased_imgs(biased_net=biased_net,
                                debiased_net=debiased_net,
                                loader=testloader,
                                protected_attr = 'Black',
                                prediction_attr = 'Smiling')
imgs = [output[0] for output in outputs]

In [None]:
rowsize = 8
for i in range(min(len(imgs)//rowsize, 5)):
    imshow_group(imgs[rowsize*i:rowsize*(i+1)], rowsize)