In [12]:
import tensorflow as tf
import os
import numpy as np
from dataset import path_to_image_crop, input_pipeline, images_as_float

def model1(x, y_, data_size):
    W = tf.Variable(tf.zeros([data_size, 1]))
    b = tf.Variable(tf.zeros([2]))
    y = tf.matmul(x, W) + b

    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
    return tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy), y, cross_entropy

def model2(x, y_):
    def fully_connected(input, size):
        weights = tf.get_variable( 'weights', 
            shape = [input.get_shape()[1], size],
            initializer = tf.contrib.layers.xavier_initializer()
          )
        biases = tf.get_variable( 'biases',
            shape = [size],
            initializer=tf.constant_initializer(0.0)
          )
        return tf.matmul(input, weights) + biases
    def model_pass(input):
        with tf.variable_scope('hidden'):
            hidden = fully_connected(input, size = 100)
        relu_hidden = tf.nn.relu(hidden)
        with tf.variable_scope('out'):
            prediction = fully_connected(relu_hidden, size = 2)
        return prediction   
    
    predictions = model_pass(x)  
    loss = tf.reduce_mean(tf.square(predictions - y_))
    optimizer = tf.train.MomentumOptimizer(
        learning_rate = 0.01, 
        momentum = 0.9, 
        use_nesterov = True
    ).minimize(loss)
    
    return optimizer, predictions, loss

people_path = '/data/people_classification_all/fold_*_data.txt'
image_prefix = 'coarse_tilt_aligned_face'
batch_size = 128
image_size = 227
image_dimension = [image_size,image_size]
num_epochs = 1000

model_name = "1fc_b" + str(batch_size) + "_e" + str(num_epochs - 1)
model_variable_scope = model_name

def extract_gender(features):
    def extract(v):
        def f1(): return tf.constant([1,0])
        def f2(): return tf.constant([0,1])
        def f_(): return tf.constant([0,0])
        return tf.case({
            tf.equal(v[1], tf.constant(1)): f1,
            tf.equal(v[1], tf.constant(2)): f2,
            }, default=f_, exclusive=True)
    return tf.map_fn(extract, features, dtype=tf.int32)


data_size = image_dimension[0] * image_dimension[1] * 3

graph = tf.Graph()
with graph.as_default():
    path_batch, label_batch = input_pipeline(people_path, batch_size, None, True)
    label_batch = extract_gender(label_batch)
    label_batch = tf.reshape(label_batch,[batch_size,2])

    data_batch = path_to_image_crop(path_batch, os.path.dirname(people_path), image_prefix, image_dimension)
    data_batch = tf.reshape(data_batch,[batch_size, data_size])

    x = tf.placeholder(tf.float32, [None, data_size])
    y_ = tf.placeholder(tf.float32, [None, 2])
    
    #train_step, y, loss = model1(x, y_, data_size)
    train_step, y, loss = model2(x, y_)

with tf.Session(graph = graph) as session:
    session.run(tf.global_variables_initializer())
    session.run(tf.local_variables_initializer())
    coord = tf.train.Coordinator ()
    threads = tf.train.start_queue_runners (coord = coord)
    for i in range(num_epochs):
        batch_xs, batch_ys = session.run([data_batch, label_batch])
        batch_xs = images_as_float(batch_xs, batch_size, data_size)
        p,l,_ = session.run([y, loss, train_step], feed_dict={x: batch_xs, y_: batch_ys})
        print('%d: %s -> %s %s' % (i, l, p[i % batch_size], batch_ys[i % batch_size])) # 1 -> male, 2 -> female
        if (i == 0): print(batch_xs)
    coord.request_stop ()
    coord.join (threads)      

0: 0.820444 -> [ 0.7785638   0.17644775] [1 0]
[[ 0.42745098  0.4         0.36862745 ...,  0.18039216  0.08235294
   0.06666667]
 [ 0.40392157  0.35686275  0.2627451  ...,  0.16078431  0.11764706
   0.14117647]
 [ 0.23529412  0.1372549   0.12156863 ...,  0.36078431  0.11764706
   0.21568627]
 ..., 
 [ 0.60784314  0.43137255  0.34901961 ...,  0.19215686  0.10588235
   0.11764706]
 [ 0.41176471  0.28235294  0.24705882 ...,  0.38823529  0.27843137
   0.19215686]
 [ 0.          0.00392157  0.         ...,  0.          0.00392157
   0.01960784]]
1: 13963.5 -> [-179.78222656  171.47874451] [1 0]
2: 6.65374e+08 -> [ 14878.20507812 -14850.13378906] [1 0]
3: 3.57617e+07 -> [-3879.42919922  6603.88378906] [0 1]
4: 299687.0 -> [-575.04150391  518.39080811] [1 0]
5: 446401.0 -> [-706.7364502  627.1751709] [1 0]
6: 591358.0 -> [-817.00469971  717.83569336] [1 0]
7: 724631.0 -> [-907.0793457   791.44152832] [1 0]
8: 828642.0 -> [-954.19421387  866.4239502 ] [1 0]
9: 6.57381e+15 -> [ 66663184. -61471

KeyboardInterrupt: 