In [12]:
import os
import copy
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models

from sklearn.metrics import f1_score, classification_report

import lightning as L
from lightning import Fabric
from src.resnet_modifications import resnet18
from src.models import AlexNetInception1x1, AlexNetSeparable11, AlexNetSkipConnection, AlexNetWithBatchNorm

In [2]:
DATA_DIR = 'data/img_align_celeba/img_align_celeba/'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
ys = pd.read_csv('data/list_attr_celeba.csv')
split = pd.read_csv('data/list_eval_partition.csv')

In [4]:
X_test = split[split['partition'] == 2].image_id.values
y_test = ys[ys['image_id'].isin(X_test)]['Heavy_Makeup'].values
y_test = np.where(y_test == -1, 0, 1)

In [5]:
val_tfms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

In [6]:
def get_preds(current_model):
    current_model.eval()
    all_preds, all_targets = [], []

    with torch.no_grad():
        for imgs, targets in tqdm(test_loader):
            imgs, targets = imgs.to(DEVICE), targets.float().to(DEVICE)
            outputs = current_model(imgs).squeeze(1)

            preds = torch.sigmoid(outputs).cpu().numpy() > 0.5
            all_preds.extend(preds.astype(int))
            all_targets.extend(targets.cpu().numpy().astype(int))

    return np.array(all_targets), np.array(all_preds)


In [7]:
class MakeupDataset(Dataset):
    def __init__(self, image_ids, labels, root_dir, transform=None):
        self.ids = image_ids
        self.labels = labels
        self.root = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        path = os.path.join(self.root, img_id)
        img = Image.open(path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = self.labels[idx]
        return img, label


In [8]:
test_ds  = MakeupDataset(X_test,  y_test,  DATA_DIR, transform=val_tfms)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False, num_workers=6)

In [12]:
# vanilla
alexnet_model = models.alexnet()
alexnet_model.classifier[6] = nn.Linear(4096, 1) 
alexnet_model.load_state_dict(torch.load('models/best_alexnet.pth'))
alexnet_model.eval()
alexnet_model.to(DEVICE)
pass

In [13]:
# pretrain
alexnet_model_pretrain = models.alexnet()
alexnet_model_pretrain.classifier[6] = nn.Linear(4096, 1) 
alexnet_model_pretrain.load_state_dict(torch.load('models/best_alexnet_pretrain.pth'))
alexnet_model_pretrain.eval()
alexnet_model_pretrain.to(DEVICE)
pass

In [14]:
# with 1x1 kernels
alexnet_model_1x1 = AlexNetInception1x1()
alexnet_model_1x1.load_state_dict(torch.load('models/best_alexnet_1x1.pth'))
alexnet_model_1x1.eval()
alexnet_model_1x1.to(DEVICE)
pass

In [15]:
# with 11x1 and 1x11
alexnet_model_1x11 = AlexNetSeparable11()
alexnet_model_1x11.load_state_dict(torch.load('models/best_alexnet_11.pth'))
alexnet_model_1x11.eval()
alexnet_model_1x11.to(DEVICE)
pass

In [16]:
# with scip-connection
alexnet_model_sc = AlexNetSkipConnection()
alexnet_model_sc.load_state_dict(torch.load('models/best_alexnet_skip_connection.pth'))
alexnet_model_sc.eval()
alexnet_model_sc.to(DEVICE)
pass

In [17]:
# with BatchNorm
alexnet_model_bn = AlexNetWithBatchNorm()
alexnet_model_bn.load_state_dict(torch.load('models/best_alexnet_bn.pth'))
alexnet_model_bn.eval()
alexnet_model_bn.to(DEVICE)
pass

In [18]:
# with pixels
alexnet_model_pixels = models.alexnet()
alexnet_model_pixels.classifier[6] = nn.Linear(4096, 1) 
alexnet_model_pixels.load_state_dict(torch.load('models/best_alexnet_pixels.pth'))
alexnet_model_pixels.eval()
alexnet_model_pixels.to(DEVICE)
pass

In [19]:
y_true, y_pred_alexnet = get_preds(alexnet_model)
_, y_pred_alexnet_pretrain = get_preds(alexnet_model_pretrain)
_, y_pred_alexnet_1x1 = get_preds(alexnet_model_1x1)
_, y_pred_alexnet_1x11 = get_preds(alexnet_model_1x11)
_, y_pred_alexnet_sc = get_preds(alexnet_model_sc)
_, y_pred_alexnet_bn = get_preds(alexnet_model_bn)
_, y_pred_alexnet_pixels = get_preds(alexnet_model_pixels)

100%|█████████████████████████████████████████| 312/312 [00:06<00:00, 45.34it/s]
100%|█████████████████████████████████████████| 312/312 [00:06<00:00, 46.17it/s]
100%|█████████████████████████████████████████| 312/312 [00:07<00:00, 42.21it/s]
100%|█████████████████████████████████████████| 312/312 [00:08<00:00, 38.41it/s]
100%|█████████████████████████████████████████| 312/312 [00:05<00:00, 53.26it/s]
100%|█████████████████████████████████████████| 312/312 [00:07<00:00, 44.35it/s]
100%|█████████████████████████████████████████| 312/312 [00:06<00:00, 45.98it/s]


In [20]:
true_images_mask = (
    (y_pred_alexnet == y_true) &
    (y_pred_alexnet_pretrain == y_true) &
    (y_pred_alexnet_1x1 == y_true) &
    (y_pred_alexnet_1x11 == y_true) &
    (y_pred_alexnet_sc == y_true) &
    (y_pred_alexnet_bn == y_true) & 
    (y_pred_alexnet_pixels == y_true)
)

In [21]:
X_test[true_images_mask][:30]

array(['182638.jpg', '182639.jpg', '182641.jpg', '182642.jpg',
       '182643.jpg', '182644.jpg', '182646.jpg', '182647.jpg',
       '182648.jpg', '182649.jpg', '182652.jpg', '182653.jpg',
       '182655.jpg', '182656.jpg', '182658.jpg', '182659.jpg',
       '182660.jpg', '182661.jpg', '182662.jpg', '182663.jpg',
       '182664.jpg', '182665.jpg', '182666.jpg', '182667.jpg',
       '182669.jpg', '182670.jpg', '182671.jpg', '182672.jpg',
       '182673.jpg', '182674.jpg'], dtype=object)

In [None]:
true_images = (X_test[true_images_mask], y_test[true_images_mask])

In [35]:
import pickle
with open('results/true_images.pickle', 'wb') as f:
    pickle.dump(true_images, f)

In [23]:
alexnet_model_bn_do = models.alexnet(dropout=0.8)
alexnet_model_bn_do.classifier[6] = nn.Linear(4096, 1) 
alexnet_model_bn_do.load_state_dict(torch.load('models/best_alexnet_do.pth'))
alexnet_model_bn_do.eval()
alexnet_model_bn_do.to(DEVICE)
pass

In [24]:
y_true, y_pred_alexnet_bn_do = get_preds(alexnet_model_bn_do)

100%|█████████████████████████████████████████| 312/312 [00:06<00:00, 45.56it/s]


In [25]:
true_images_mask = (
    (y_pred_alexnet == y_true) &
    (y_pred_alexnet_pretrain == y_true) &
    (y_pred_alexnet_1x1 == y_true) &
    (y_pred_alexnet_1x11 == y_true) &
    (y_pred_alexnet_sc == y_true) &
    (y_pred_alexnet_bn == y_true) & 
    (y_pred_alexnet_pixels == y_true) & 
    (y_pred_alexnet_bn_do == y_true)
)

In [26]:
X_test[true_images_mask][:30]

array(['182638.jpg', '182639.jpg', '182641.jpg', '182642.jpg',
       '182643.jpg', '182644.jpg', '182646.jpg', '182647.jpg',
       '182648.jpg', '182649.jpg', '182652.jpg', '182653.jpg',
       '182655.jpg', '182656.jpg', '182658.jpg', '182659.jpg',
       '182660.jpg', '182661.jpg', '182662.jpg', '182663.jpg',
       '182664.jpg', '182665.jpg', '182666.jpg', '182667.jpg',
       '182669.jpg', '182670.jpg', '182671.jpg', '182672.jpg',
       '182673.jpg', '182674.jpg'], dtype=object)

# ResNet

In [10]:
# vanilla

resnet_18_model = models.resnet18()
resnet_18_model.fc = nn.Linear(512, 1)
resnet_18_model.load_state_dict(torch.load('models/best_resnet18_2.pth'))
resnet_18_model.eval()
resnet_18_model.to(DEVICE)
pass

In [11]:
# pretrain

resnet_18_pretrain_model = models.resnet18()
resnet_18_pretrain_model.fc = nn.Linear(512, 1)
resnet_18_pretrain_model.load_state_dict(torch.load('models/best_resnet18_pretrain.pth'))
resnet_18_pretrain_model.eval()
resnet_18_pretrain_model.to(DEVICE)
pass

In [13]:
# silu

resnet_18_silu_model = resnet18()
resnet_18_silu_model.fc = nn.Linear(512, 1)
resnet_18_silu_model.load_state_dict(torch.load('models/best_resnet_silu.pth'))
resnet_18_silu_model.eval()
resnet_18_silu_model.to(DEVICE)
pass

In [14]:
# resnet101

resnet_101_model = models.resnet101()
resnet_101_model.fc = nn.Linear(2048, 1)
resnet_101_model.load_state_dict(torch.load('models/best_resnet_101.pth'))
resnet_101_model.eval()
resnet_101_model.to(DEVICE)
pass

In [15]:
y_true, y_pred_resnet_18 = get_preds(resnet_18_model)
_, y_pred_resnet_18_pretrain = get_preds(resnet_18_pretrain_model)
_, y_pred_resnet_silu = get_preds(resnet_18_silu_model)
_, y_pred_resnet_101 = get_preds(resnet_101_model)

100%|█████████████████████████████████████████████████████████████████████████████████| 312/312 [00:13<00:00, 23.82it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 312/312 [00:12<00:00, 24.20it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 312/312 [00:12<00:00, 24.21it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 312/312 [00:56<00:00,  5.56it/s]


In [16]:
true_images_mask = (
    (y_pred_resnet_18 == y_true) &
    (y_pred_resnet_18_pretrain == y_true) &
    (y_pred_resnet_silu == y_true) &
    (y_pred_resnet_101 == y_true)
)

In [17]:
true_images = (X_test[true_images_mask], y_test[true_images_mask])

In [18]:
import pickle
with open('results/true_images_resnet.pickle', 'wb') as f:
    pickle.dump(true_images, f)