In [None]:
# Imports and setup
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from matplotlib.image import imread
import matplotlib.pyplot as plt
from collections import defaultdict
import os
from tqdm import tqdm
import random
import math
import pickle
from sklearn.model_selection import StratifiedKFold

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

%matplotlib inline

In [None]:
# Setup general variables, etc. that are needed

# Load the list of image files from the dataset
BASE_PATH = '../input/utk-face-cropped/utkcropped/'
IMAGE_FILES = os.listdir(BASE_PATH)

# Used to extract the meaning of the numeric representation encoded in each image name
GENDER_KEYS = {'0': 'male', '1': 'female'}
RACE_KEYS = {'0': 'white', '1': 'black', '2': 'asian', '3': 'indian', '4': 'other'}

In [None]:
# Functions for parsing the image labels

def is_valid_name(name):
    return len(name.split('_')) == 4

def parse_name(name):
    age_str, gender_str, race_str, timestamp = name.split('_')
        
    age = int(age_str)
    gender = GENDER_KEYS[gender_str]
    race = RACE_KEYS[race_str]
    
    return age, gender, race

In [None]:
def create_label_dataframe(image_files=IMAGE_FILES):
    '''
    Returns a pandas dataframe with information about the age, gender, and race, for each image
    '''
    data = pd.DataFrame(columns=['age', 'gender', 'race'])

    for image_name in tqdm(image_files):
        if is_valid_name(image_name):
            age, gender, race = parse_name(image_name)
            
            new_row = {'age': age, 'gender': gender, 'race': race}
        
            data = data.append(new_row, ignore_index=True)
            
    return data

In [None]:
def print_age_stats(data):
    print(f'min age: {min(data["age"])}')
    print(f'max age: {max(data["age"])}')
    print(f'mean age: {np.mean(data["age"])}')
    print(f'std. dev. age: {np.std(data["age"])}')
    print(f'Most common ages: {data["age"].value_counts()}')

def plot_basic_label_distributions(data):
    ax = data['age'].plot.hist(bins=100,  title='Age Distribution')
    ax.set_xlabel('Age (years)')
    plt.show()
    
    data['gender'].value_counts().plot.bar(title='Gender Distribution', xlabel='Gender')
    plt.show()
    
    data['race'].value_counts().plot.bar(title='Race Distribution', xlabel='Race')
    plt.show()

In [None]:
def plot_adv_label_distributions(data):
    ax = data[data['gender'] == 'male']['age'].plot.hist(bins=100, title='Male Age Distribution')
    ax.set_xlabel('Age (years)')
    plt.show()
    
    ax = data[data['gender'] == 'female']['age'].plot.hist(bins=100, title='Female Age Distribution')
    ax.set_xlabel('Age (years)')
    plt.show()

    ax = data[data['race'] == 'white']['age'].plot.hist(bins=100, title='White Age Distribution')
    ax.set_xlabel('Age (years)')
    plt.show()
    
    
    ax = data[data['race'] == 'black']['age'].plot.hist(bins=100, title='Black Age Distribution')
    ax.set_xlabel('Age (years)')
    plt.show()
    
    ax = data[data['race'] == 'indian']['age'].plot.hist(bins=100, title='Indian Age Distribution')
    ax.set_xlabel('Age (years)')
    plt.show()
    
    ax = data[data['race'] == 'asian']['age'].plot.hist(bins=100, title='Asian Age Distribution')
    ax.set_xlabel('Age (years)')
    plt.show()

    
    data[data['gender'] == 'male']['race'].value_counts().plot.bar(title='Male Race Distribution', xlabel='Race')
    plt.show()
    data[data['gender'] == 'female']['race'].value_counts().plot.bar(title='Female Race Distribution', xlabel='Race')
    plt.show()

In [None]:
data = create_label_dataframe()

In [None]:
print(np.mean(data[data['race'] == 'white']['age']))
print(np.mean( data[data['race'] == 'black']['age']))
print(np.mean(data[data['race'] == 'asian']['age']))
print(np.mean( data[data['race'] == 'indian']['age']))
print(np.mean(data[data['race'] == 'other']['age']))


In [None]:
def label_analysis():
    '''
    This function will plot and print all the information about the distribution of the labels
    '''
    
    data = create_label_dataframe()
    
    print_age_stats(data)
    plot_basic_label_distributions(data)
    plot_adv_label_distributions(data)

In [None]:
'''
Execute this code cell to print a complete report of the distribution of the labels

Note: it may take a few minutes to execute
'''

label_analysis()

In [None]:
def show_image_sample(num_images=100, image_files=IMAGE_FILES, rand_seed=0):
    '''
    Uses matplotlib to display a sample of images, along with their labels.
    Can be used to audit images and/or labels
    '''

    random.seed(rand_seed)
    audit_imgs = random.sample(image_files, num_images)

    for img_path in audit_imgs:
        img = imread(os.path.join(base_path, img_path))
        plt.imshow(img)
        plt.show()
        
        if !is_valid_name(img_path):
            print(f'Invalid file name: {img_path}')
        else:
            print(parse_name(img_path))

In [None]:
'''
Run this code cell to display a sample of 100 randomly selected images
'''

show_image_samples()

In [None]:
def img_to_nparray(image_files=IMAGE_FILES):
    '''
    Converts the directory of images to a numpy array of the images (scaled between 0 and 1)
    Save this array somewhere to avoid having to regenerate it
    
    Note: downsamples to size 100x100
    '''
    
    img_list = []

    for img_path in tqdm(image_files):
        if is_valid_name(img_path):
            np_image = imread(os.path.join(BASE_PATH, img_path))[::2, ::2]
            img_list.append(np_image)

    img_array = np.array(img_list)
    img_array = img_array.astype('float32') / 255.0

    del img_list
    
    return img_array

img_array = img_to_nparray()

In [None]:
my_face = np.array([imread('../input/face-autoencoder-10-27/colin_face3.png')[:, :, :3]])
plt.imshow(my_face[0])

In [None]:
'''
Run this cell to load the data from an existing 'images.npz' file

Note: this takes a few second because the file is quite larger
'''

IMAGES_LOCATION = '../input/utkfacecroppednumpy/images.npz'
img_array = np.load(IMAGES_LOCATION)['arr_0']

In [None]:
'''
Defines components of the autoencoder model
'''

def Encoder(latent_space_size, image_dim=100):
    image_input = keras.Input(shape=(image_dim, image_dim, 3))
    
    intermediate_encoding = layers.Conv2D(filters=8, kernel_size=3, strides=(2, 2), padding='same', activation='relu')(image_input)
    
    for i in range(math.ceil(math.log2(image_dim)) - 1):
        new_filters = intermediate_encoding.shape[-1] * 2
        
        intermediate_encoding = layers.Conv2D(filters=new_filters, kernel_size=3, strides=(2, 2), padding='same', activation='relu')(intermediate_encoding)
        
    flat_encoding = layers.Flatten()(intermediate_encoding)
    dense_encoding = layers.Dense(latent_space_size)(flat_encoding)
    
    return keras.Model(image_input, dense_encoding, name='Encoder')
    
def Decoder(latent_space_size):
    # Note: I realize this is all hard coded, which isn't great, but b/c the size of the images is 100x100, not 128x128, there
    # is no simple way to determine which layers need to be 'same' vs 'valid' to reach the target size
    
    decode_input = keras.Input(shape=(latent_space_size))
    decode_reshape = layers.Reshape((1, 1, latent_space_size))(decode_input)
    decodeA = layers.Conv2DTranspose(filters=256, kernel_size=3, strides=(2, 2), padding='valid', activation='relu')(decode_reshape)
    decodeB = layers.Conv2DTranspose(filters=128, kernel_size=3, strides=(2, 2), padding='same', activation='relu')(decodeA)
    decodeC = layers.Conv2DTranspose(filters=64, kernel_size=3, strides=(2, 2), padding='same', activation='relu')(decodeB)
    decodeD = layers.Conv2DTranspose(filters=32, kernel_size=3, strides=(2, 2), padding='valid', activation='relu')(decodeC)
    decodeE = layers.Conv2DTranspose(filters=8, kernel_size=3, strides=(2, 2), padding='same', activation='relu')(decodeD)
    decodeF = layers.Conv2DTranspose(filters=3, kernel_size=3, strides=(2, 2), padding='same', activation='relu')(decodeE)
    
    decoder = keras.Model(decode_input, decodeF, name='Decoder')
    
    return decoder
    
def Autoencoder(encoder, decoder, image_dim=100): 
    autoencoder_input = keras.Input(shape=(image_dim, image_dim, 3)) # In theory, could pull this size from the encoder
    
    encoded = encoder(autoencoder_input)
    decoded = decoder(encoded)
    
    autoencoder = keras.Model(autoencoder_input, decoded, name='Autoencoder')
    
    return autoencoder


In [None]:
'''
Exectue this cell to prepare for training the model
'''

LATENT_SPACE_SIZE = 128

# Note: this doesn't seem to be aggresive enough in reducing the learning rate. I think a lower start might be worthwhile
LEARNING_RATE = 0.0003

'''
tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.0005,
    decay_steps=6000,
    decay_rate=0.95)
'''

X_train = img_array[:16000]
X_test = img_array[20000:]

encoder = Encoder(LATENT_SPACE_SIZE)
decoder = Decoder(LATENT_SPACE_SIZE)
autoencoder = Autoencoder(encoder, decoder)

encoder.summary()
decoder.summary()
autoencoder.summary()

optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
autoencoder.compile(optimizer=optimizer, loss='binary_crossentropy')

In [None]:
'''
Execute this cell to actually fit the model
'''

EPOCHS = 100 * 1
BATCH_SIZE = 32

history = autoencoder.fit(X_train, X_train, epochs=EPOCHS, batch_size=BATCH_SIZE, shuffle=True, validation_data=(X_test, X_test))

In [None]:
print(np.mean(history.history['val_loss'][-10:]))
str(history.history['val_loss'][::1])

In [None]:
'''
Execute this cell to save the models
'''

NAME = '128'
encoder.save(f'./encoder_{NAME}.hdf5')
decoder.save(f'./decoder_{NAME}.hdf5')
autoencoder.save(f'./autoencoder_{NAME}.hdf5')

In [None]:
def show_predictions(autoencoder, np_images):
    for i in range(len(np_images)):
        plt.imshow(np_images[i])
        plt.show()
        
        encoded = autoencoder.predict(np_images[i:i+1])
        plt.imshow(encoded[0])
        plt.show()

In [None]:
autoencoder = keras.models.load_model('../input/face-autoencoder-10-27/autoencoder_128.hdf5')
decoder = keras.models.load_model('../input/face-autoencoder-10-27/decoder_128.hdf5')
encoder = keras.models.load_model('../input/face-autoencoder-10-27/encoder_128.hdf5')
autoencoder.summary()

In [None]:
show_predictions(autoencoder, img_array[20000:20010])

In [None]:
encoded = autoencoder.predict(my_face)
plt.imshow(encoded[0])

In [None]:
def pull_gender(file_names, imgs, target_gender='male'):
    indices = [False] * len(file_names)
    
    for i, name in enumerate(file_names):
        if is_valid_name(name):
            age, gender, race = parse_name(name)
            
            if gender == target_gender:
                indices[i] = True
                

    return imgs[indices]

def pull_race(file_names, imgs, target_race='white'):
    indices = [False] * len(file_names)
    
    for i, name in enumerate(file_names):
        if is_valid_name(name):
            age, gender, race = parse_name(name)
            
            if race == target_race:
                indices[i] = True
                
    return imgs[indices]

def pull_age(file_names, imgs, min_age, max_age):
    indices = [False] * len(file_names)
    
    for i, name in enumerate(file_names):
        if is_valid_name(name):
            age, gender, race = parse_name(name)
            
            if min_age <= age <= max_age:
                indices[i] = True
                
    return imgs[indices]

In [None]:
valid_image_files = [image_file for image_file in IMAGE_FILES if is_valid_name(image_file)]

In [None]:
imgs = pull_age(valid_image_files, img_array, min_age=30, max_age=39)

'''
for img in imgs:
    plt.imshow(img)
    plt.show()
'''

autoencoder.evaluate(imgs, imgs)

In [None]:
loss_32 = [0.6627695560455322, 0.6259118914604187, 0.6138525009155273, 0.6089658141136169, 0.6052066683769226, 0.6025187373161316, 0.6072785258293152, 0.5883831977844238, 0.5862176418304443, 0.5827500224113464, 0.5795606374740601, 0.5778688192367554, 0.577028751373291, 0.5815162658691406, 0.5756434202194214, 0.5733323693275452, 0.578891396522522, 0.5720478296279907, 0.5746335983276367, 0.5718228816986084, 0.5705004334449768, 0.5709825754165649, 0.5711780786514282, 0.5702975988388062, 0.5691111087799072, 0.5694326758384705, 0.5686717629432678, 0.5696514248847961, 0.5675574541091919, 0.5694934129714966, 0.5671796798706055, 0.566648006439209, 0.5667614936828613, 0.5689466595649719, 0.565967321395874, 0.5662250518798828, 0.5682482123374939, 0.5740189552307129, 0.5663620829582214, 0.5655991435050964, 0.5658228397369385, 0.5659757852554321, 0.5652344822883606, 0.5652780532836914, 0.5654373168945312, 0.5651289224624634, 0.5650496482849121, 0.5652844309806824, 0.5646237134933472, 0.5644608736038208, 0.5656793117523193, 0.5676941871643066, 0.5647034645080566, 0.564418375492096, 0.5649163722991943, 0.5658340454101562, 0.564697802066803, 0.5649095177650452, 0.5642573237419128, 0.5653145909309387, 0.5639434456825256, 0.5642932057380676, 0.5639004707336426, 0.5642625689506531, 0.5638266801834106, 0.563861608505249, 0.5638389587402344, 0.5637657046318054, 0.5638052225112915, 0.5653149485588074, 0.5641324520111084, 0.5636299252510071, 0.5636019110679626, 0.5639102458953857, 0.5635942220687866, 0.5637409090995789, 0.5635852217674255, 0.5634598135948181, 0.5635145306587219, 0.5636105537414551, 0.5635759234428406, 0.5634059906005859, 0.5634787082672119, 0.5634635090827942, 0.5641558766365051, 0.5632580518722534, 0.5635772943496704, 0.5655040740966797, 0.5634446144104004, 0.5641424059867859, 0.5632159113883972, 0.5641984939575195, 0.563368558883667, 0.5639365315437317, 0.5634017586708069, 0.5634316205978394, 0.5632849931716919, 0.5633142590522766, 0.5632380247116089, 0.5634582042694092]
loss_64 = [0.6556639075279236, 0.627168595790863, 0.6166462302207947, 0.6130629181861877, 0.6045221090316772, 0.5962830185890198, 0.6077504754066467, 0.5901625752449036, 0.5901321768760681, 0.5878445506095886, 0.5891215801239014, 0.5882571339607239, 0.5861573219299316, 0.5851282477378845, 0.5833849310874939, 0.5826190710067749, 0.5860950946807861, 0.5818547606468201, 0.5787014961242676, 0.5794712901115417, 0.5754133462905884, 0.5739091634750366, 0.5729219317436218, 0.5718095302581787, 0.5708395838737488, 0.5710378885269165, 0.5700592398643494, 0.5694400072097778, 0.5726715326309204, 0.5694040656089783, 0.5680436491966248, 0.5675168037414551, 0.5674852132797241, 0.5673973560333252, 0.5668566226959229, 0.5665867328643799, 0.5665884017944336, 0.5669726729393005, 0.5875481963157654, 0.5660555362701416, 0.5649645924568176, 0.5651966333389282, 0.5642680525779724, 0.564110279083252, 0.5636363625526428, 0.5643065571784973, 0.5667364001274109, 0.5635185837745667, 0.5629991292953491, 0.5628889203071594, 0.5628605484962463, 0.5638269782066345, 0.5625545978546143, 0.5634298920631409, 0.5629264116287231, 0.5617396235466003, 0.5615983605384827, 0.561500072479248, 0.5613858103752136, 0.5615896582603455, 0.5609035491943359, 0.5616069436073303, 0.5614066123962402, 0.560908854007721, 0.5617967844009399, 0.5606575012207031, 0.5609208941459656, 0.5603482127189636, 0.5604383945465088, 0.5603731870651245, 0.5633828639984131, 0.560508668422699, 0.5599803328514099, 0.5598334074020386, 0.561272919178009, 0.5600394010543823, 0.5600714683532715, 0.5602531433105469, 0.5603107213973999, 0.5596372485160828, 0.5595220327377319, 0.5601940751075745, 0.5605542659759521, 0.559871256351471, 0.5595396161079407, 0.5603870153427124, 0.5591939091682434, 0.5593236684799194, 0.5593409538269043, 0.5597257614135742, 0.5593305826187134, 0.5594316720962524, 0.5589523911476135, 0.559037446975708, 0.5593436360359192, 0.5595834255218506, 0.5591671466827393, 0.5627565383911133, 0.5588695406913757, 0.5592069625854492]
loss_128 = [0.6651434302330017, 0.6230427026748657, 0.6101459264755249, 0.611891508102417, 0.596498429775238, 0.6077919006347656, 0.5957257151603699, 0.5916337370872498, 0.5892846584320068, 0.5884402990341187, 0.6150095462799072, 0.5878422856330872, 0.5868470668792725, 0.5841116905212402, 0.5830162763595581, 0.5852934718132019, 0.5861793160438538, 0.580944299697876, 0.5778355598449707, 0.574596107006073, 0.5732763409614563, 0.5750085115432739, 0.5716807246208191, 0.5759696960449219, 0.5709481239318848, 0.5698457956314087, 0.5699139833450317, 0.5693665742874146, 0.5683586001396179, 0.5678092241287231, 0.5672999620437622, 0.5670068264007568, 0.5680339336395264, 0.5663964152336121, 0.5661787390708923, 0.567940354347229, 0.5668559670448303, 0.5649978518486023, 0.5673456192016602, 0.564848780632019, 0.5649939179420471, 0.5648012757301331, 0.5641854405403137, 0.5633525848388672, 0.5649952292442322, 0.5637995600700378, 0.5644568800926208, 0.5631456971168518, 0.5624324083328247, 0.5632839798927307, 0.5620836019515991, 0.5616880059242249, 0.5629041194915771, 0.5615797638893127, 0.5617144703865051, 0.561769962310791, 0.5613058805465698, 0.5618610382080078, 0.5611031651496887, 0.5616395473480225, 0.5611751675605774, 0.5604587197303772, 0.5609548091888428, 0.5608373284339905, 0.5603697299957275, 0.5599216818809509, 0.5633348822593689, 0.5597605109214783, 0.5598016977310181, 0.5599160194396973, 0.560280978679657, 0.559654951095581, 0.5597550272941589, 0.5592321157455444, 0.5593011975288391, 0.5589351654052734, 0.5590755343437195, 0.5590528845787048, 0.574508547782898, 0.5601709485054016, 0.5628482103347778, 0.5587095022201538, 0.5585806965827942, 0.558451771736145, 0.5582389831542969, 0.5582034587860107, 0.5585802793502808, 0.5596485137939453, 0.5582942366600037, 0.5583699941635132, 0.5580384731292725, 0.558172881603241, 0.5582118630409241, 0.5589061379432678, 0.5578951239585876, 0.5578790307044983, 0.5579734444618225, 0.558055579662323, 0.5578036904335022, 0.5586898922920227]
loss_256 = [0.6519070267677307, 0.6138643622398376, 0.6221556067466736, 0.6038567423820496, 0.5958243012428284, 0.5932930707931519, 0.5906574130058289, 0.5924011468887329, 0.5851490497589111, 0.5823303461074829, 0.5815746784210205, 0.5787850618362427, 0.5793103575706482, 0.5763710737228394, 0.5759060382843018, 0.5729438066482544, 0.5744713544845581, 0.5724223256111145, 0.5708431601524353, 0.5711768269538879, 0.5691547393798828, 0.5685892701148987, 0.5684875845909119, 0.5677867531776428, 0.567676842212677, 0.5662437677383423, 0.5663963556289673, 0.5661959052085876, 0.5653982758522034, 0.5662538409233093, 0.5650900602340698, 0.564748227596283, 0.5641091465950012, 0.5636481046676636, 0.5635164976119995, 0.56280118227005, 0.5629335641860962, 0.5646363496780396, 0.564286470413208, 0.5621851682662964, 0.5626335144042969, 0.5619467496871948, 0.5618473887443542, 0.5621809959411621, 0.5614476203918457, 0.5611790418624878, 0.5624386668205261, 0.5611171126365662, 0.5610040426254272, 0.560554563999176, 0.5607685446739197, 0.5601770281791687, 0.560185968875885, 0.5599826574325562, 0.5602937340736389, 0.5596851110458374, 0.5595584511756897, 0.559391975402832, 0.5593541264533997, 0.5592750310897827, 0.5594449043273926, 0.5606568455696106, 0.5591813325881958, 0.5595486164093018, 0.5586241483688354, 0.5586174130439758, 0.5588067173957825, 0.5586187839508057, 0.5588344931602478, 0.5582901835441589, 0.5581218004226685, 0.5583655834197998, 0.5587472915649414, 0.5580026507377625, 0.557977020740509, 0.5577593445777893, 0.5578620433807373, 0.5598688125610352, 0.5680185556411743, 0.5582558512687683, 0.558681845664978, 0.5580493807792664, 0.5575885772705078, 0.5582588315010071, 0.5574234127998352, 0.5573419332504272, 0.5576935410499573, 0.5574342608451843, 0.5573176741600037, 0.5580106377601624, 0.5574899911880493, 0.5570968389511108, 0.558347225189209, 0.5572624802589417, 0.5569573640823364, 0.5574367046356201, 0.5571238994598389, 0.5569007396697998, 0.5569027662277222, 0.5566669702529907]
loss_512 = [0.6586852073669434, 0.6446666121482849, 0.6190512776374817, 0.6116129159927368, 0.6089377403259277, 0.6089082956314087, 0.6058579683303833, 0.6043663024902344, 0.6018903255462646, 0.587471067905426, 0.5860590934753418, 0.588096022605896, 0.585354745388031, 0.5927261114120483, 0.5851246118545532, 0.5815338492393494, 0.5791661143302917, 0.5759374499320984, 0.5745606422424316, 0.5730831623077393, 0.5750294923782349, 0.5929646492004395, 0.5712465047836304, 0.5887829661369324, 0.5707591772079468, 0.5697836875915527, 0.5684753060340881, 0.5682927370071411, 0.5679036974906921, 0.5685930848121643, 0.5673362612724304, 0.5673955678939819, 0.5660635828971863, 0.5675015449523926, 0.5670884847640991, 0.5649962425231934, 0.5657569766044617, 0.5660125017166138, 0.5638850331306458, 0.5700755715370178, 0.5640339851379395, 0.5633885264396667, 0.5626549124717712, 0.5627242922782898, 0.5631037950515747, 0.5622428059577942, 0.5619184374809265, 0.5617914795875549, 0.5624736547470093, 0.5612391829490662, 0.5609789490699768, 0.5617445111274719, 0.5612916946411133, 0.5607461929321289, 0.5613078474998474, 0.5607343912124634, 0.5605432987213135, 0.5613492727279663, 0.5612592101097107, 0.560352087020874, 0.5621692538261414, 0.5600977540016174, 0.5596922636032104, 0.5606716871261597, 0.5610446333885193, 0.5597125291824341, 0.5600242018699646, 0.5595635175704956, 0.5590877532958984, 0.5595449209213257, 0.5607255101203918, 0.5592942237854004, 0.5590581297874451, 0.5587218999862671, 0.5596281290054321, 0.5584062337875366, 0.5587581396102905, 0.559036910533905, 0.5592011213302612, 0.5588974356651306, 0.5579813122749329, 0.5582935214042664, 0.5585764646530151, 0.5582740902900696, 0.5578482747077942, 0.5591400265693665, 0.5596509575843811, 0.5575116872787476, 0.5576379895210266, 0.5579900145530701, 0.5574676990509033, 0.5582157373428345, 0.5575436353683472, 0.5576226115226746, 0.5585836172103882, 0.5572994351387024, 0.5573899745941162, 0.5577874183654785, 0.558302640914917, 0.5585382580757141]

plt.plot(loss_32, 'r-', label='32')
plt.plot(loss_64, 'b-', label='64')
plt.plot(loss_128, 'g-', label='128')
plt.plot(loss_256, 'y-', label='256')
plt.plot(loss_512, 'k-', label='512')

elements = 10

print(np.mean(loss_32[-elements:]))
print(np.mean(loss_64[-elements:]))
print(np.mean(loss_128[-elements:]))
print(np.mean(loss_256[-elements:]))
print(np.mean(loss_512[-elements:]))

plt.legend()

plt.show()

In [None]:
for i in range(128):
    test = np.zeros((1, 128))
    test[0][i] = 1
    plt.imshow(decoder.predict(test)[0])
    plt.show()


In [None]:
encoder = Encoder(128)

In [None]:
from sklearn.decomposition import PCA

encoded = encoder.predict(img_array)
pca = PCA(128)
pca.fit(encoded)
pickle.dump(pca, open('pca_model_128', 'wb'))

In [None]:
pca = pickle.load(open('../input/face-autoencoder-10-27/pca_model_128', 'rb'))

In [None]:
encoded = encoder.predict(img_array)
projected = pca.transform(encoded)

In [None]:
X_train = projected[:20000]
X_test = projected[20000:]

Y_train = img_array[:20000]
Y_test = img_array[20000:]

decoder = Decoder(128)
optimizer = keras.optimizers.Adam(learning_rate=0.0003)
decoder.compile(optimizer=optimizer, loss='binary_crossentropy')

EPOCHS = 200
BATCH_SIZE = 32

history = decoder.fit(X_train, Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, shuffle=True, validation_data=(X_test, Y_test))

In [None]:
print(str(history.history['val_loss']))
print(np.mean(history.history['val_loss'][-10:]))

In [None]:
ages = []

for image_name in tqdm(IMAGE_FILES):
    if is_valid_name(image_name):
        age, _, _ = parse_name(image_name)

        ages.append(age)

ages = np.array(ages).reshape((23705, 1)) / 120

ages.shape

#age_one_hot = keras.backend.one_hot(ages, 120)
#pca_with_age = np.hstack((projected, age_one_hot))

pca_with_age = np.hstack((projected, ages))

In [None]:
X_train = pca_with_age[:20000]
X_test = pca_with_age[20000:]

Y_train = img_array[:20000]
Y_test = img_array[20000:]

decoder = Decoder(129)
optimizer = keras.optimizers.Adam(learning_rate=0.001)
decoder.compile(optimizer=optimizer, loss='binary_crossentropy')

EPOCHS = 100
BATCH_SIZE = 32

decoder.fit(X_train, Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, shuffle=True, validation_data=(X_test, Y_test))

In [None]:
encoded = encoder.predict(my_face)
projected = pca.transform(encoded)

In [None]:
decoder = keras.models.load_model('../input/face-autoencoder-10-27/pca_decoder_128.hdf5')

In [None]:
plt.imshow(decoder.predict(np.zeros((1, 128)))[0])
plt.show()

ims = []

fig = plt.figure()

for i in range(20):
    print(f'Component {i}')
    test = np.zeros((1, 128))

    test[0][i] = 1
    plt.imshow(decoder.predict(test)[0])
    plt.show()
    
    test[0][i] = -1
    plt.imshow(decoder.predict(test)[0])
    plt.show()

In [None]:
# Generates random faces!
for i in range(36):
    #pca_vector = np.random.random((1, 128)) * 2 - 1
    pca_vector = np.random.normal(0, 6.2 * pca.explained_variance_, (1, 128))
    
    #for j in range(128):
        #pca_vector[0][j] *= 10 * pca.explained_variance_[j]
        
    plt.imshow(decoder.predict(pca_vector)[0])
    plt.show()

In [None]:
encoded = encoder.predict(my_face)
projected = pca.transform(encoded)
face_vec = np.copy(projected)

component = 9
multiplier = 2

print(face_vec[0][component])

face_vec[0][component] *= -multiplier
plt.imshow(decoder.predict(face_vec)[0])
plt.show()

face_vec = np.copy(projected)
plt.imshow(decoder.predict(face_vec)[0])
plt.show()

face_vec = np.copy(projected)
face_vec[0][component] *= multiplier
plt.imshow(decoder.predict(face_vec)[0])
plt.show()

In [None]:
decoder.summary()

In [None]:
decoder.save('./pca_decoder_128.hdf5')

In [None]:
pca_vector = np.zeros((1, 128))

pca_vector[0][0] = 0
pca_vector[0][1] = 0
pca_vector[0][2] = 1
pca_vector[0][3] = 0
pca_vector[0][4] = 0
pca_vector[0][5] = 0
pca_vector[0][6] = -1
pca_vector[0][7] = 0
pca_vector[0][8] = 0
pca_vector[0][9] = 0
pca_vector[0][10] = 0
pca_vector[0][11] = 0

plt.imshow(decoder.predict(pca_vector)[0])
plt.show()
