In [5]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy.ndimage import rotate
import Utils
import cv2
from facenet_pytorch import MTCNN, InceptionResnetV1

In [63]:
data_root = '../../data/'
def filtered_labels(path, columns = ['skin_tone','gender','age']):
    df = pd.read_csv(data_root + path)
    if columns is not None and len(columns) > 0:
        df = df.dropna(how='any',subset=columns)
        
    df['skin_tone'] = df.skin_tone.apply(lambda x: int(x.replace('monk_',''))-1)
    df['gender'] = df.gender.apply(lambda x: int(x == 'male'))
    age_map = {
        '0_17': 0,
        '18_30': 1,
        '31_60': 2,
        '61_100': 3,
    }
    df['age'] = df.age.apply(lambda x: age_map[x])
    
    return df

train_labels = filtered_labels('train_labels.csv')
test_labels = filtered_labels('labels.csv')

In [65]:
def detect_nonfaces(df,root,image_size=256):
    mtcnn = MTCNN(image_size = image_size)
    def is_face(file):
        img = Image.open(root+file)
        try:
            face = mtcnn(img)
            if face is None:
                return False
            return True
        except Exception as e:
            print(e,file)
            return True
        
    df['is_face'] = df.name.apply(is_face)
    return df

train_labels = detect_nonfaces(train_labels,data_root)
test_labels = detect_nonfaces(test_labels,data_root)
train_labels[train_labels.is_face]

Unnamed: 0,name,skin_tone,gender,age,is_face
2,TRAIN0002.png,5,1,0,True
6,TRAIN0006.png,1,0,1,True
7,TRAIN0007.png,1,0,1,True
11,TRAIN0011.png,3,0,1,True
12,TRAIN0012.png,3,1,2,True
...,...,...,...,...,...
12271,TRAIN9988.png,5,1,3,True
12275,TRAIN9992.png,4,0,2,True
12276,TRAIN9993.png,1,1,1,True
12278,TRAIN9995.png,8,0,1,True


In [66]:
validation_labels = train_labels.sample(frac=.2,replace=False)
train_labels = train_labels.drop(validation_labels.index)
train_labels

Unnamed: 0,name,skin_tone,gender,age,is_face
1,TRAIN0001.png,0,0,1,False
2,TRAIN0002.png,5,1,0,True
5,TRAIN0005.png,1,1,0,False
7,TRAIN0007.png,1,0,1,True
9,TRAIN0009.png,7,0,1,False
...,...,...,...,...,...
12275,TRAIN9992.png,4,0,2,True
12276,TRAIN9993.png,1,1,1,True
12278,TRAIN9995.png,8,0,1,True
12281,TRAIN9998.png,4,1,1,False


In [69]:
train_labels.to_csv('train_data_clean.csv',index=False)
validation_labels.to_csv('validation_data_clean.csv',index=False)
test_labels.to_csv('test_data_clean.csv',index=False)