In [1]:
%load_ext autoreload
%autoreload 2

In [17]:
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from torchvision.utils import save_image
import torchvision
import numpy as np

from data_loader import *
from model import Generator
from model_conv import BlondClassifier, GenderClassifier

In [None]:
img_dir = '/mnt/data/10708-controllable-generation/data/celeba/img_align_celeba'
attr_fp = '/mnt/data/10708-controllable-generation/data/celeba/list_attr_celeba.txt'
stargan_fp = 'stargan_celeba_128/models/200000-G.ckpt'
male_classifier_fp = '../male_classifier/saved_models/epoch5.pth'
blonde_classifier_fp = 'Blond.pt'

In [3]:
celeba_loader = get_loader(img_dir, attr_fp, 
                           ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'],
                           178, 128, 32,
                           'CelebA', 'test', 2)

Finished preprocessing the CelebA dataset...


In [4]:
def denorm(x):
    """Convert the range from [-1, 1] to [0, 1]."""
    out = (x + 1) / 2
    return out.clamp_(0, 1)

In [9]:
def mean_psnr(recon_imgs, imgs):
    recon_imgs = recon_imgs.reshape((recon_imgs.shape[0], -1, 3))
    imgs = imgs.reshape((recon_imgs.shape[0], -1, 3))
    channel_wise_mse = ((imgs - recon_imgs) ** 2).mean(axis=1)
    sample_wise_mse = channel_wise_mse.mean(axis=1)
    mse = sample_wise_mse.mean()
    mean_psnr = 10*np.log10(255*255/mse)
    return mean_psnr

In [10]:
def mean_rmse(recon_imgs, imgs):
    recon_imgs = recon_imgs.reshape((recon_imgs.shape[0], -1, 3))
    imgs = imgs.reshape((recon_imgs.shape[0], -1, 3))
    channel_wise_rmse = ((imgs - recon_imgs) ** 2).mean(axis=1) ** 0.5
    sample_wise_rmse = channel_wise_rmse.mean(axis=1)
    rmse = sample_wise_rmse.mean()
    return rmse

## Blondes v/s Blacks

### Get StarGAN preds

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
generator = Generator(64, 5, 6).to(device)
generator.load_state_dict(torch.load(stargan_fp, map_location=device))
# _ = generator.eval()

<All keys matched successfully>

In [None]:
blondes = []
blacks = []
blondes_orig = []
blacks_orig = []
for images, labels in tqdm(celeba_loader):
    labels[:, 0] = 1 - labels[:, 0]
    labels[:, 1] = 1 - labels[:, 1]
    black_idxs = (labels[:, 0] == 1).nonzero()
    blonde_idxs = (labels[:, 1] == 1).nonzero()
    with torch.no_grad():
        outs = generator(images, labels).detach()
    blondes.append(outs[blonde_idxs])
    blacks.append(outs[black_idxs])
    
    labels[:, 0] = 1 - labels[:, 0]
    labels[:, 1] = 1 - labels[:, 1]
    with torch.no_grad():
        outs = generator(images, labels).detach()
    blondes_orig.append(outs[blonde_idxs])
    blacks_orig.append(outs[black_idxs])

In [160]:
blo = torch.cat(blondes, dim=0)
bla = torch.cat(blacks, dim=0)
ori_blo = torch.cat(blondes_orig, dim=0)
ori_bla = torch.cat(blacks_orig, dim=0)

In [7]:
# torch.save(blo, 'blondes.pth')
# torch.save(bla, 'blacks.pth')
# torch.save(ori_blo, 'blondes_orig.pth')
# torch.save(ori_bla, 'blacks_orig.pth')

blo = torch.load('blondes.pth')
bla = torch.load('blacks.pth')
ori_blo = torch.load('blondes_orig.pth')
ori_bla = torch.load('blacks_orig.pth')

### Evaluate

In [33]:
blond_classifier = BlondClassifier()
blond_classifier.load_state_dict(torch.load(blonde_classifier_fp, map_location='cpu'))
_ = blond_classifier.eval()

In [8]:
blo_resized = torchvision.transforms.Resize(64)(blo.squeeze())
bla_resized = torchvision.transforms.Resize(64)(bla.squeeze())
ori_blo = torchvision.transforms.Resize(64)(ori_blo.squeeze())
ori_bla = torchvision.transforms.Resize(64)(ori_bla.squeeze())

In [10]:
loader_blo = torch.utils.data.DataLoader(list(zip(blo_resized, ori_blo)), batch_size=32, shuffle=False)
loader_bla = torch.utils.data.DataLoader(list(zip(bla_resized, ori_bla)), batch_size=32, shuffle=False)

In [40]:
total_cnt = 0
correct = 0

In [41]:
for batch, batch_ori in loader_blo:
    blo_preds = blond_classifier(batch)
    blo_classes = blo_preds.argmax(dim=1)
    total_cnt += blo_classes.shape[0]
    correct += (blo_classes == 1).sum()

In [42]:
for batch, batch_ori in loader_bla:
    bla_preds = blond_classifier(batch)
    bla_classes = bla_preds.argmax(dim=1)
    total_cnt += bla_classes.shape[0]
    correct += (bla_classes == 0).sum()

In [43]:
correct / total_cnt

tensor(0.6235)

In [11]:
for batch, batch_ori in loader_blo:
    break

In [23]:
save_image(torch.cat((denorm(batch_ori[[4,1]]), denorm(batch[[4,1]])), dim=0), 'tmp.png', nrow=2)

In [8]:
recon = []
orig = []
for images, labels in tqdm(celeba_loader):
    labels[:, 0] = 1 - labels[:, 0]
    labels[:, 1] = 1 - labels[:, 1]
    with torch.no_grad():
        outs = generator(images.to(device), labels.to(device)).detach()
    labels[:, 0] = 1 - labels[:, 0]
    labels[:, 1] = 1 - labels[:, 1]
    with torch.no_grad():
        outs = generator(outs.to(device), labels.to(device)).detach()
    orig.append(images)
    recon.append(outs)

100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [00:37<00:00,  1.68it/s]


In [15]:
recon_all = torch.cat(recon, dim=0).to(device)
orig_all = torch.cat(orig, dim=0).to(device)

In [18]:
mean_psnr(recon_all.cpu().numpy(), orig_all.cpu().numpy())

69.44711589072598

## Males v/s Females

### Get StarGAN preds

In [19]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
generator = Generator(64, 5, 6).to(device)
generator.load_state_dict(torch.load(stargan_fp, map_location=device))
# _ = generator.eval()

<All keys matched successfully>

In [10]:
males = []
females = []
males_orig = []
females_orig = []
for images, labels in tqdm(celeba_loader):
    labels[:, -2] = 1 - labels[:, 0]
    male_idxs = (labels[:, -2] == 1).nonzero()
    female_idxs = (labels[:, -2] == 0).nonzero()
    with torch.no_grad():
        outs = generator(images.to(device), labels.to(device)).detach()
    males.append(outs[male_idxs])
    females.append(outs[female_idxs])
    
    labels[:, -2] = 1 - labels[:, -2]
    with torch.no_grad():
        outs = generator(images.to(device), labels.to(device)).detach()
    males_orig.append(outs[male_idxs])
    females_orig.append(outs[female_idxs])

100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [00:39<00:00,  1.61it/s]


In [11]:
ma = torch.cat(males, dim=0)
fe = torch.cat(females, dim=0)
ori_ma = torch.cat(males_orig, dim=0)
ori_fe = torch.cat(females_orig, dim=0)

In [20]:
# torch.save(ma, 'males.pth')
# torch.save(fe, 'females.pth')
# torch.save(ori_ma, 'males_orig.pth')
# torch.save(ori_fe, 'females_orig.pth')

ma = torch.load('males.pth')
fe = torch.load('females.pth')
ori_ma = torch.load('males_orig.pth')
ori_fe = torch.load('females_orig.pth')

### Evaluate

In [16]:
male_classifier = torch.load(male_classifier_fp).to(device)
_ = male_classifier.eval()

In [21]:
ma_resized = torchvision.transforms.Resize(64)(ma.squeeze())
fe_resized = torchvision.transforms.Resize(64)(fe.squeeze())
ori_ma = torchvision.transforms.Resize(64)(ori_ma.squeeze())
ori_fe = torchvision.transforms.Resize(64)(ori_fe.squeeze())

In [22]:
loader_ma = torch.utils.data.DataLoader(list(zip(ma_resized, ori_ma)), batch_size=32, shuffle=False)
loader_fe = torch.utils.data.DataLoader(list(zip(fe_resized, ori_fe)), batch_size=32, shuffle=False)

In [20]:
total_cnt = 0
correct = 0

In [21]:
for batch, batch_ori in loader_ma:
    ma_preds = male_classifier(batch)
    ma_classes = ma_preds.argmax(dim=1)
    total_cnt += ma_classes.shape[0]
    correct += (ma_classes == 1).sum()

In [22]:
for batch, batch_ori in loader_fe:
    fe_preds = male_classifier(batch)
    fe_classes = fe_preds.argmax(dim=1)
    total_cnt += fe_classes.shape[0]
    correct += (fe_classes == 0).sum()

In [23]:
correct / total_cnt

tensor(0.7449, device='cuda:0')

In [28]:
for batch, batch_ori in loader_ma:
    break

In [29]:
save_image(torch.cat((denorm(batch_ori[[4,1]]), denorm(batch[[4,1]])), dim=0), 'tmp.png', nrow=2)

In [26]:
recon = []
orig = []
for images, labels in tqdm(celeba_loader):
    labels[:, -2] = 1 - labels[:, -2]
    with torch.no_grad():
        outs = generator(images.to(device), labels.to(device)).detach()
    labels[:, -2] = 1 - labels[:, -2]
    with torch.no_grad():
        outs = generator(outs.to(device), labels.to(device)).detach()
    orig.append(images)
    recon.append(outs)

100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [00:37<00:00,  1.70it/s]


In [27]:
recon_all = torch.cat(recon, dim=0).to(device)
orig_all = torch.cat(orig, dim=0).to(device)

In [28]:
mean_psnr(recon_all.cpu().numpy(), orig_all.cpu().numpy())

70.04880108580704