In [None]:
'''
REMEMBER TO INCREASE THE RAM
REMEMBER TO INCLUDE THE JAR FILE: tensorflow-hadoop-1.0-SNAPSHOT.jar
'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from array import array
from hops import hdfs

import numpy as np
from PIL import Image
import io
from skimage.color import rgb2lab

def is_grey_scale(array, w, h):
    for i in range(w):
        for j in range(h):
            r,g,b = array[i,j,:]
            if r != g != b: return False
    return True

def readImages(sc, folder):
    imageRDD = sc.binaryFiles(folder + "/**/**/*.jpg")
    #imageRDD = sc.binaryFiles(folder + "/*.jpg")
    imageRDD = imageRDD.map(lambda binaryData: binaryData[1])
    imageRDD = imageRDD.map(lambda rawData: Image.open(io.BytesIO(rawData)))
    imageRDD = imageRDD.map(lambda image: image.resize((256,256), Image.ANTIALIAS))
    imageRDD = imageRDD.map(lambda image: np.array(image, dtype=float))
    imageRDD = imageRDD.filter(lambda array: array.shape == (256, 256, 3))
    imageRDD = imageRDD.filter(lambda array: not is_grey_scale(array, 256, 256))
    imageRDD = imageRDD.map(lambda array: rgb2lab(array))
    imageRDD = imageRDD.map(lambda lab: (lab[:,:,0], lab[:,:,1] / 128, lab[:,:,2] / 128))
    imageRDD = imageRDD.map(lambda (L, A, B): (L.reshape(L.shape+(1,)), A, B))
    imageRDD = imageRDD.map(lambda (L, A, B): (
        L.reshape(L.shape[0] * L.shape[1]),
        A.reshape(A.shape[0] * A.shape[1]),
        B.reshape(B.shape[0] * B.shape[1])
    ))

    return imageRDD

def toTFExample(L, A, B):
    """Serializes an image/label as a TFExample byte string"""
    example = tf.train.Example(
        features = tf.train.Features(
            feature = {
                'L': tf.train.Feature(float_list=tf.train.FloatList(value=L)),
                'A': tf.train.Feature(float_list=tf.train.FloatList(value=A)),
                'B': tf.train.Feature(float_list=tf.train.FloatList(value=B))
            }
        )
    )
    return example.SerializeToString()

def fromTFExample(bytestr):
    """Deserializes a TFExample from a byte string"""
    example = tf.train.Example()
    example.ParseFromString(bytestr)
    return example

def writeTFRECORDS(sc, input_dir,  output_dir):
    imageRDD = readImages(sc, input_dir)
    
    tfRDD = imageRDD.map(lambda (L, A, B): (bytearray(toTFExample(L, A, B)), None))
    tfRDD.saveAsNewAPIHadoopFile(output_dir, "org.tensorflow.hadoop.io.TFRecordFileOutputFormat",
                                keyClass="org.apache.hadoop.io.BytesWritable",
                                valueClass="org.apache.hadoop.io.NullWritable")

In [None]:
from pyspark.context import SparkContext
from pyspark.conf import SparkConf

sc = spark.sparkContext

writeTFRECORDS(
    sc, #spark context
    "hdfs:///Projects/colorizeML2/imbd_face_dataset", #input
    "hdfs:///Projects/colorizeML2/imbd_face_dataset/processed" #output
)