# CelebA

## Download using `torchvision.datasets.CelebA` e.g. in `post_hoc_celeba.py`
 - or else download with these links
     - https://drive.google.com/file/d/0B7EVK8r0v71pZjFTYXZWM3FlRnM/view?usp=sharing
     - https://drive.google.com/file/d/0B7EVK8r0v71pd0FJY3Blby1HUTQ/view?usp=sharing

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
from os import listdir
from os.path import isfile, join
from pathlib import Path
from PIL import Image
import cv2

# Data exploration

In [None]:
def load(n=100, folder='~/post_hoc_debiasing/data/celeba/img_align_celeba/'):
    # convert the folder of images into a numpy array
    
    data = []
    num = 0
    for i in range(1,n+1):
        file = str(i).zfill(6)+'.jpg'
        img = Image.open(join(os.path.expanduser(folder), file))
        img = np.array(img)
        data.append(img)

    data = np.array(data)
    return data

def plot(data, n):
    # quick plotting method
    plt.figure(figsize=(20,10))
    columns = n
        
    for i in range(n):
        plt.subplot(1, columns, i + 1)
        img = data[i]
        img = img.astype(int)
        plt.axis('off')
        plt.imshow(img)

def load_attrs(file='~/post_hoc_debiasing/data/celeba/list_attr_celeba.txt', max_n=-1):
    # parse the features
    f = open(os.path.expanduser(file), "r")
    attrs = []
    descriptions = []
    num_attrs = 0
    n = 0
    for index,line in enumerate(f):
    
        #the first row is the header
        if index == 0:
            n = line
        elif index == 1:
            descriptions = [*line.split()]
            num_attrs = len(line.split())
        elif index == max_n:
            break
        else:
            attr = [int(num) for i, num in enumerate(line.split()) if i>0]
            attrs.append(attr)
        
    attrs = np.array(attrs)
    print(attrs.shape)
    return attrs, descriptions

In [None]:
# load all the data
data = load(n=5000) # 202599
print(data.shape)
attrs, descriptions = load_attrs()

In [None]:
# check the attributes are correct
print(descriptions)
for i in range(3):
    plt.imshow(data[i])
    plt.show()
    for attr in ['Male', 'Attractive', 'Smiling', 'Pale_Skin']:
        print(attr, attrs[i][descriptions.index(attr)])

In [None]:
# check features
print(descriptions)
attr = 'Goatee'
inds = [i for i in range(1000) if attrs[i][descriptions.index(attr)]==1]
plot([data[i] for i in inds[8:16]], 8)

## Load from torch

In [None]:
import math

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from sklearn.metrics import accuracy_score, roc_auc_score, roc_curve
from torchvision import models, transforms

In [None]:
def load_celeba(num_workers=2):
    transform = transforms.ToTensor()

    trainset = torchvision.datasets.CelebA(root='./data', download=True, split='train', transform=transform)
    print(len(trainset))

    trainset, valset = torch.utils.data.random_split(trainset, [int(len(trainset)*0.7), int(len(trainset)*0.3)])
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=num_workers)
    valloader = torch.utils.data.DataLoader(valset, batch_size=4,
                                            shuffle=True, num_workers=num_workers)

    testset = torchvision.datasets.CelebA(root='./data', split='test',
                                                download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                             shuffle=False, num_workers=num_workers)
    return trainset, valset, testset, trainloader, valloader, testloader

In [None]:
trainset, valset, tetstset, trainloader, valloader, testloader = load_celeba()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
    
def get_single_attr(labels, attr='Attractive'):
    descriptions = ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', \
                    'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', \
                    'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', \
                    'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', \
                    'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', \
                    'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', \
                    'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', \
                    'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', \
                    'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', \
                    'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', \
                    'Young']
    print(labels.shape)
    attrs = []
    for i in range(len(labels)):
        attrs.append(labels[i][descriptions.index(attr)])
    attrs = torch.from_numpy(np.array(attrs))
    print(attrs.shape)
    return attrs
    
for i, data in enumerate(trainloader, 0):
    # get the inputs; data is a list of [inputs, labels]
    inputs, labels = data[0].to(device), data[1].to(device)
    img = inputs[0]
    label = labels[0]
    labels = get_single_attr(labels)
    if label[descriptions.index(attr)]==1:
        imshow(img)
    if i > 1:
        break
