In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torchvision import models,transforms
import matplotlib.pyplot as plt
import pickle
from collections import OrderedDict
import csv
import collections
from  PIL import Image
from tqdm.notebook import tqdm_notebook
from scipy.spatial import distance
import warnings
warnings.filterwarnings('ignore')
import math
device = torch.device("mps" if torch.has_mps else "cpu")
from itertools import product
import senet50
random_state = 1

In [None]:
model_scratch = senet50.make_model()
fname = 'weights/senet50_ft_weight.pkl'
with open(fname, 'rb') as f:
    weights = pickle.load(f, encoding='latin1')

own_state = model_scratch.state_dict()
for name, param in weights.items():
    if name in own_state:
        try:
            own_state[name].copy_(torch.from_numpy(param))
        except Exception:
            raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\
                                'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.shape))
    else:
        raise KeyError('unexpected key "{}" in state_dict'.format(name))
model_scratch = model_scratch.to(device)

In [None]:
### load data 
# create df to contain all identities, their image file names, their ethnicities
n = 24
img_path = 'data/BFW/CROPPED_ALIGNED'
# asian female
asian_female_identities = np.array(next(os.walk(img_path+'/asian_females'))[1])
asian_female_files = np.array([])
asian_female_references = []
for identity in asian_female_identities:
    files = np.sort(next(os.walk(img_path+'/asian_females/'+identity))[2])
    x = img_path+'/asian_females/'+identity +'/'
    files = np.char.add(x,files)
    asian_female_references.append(files[0])
    files = files[1:n+1]
    asian_female_files = np.concatenate((asian_female_files,files))
asian_female_candidates = pd.DataFrame(asian_female_files,columns =['file'])
asian_female_candidates['ethnicity'] = 'asian'
asian_female_candidates['gender'] = 'female'
asian_female_candidates['identity'] = np.repeat(asian_female_identities,n)

asian_female_references = pd.DataFrame(asian_female_references,columns =['file'])
asian_female_references['ethnicity'] = 'asian'
asian_female_references['gender'] = 'female'
asian_female_references['identity'] = asian_female_identities

# black females
black_female_identities = np.array(next(os.walk(img_path+'/black_females'))[1])
black_female_files = np.array([])
black_female_references = []
for identity in black_female_identities:
    files = np.sort(next(os.walk(img_path+'/black_females/'+identity))[2])
    x = img_path+'/black_females/'+identity +'/'
    files = np.char.add(x,files)
    black_female_references.append(files[0])
    files = files[1:n+1]
    black_female_files = np.concatenate((black_female_files,files))
black_female_candidates = pd.DataFrame(black_female_files,columns =['file'])
black_female_candidates['ethnicity'] = 'black'
black_female_candidates['gender'] = 'female'
black_female_candidates['identity'] = np.repeat(black_female_identities,n)


black_female_references = pd.DataFrame(black_female_references,columns =['file'])
black_female_references['ethnicity'] = 'black'
black_female_references['gender'] = 'female'
black_female_references['identity'] = black_female_identities

# indian females
indian_female_identities = np.array(next(os.walk(img_path+'/indian_females'))[1])
indian_female_files = np.array([])
indian_female_references = []
for identity in indian_female_identities:
    files = np.sort(next(os.walk(img_path+'/indian_females/'+identity))[2])
    x = img_path+'/indian_females/'+identity +'/'
    files = np.char.add(x,files)
    indian_female_references.append(files[0])
    files = files[1:n+1]
    indian_female_files = np.concatenate((indian_female_files,files))
indian_female_candidates = pd.DataFrame(indian_female_files,columns =['file'])
indian_female_candidates['ethnicity'] = 'indian'
indian_female_candidates['gender'] = 'female'
indian_female_candidates['identity'] = np.repeat(indian_female_identities,n)


indian_female_references = pd.DataFrame(indian_female_references,columns =['file'])
indian_female_references['ethnicity'] = 'indian'
indian_female_references['gender'] = 'female'
indian_female_references['identity'] = indian_female_identities

# white females
white_female_identities = np.array(next(os.walk(img_path+'/white_females'))[1])
white_female_files = np.array([])
white_female_references = []
for identity in white_female_identities:
    files = np.sort(next(os.walk(img_path+'/white_females/'+identity))[2])
    x = img_path+'/white_females/'+identity +'/'
    files = np.char.add(x,files)
    white_female_references.append(files[0])
    files = files[1:n+1]
    white_female_files = np.concatenate((white_female_files,files))
white_female_candidates = pd.DataFrame(white_female_files,columns =['file'])
white_female_candidates['ethnicity'] = 'white'
white_female_candidates['gender'] = 'female'
white_female_candidates['identity'] = np.repeat(white_female_identities,n)


white_female_references = pd.DataFrame(white_female_references,columns =['file'])
white_female_references['ethnicity'] = 'white'
white_female_references['gender'] = 'female'
white_female_references['identity'] = white_female_identities


# asian males
asian_male_identities = np.array(next(os.walk(img_path+'/asian_males'))[1])
asian_male_files = np.array([])
asian_male_references = []
for identity in asian_male_identities:
    files = np.sort(next(os.walk(img_path+'/asian_males/'+identity))[2])
    x = img_path+'/asian_males/'+identity +'/'
    files = np.char.add(x,files)
    asian_male_references.append(files[0])
    files = files[1:n+1]
    asian_male_files = np.concatenate((asian_male_files,files))
asian_male_candidates = pd.DataFrame(asian_male_files,columns =['file'])
asian_male_candidates['ethnicity'] = 'asian'
asian_male_candidates['gender'] = 'male'
asian_male_candidates['identity'] = np.repeat(asian_male_identities,n)

asian_male_references = pd.DataFrame(asian_male_references,columns =['file'])
asian_male_references['ethnicity'] = 'asian'
asian_male_references['gender'] = 'male'
asian_male_references['identity'] = asian_male_identities

# black males
black_male_identities = np.array(next(os.walk(img_path+'/black_males'))[1])
black_male_files = np.array([])
black_male_references = []
for identity in black_male_identities:
    files = np.sort(next(os.walk(img_path+'/black_males/'+identity))[2])
    x = img_path+'/black_males/'+identity +'/'
    files = np.char.add(x,files)
    black_male_references.append(files[0])
    files = files[1:n+1]
    black_male_files = np.concatenate((black_male_files,files))
black_male_candidates = pd.DataFrame(black_male_files,columns =['file'])
black_male_candidates['ethnicity'] = 'black'
black_male_candidates['gender'] = 'male'
black_male_candidates['identity'] = np.repeat(black_male_identities,n)


black_male_references = pd.DataFrame(black_male_references,columns =['file'])
black_male_references['ethnicity'] = 'black'
black_male_references['gender'] = 'male'
black_male_references['identity'] = black_male_identities

# indian males
indian_male_identities = np.array(next(os.walk(img_path+'/indian_males'))[1])
indian_male_files = np.array([])
indian_male_references = []
for identity in indian_male_identities:
    files = np.sort(next(os.walk(img_path+'/indian_males/'+identity))[2])
    x = img_path+'/indian_males/'+identity +'/'
    files = np.char.add(x,files)
    indian_male_references.append(files[0])
    files = files[1:n+1]
    indian_male_files = np.concatenate((indian_male_files,files))
indian_male_candidates = pd.DataFrame(indian_male_files,columns =['file'])
indian_male_candidates['ethnicity'] = 'indian'
indian_male_candidates['gender'] = 'male'
indian_male_candidates['identity'] = np.repeat(indian_male_identities,n)


indian_male_references = pd.DataFrame(indian_male_references,columns =['file'])
indian_male_references['ethnicity'] = 'indian'
indian_male_references['gender'] = 'male'
indian_male_references['identity'] = indian_male_identities

# white males
white_male_identities = np.array(next(os.walk(img_path+'/white_males'))[1])
white_male_files = np.array([])
white_male_references = []
for identity in white_male_identities:
    files = np.sort(next(os.walk(img_path+'/white_males/'+identity))[2])
    x = img_path+'/white_males/'+identity +'/'
    files = np.char.add(x,files)
    white_male_references.append(files[0])
    files = files[1:n+1]
    white_male_files = np.concatenate((white_male_files,files))
white_male_candidates = pd.DataFrame(white_male_files,columns =['file'])
white_male_candidates['ethnicity'] = 'white'
white_male_candidates['gender'] = 'male'
white_male_candidates['identity'] = np.repeat(white_male_identities,n)


white_male_references = pd.DataFrame(white_male_references,columns =['file'])
white_male_references['ethnicity'] = 'white'
white_male_references['gender'] = 'male'
white_male_references['identity'] = white_male_identities

# all ethnicity
asian_references = pd.concat([asian_male_references,asian_female_references],ignore_index=True)
black_references = pd.concat([black_male_references,black_female_references],ignore_index=True)
indian_references = pd.concat([indian_male_references,indian_female_references],ignore_index=True)
white_references = pd.concat([white_male_references,white_female_references],ignore_index=True)
female_references = pd.concat([asian_female_references,black_female_references,indian_female_references,white_female_references],ignore_index=True)
male_references = pd.concat([asian_male_references,black_male_references,indian_male_references,white_male_references],ignore_index=True)
references = pd.concat([male_references,female_references],ignore_index=True)

asian_candidates = pd.concat([asian_male_candidates,asian_female_candidates],ignore_index=True)
black_candidates = pd.concat([black_male_candidates,black_female_candidates],ignore_index=True)
indian_candidates = pd.concat([indian_male_candidates,indian_female_candidates],ignore_index=True)
white_candidates = pd.concat([white_male_candidates,white_female_candidates],ignore_index=True)
female_candidates = pd.concat([asian_female_candidates,black_female_candidates,indian_female_candidates,white_female_candidates],ignore_index=True)
male_candidates = pd.concat([asian_male_candidates,black_male_candidates,indian_male_candidates,white_male_candidates],ignore_index=True)
candidates = pd.concat([male_candidates,female_candidates],ignore_index=True)

In [None]:
class BFW_dataset(data.Dataset):
    '''
    This class loads data from dataframes containing images from BFW
    '''
    # mean_bgr = np.array([91.4953, 103.8827, 131.0912])  # from senet50_ft.prototxt
    mean_bgr = np.array([93.5940, 104.7624, 129.1863])  # from senet50_scratch.prototxt

    def __init__(self,img_df):
        """
        :param img_path: dataset directory
        :param img_df: contains image file names and other information
        """
        assert os.path.exists(img_path), "root: {} not found.".format(img_path)
        self.img_df = img_df
        self.img_info = []

        for i, row in self.img_df.iterrows():
            self.img_info.append({
                'img_file': row.file,
                'identity': row.identity,
                'ethnicity': row.ethnicity,
                'gender': row.gender,
            })
            if i % 1000 == 0:
                print("processing: {} images".format(i))

    def __len__(self):
        return len(self.img_info)

    def __getitem__(self, index):
        info = self.img_info[index]
        img_file = info['img_file']
        img = Image.open(img_file)
        img = transforms.Resize(256)(img)
        img = transforms.CenterCrop(224)(img)
        img = np.array(img, dtype=np.uint8)
        assert len(img.shape) == 3  # assumes color images and no alpha channel

        ethnicity = info['ethnicity']
        identity = info['identity']
        gender = info['gender']
        return self.transform(img), identity, ethnicity, gender
  
    def transform(self, img):
        img = img[:, :, ::-1]  # RGB -> BGR
        img = img.astype(np.float32)
        img -= self.mean_bgr
        img = img.transpose(2, 0, 1)  # C x H x W
        img = torch.from_numpy(img).float()
        return img

    def untransform(self, img, lbl):
        img = img.numpy()
        img = img.transpose(1, 2, 0)
        #img += self.mean_bgr
        img = img.astype(np.uint8)
        img = img[:, :, ::-1]
        return img, lbl
        
def apply_model(model,dataloader,device):
    model.eval()
    outputs = []
    identities = []
    ethnicities = []
    genders = []
    with torch.no_grad():
        for _, (imgs, identityID, ethnicity, gender) in tqdm_notebook(enumerate(dataloader),total=len(dataloader)):
            imgs = imgs.to(device)
            x = model(imgs)
            out = x.view(x.size(0),-1)
            outputs.append(out)
            identities.append(np.array(identityID))
            ethnicities.append(np.array(ethnicity))
            genders.append(np.array(gender))

    outputs=torch.cat(outputs)
    identities= np.concatenate(np.array(identities)).ravel()
    ethnicities= np.concatenate(np.array(ethnicities)).ravel()
    genders= np.concatenate(np.array(genders)).ravel()

    # torch.save(outputs, file_prefix + '_outputs.pt')
    # np.save(file_prefix + '_identities.npy', identities)
    # np.save(file_prefix + '_ethnicities.npy', ethnicities)
    # np.save(file_prefix + '_faceIDs.npy', genders)
    return outputs, identities, ethnicities, genders

In [None]:
# load reference images
reference_dataset = BFW_dataset(references.reset_index(drop=True))
reference_loader = torch.utils.data.DataLoader(reference_dataset, batch_size=4, shuffle=False)#, **kwargs)

candidate_dataset = BFW_dataset(candidates.reset_index(drop=True))
candidate_loader = torch.utils.data.DataLoader(candidate_dataset, batch_size=4, shuffle=False)#, **kwargs)


reference_outputs, reference_identities, reference_ethnicities, reference_genders = apply_model(model_scratch,reference_loader,device)
candidate_outputs, candidate_identities, candidate_ethnicities, candidate_genders = apply_model(model_scratch,candidate_loader,device)


In [None]:
reference_outputs_list = []
for output in reference_outputs.cpu().numpy():
    reference_outputs_list.append(output)
candidate_outputs_list = []
for output in candidate_outputs.cpu().numpy():
    candidate_outputs_list.append(output)
output_references = {'outputs': reference_outputs_list, 'identity': reference_identities,'ethnicity': reference_ethnicities, 'gender': reference_genders}
output_references = pd.DataFrame(output_references)

output_candidates = {'outputs': candidate_outputs_list, 'identity': candidate_identities,'ethnicity': candidate_ethnicities, 'gender': candidate_genders}
output_candidates = pd.DataFrame(output_candidates)
output_references

In [None]:
ids = np.array(list(product(output_references['identity'], output_candidates['identity'])))
# labels = (ids[:,0] == ids[:,1])*1
ethnicities = np.array(list(product(output_references['ethnicity'], output_candidates['ethnicity'])))
genders = np.array(list(product(output_references['gender'], output_candidates['gender'])))

logistic_df = { 
                'reference_identity': ids[:,0],'candidate_identity': ids[:,1],
                'reference_ethnicity': ethnicities[:,0],'candidate_ethnicity': ethnicities[:,1], 
                'reference_gender': genders[:,0],'candidate_gender': genders[:,1]}


logistic_df = pd.DataFrame(logistic_df)
logistic_df['labels']=(logistic_df.reference_identity == logistic_df.candidate_identity )*1


logistic_df

In [None]:
logistic_df2 = logistic_df[(logistic_df['reference_ethnicity']==logistic_df['candidate_ethnicity'] )]
logistic_df2 = logistic_df2[(logistic_df2['reference_gender']==logistic_df2['candidate_gender'] )]
logistic_df2


In [None]:
labels= logistic_df2.labels
logistic_df2.labels.value_counts()

In [None]:
match_idx = np.where(labels==1)[0]
not_match_idx = np.where(labels==0)[0]
np.random.seed(random_state)
not_match_idx_sub  = not_match_idx[np.random.choice(len(not_match_idx), size=len(match_idx), replace=False)]
print((not_match_idx_sub))

In [None]:
array_1 = np.arange(reference_outputs.shape[0])
array_2 = np.arange(candidate_outputs.shape[0])
mesh = np.array(np.meshgrid(array_1, array_2))
combinations = mesh.T.reshape(-1, 2)
combinations = combinations[logistic_df2.index.values]
combinations.shape

In [None]:
match_pairs = combinations[match_idx]
not_match_pairs = combinations[not_match_idx_sub]

match_list = []
for _,pairs in tqdm_notebook(enumerate(match_pairs),total=len(match_pairs)):
    match_list.append(torch.concat((reference_outputs[pairs[0]],candidate_outputs[pairs[1]])))

not_match_list = []
for _,pairs in tqdm_notebook(enumerate(not_match_pairs),total=len(not_match_pairs)):
    not_match_list.append(torch.concat((reference_outputs[pairs[0]],candidate_outputs[pairs[1]])))

match_tensor=torch.stack(match_list)
not_match_tensor=torch.stack(not_match_list)

match_ref_ids =[]
match_ref_eth =[]
match_ref_gend =[]

for _,pairs in tqdm_notebook(enumerate(match_pairs),total=len(match_pairs)):
    match_ref_ids.append(reference_identities[pairs[0]])
    match_ref_eth.append(reference_ethnicities[pairs[0]])
    match_ref_gend.append(reference_genders[pairs[0]])
    
not_match_ref_ids =[]
not_match_ref_eth =[]
not_match_ref_gend =[]
not_match_cand_ids =[]
not_match_cand_eth =[]
not_match_cand_gend =[]

for _,pairs in tqdm_notebook(enumerate(not_match_pairs),total=len(not_match_pairs)):
    not_match_ref_ids.append(reference_identities[pairs[0]])
    not_match_ref_eth.append(reference_ethnicities[pairs[0]])
    not_match_ref_gend.append(reference_genders[pairs[0]])
    not_match_cand_ids.append(candidate_identities[pairs[1]])
    not_match_cand_eth.append(candidate_ethnicities[pairs[1]])
    not_match_cand_gend.append(candidate_genders[pairs[1]])

In [None]:
all_inputs = torch.cat([match_tensor,not_match_tensor])
torch.save(all_inputs,'inputs/bfw_senet50_face_embeddings.pt')
match_labels = torch.ones(len(match_pairs))
not_match_labels = torch.zeros(len(match_pairs))
all_labels = torch.cat([match_labels,not_match_labels])
torch.save(all_labels,'inputs/bfw_senet50_labels.pt')

In [None]:
all_ref_ids = match_ref_ids + not_match_ref_ids
all_ref_eth = match_ref_eth + not_match_ref_eth
all_ref_gend = match_ref_gend + not_match_ref_gend
all_cand_ids = match_ref_ids + not_match_cand_ids
all_cand_eth = match_ref_eth + not_match_cand_eth
all_cand_gend = match_ref_gend + not_match_cand_gend

all_df = { 'reference_identity': all_ref_ids,'candidate_identity': all_cand_ids,
            'reference_ethnicity': all_ref_eth,'candidate_ethnicity': all_cand_eth, 
            'reference_gender': all_ref_gend,'candidate_gender':all_cand_gend,
            'labels': all_labels.cpu().numpy()}


all_df = pd.DataFrame(all_df)
all_df

In [None]:
all_df.to_csv('inputs/bfw_senet50_df.csv',index=False)