In [1]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import Row
from pyspark.sql.types import *
from pyspark.sql import SQLContext
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.classification import RandomForestClassifier

from PIL import Image
from PIL import ImageOps 
import numpy as np 
import os 
from IPython.display import display, HTML

import glob
import re

In [2]:
conf = SparkConf()
conf = (conf.setMaster('local[*]')
        .set('spark.executor.memory', '4G')
        .set('spark.driver.memory', '16G')
        .set('spark.driver.maxResultSize', '10G'))
sc = SparkContext(conf=conf)
sqlContext = SQLContext(sc)

# Resizing the images

In [11]:
def imageResize(basename,imageName):
    """
    resize image
    basename : eg. /home/username/XYZFolder
    image name : xyz.jpg
    New folder in the working directory will be created with '_resized' as suffix
    """
    new_width  = 128
    new_height = 128
    try:  
        img = Image.open(basename+"/"+imageName) # image extension *.png,*.jpg
        img = img.resize((new_width, new_height), Image.ANTIALIAS)
        img.save(basename+'_resized/'+imageName)
    except:
        os.mkdir(basename+'_resized/')
        img = Image.open(basename+"/"+imageName) # image extension *.png,*.jpg
        img = img.resize((new_width, new_height), Image.ANTIALIAS)
        img.save(basename+'_resized/'+imageName)

def resizer(folderPath):
    """
    to resize all files present in a folder
    resizer('/home/username/XYZFolder')
    """
    
    for subdir, dirs, files in os.walk(folderPath):
        for fileName in files:
#             try:
                #  print os.path.join(subdir, file)
                filepath = subdir + os.sep + fileName
                #  print filepath
                if filepath.endswith(".jpg" or ".jpeg" or ".png" or ".gif"):
                    imageResize(subdir,fileName)
#             except:
#                 print traceback.print_exc()

In [8]:
# resizer('wiki_crop/wiki_crop_new/')
# # went to wiki_crop/wiki_crop_new/_resized/

# Filtering & Converting images to pixels

In [73]:
path = 'wiki_crop/wiki_crop_new/_resized/'

In [74]:
def filterimage(path):
    my_sub_dir = glob.glob(path + '*.jpg')
    for i in my_sub_dir:
        if os.path.getsize(i) < 1000:
            # print(path + str(i) + '/' + str(j))
            os.remove(i)

#filter images that are corrupted
# filterimage(path)

In [75]:
def load_image2(infilename) :
    '''  
    convert image file to pixels
    load_image2('fileName')
    '''
    img = Image.open(infilename).convert('L')
    data = np.array(img)
    return data


def ageAtPhoto(fileName):
    '''
    get age at time of photo
    ageAtPhoto('full_path_to_file')
    10049200_1891-09-16_1958.jpg
    yob is 1891
    dtpt is 1958
    '''
    basename = fileName.split('/')[-1].split('_')
    birth = int(basename[1].split('-')[0])
    today = int(basename[2].split('.')[0])
    currAge = abs(today - birth)
    return currAge



def convertToNumpy(folder):
    '''
    get pixels and age for each image in a folder
    x_values, y_values = convertToNumpy(fileNames)    
    '''
    pixels = []
    ages = []
    filename =[]
    for fileName in folder:
#         if fileName.endswith(".jpg" or ".jpeg" or ".png"):
            age = ageAtPhoto(fileName)
            if (age<100 and age>0):
                img_px = np.ravel(load_image2(fileName))
                pixels.append(img_px)
                ages.append(age)
                filename.append(fileName.split('/')[-1])
    return pixels, ages,filename

In [77]:
folderName = 'wiki_crop/wiki_crop_new/_resized/' 
fileNames = glob.glob(folderName +'*.jpg')
# only test 20 files for now
NumberOfFileToTrained =70000

x_values, y_values,filename = convertToNumpy(fileNames[:NumberOfFileToTrained])

# print x_values
# print y_values

In [78]:
len(y_values)

14833

# Putting the data on RDD and converting to DF

In [79]:
flat_rdd = sc.parallelize(x_values).map(lambda x : x.tolist()).map(lambda x: [int(element) for element in x])
# len(flat_rdd.take(5)[0])
age_rdd = sc.parallelize(y_values).map(lambda x:int(x))
# age_rdd.take(5)
f_name = sc.parallelize(filename)
combined = flat_rdd.zip(age_rdd).zip(f_name)
# combined.getNumPartitions()

In [14]:
combined.collect()

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.


In [80]:
#Create a DataFrame
imageschema = StructType([
   StructField("features", ArrayType(elementType=IntegerType(),containsNull=False),True),
   StructField("label", IntegerType(),True),
   StructField("f_name", StringType(),True)
])
df = sqlContext.createDataFrame(combined.map(lambda x : Row(x[0][0][:],x[0][1],x[1])), imageschema)
df.show()

+--------------------+-----+--------------------+
|            features|label|              f_name|
+--------------------+-----+--------------------+
|[67, 67, 67, 65, ...|   21|31843216_1990-06-...|
|[139, 140, 142, 1...|   46|1000001952958_196...|
|[0, 0, 0, 0, 0, 0...|   98|12872267_1910-10-...|
|[224, 223, 216, 2...|   19|43481633_1995-11-...|
|[220, 219, 218, 2...|   52|1000001671401_195...|
|[3, 3, 3, 3, 3, 3...|   59|100000185210_1944...|
|[21, 19, 18, 21, ...|   27|1000007243240_197...|
|[196, 196, 196, 1...|   31|10000022821347_19...|
|[35, 35, 35, 35, ...|   61|10000029563647_19...|
|[23, 23, 23, 23, ...|   28|1000003156912_198...|
|[61, 60, 60, 62, ...|   39|10000026256606_19...|
|[83, 71, 57, 54, ...|   34|610477_1976-04-20...|
|[123, 123, 123, 1...|   45|30948829_1963-03-...|
|[204, 204, 204, 2...|   34|10000012277158_19...|
|[0, 0, 0, 0, 0, 0...|   27|1000001924882_194...|
|[252, 252, 252, 2...|   24|10000015230998_19...|
|[195, 200, 205, 2...|   63|10000028738868_19...|


In [81]:
df.printSchema()

root
 |-- features: array (nullable = true)
 |    |-- element: integer (containsNull = false)
 |-- label: integer (nullable = true)
 |-- f_name: string (nullable = true)



In [26]:
# Download df in json format
# df.write.format('json').save('project/dataset.json')

In [27]:
# Write df to MongoDB
df.write.format("com.mongodb.spark.sql.DefaultSource").mode("append").save()

# Read df from MongoDB
df = spark.read.format("com.mongodb.spark.sql.DefaultSource").load()
df.show()

+--------------------+-----+--------------------+
|            features|label|              f_name|
+--------------------+-----+--------------------+
|[67, 67, 67, 65, ...|   21|31843216_1990-06-...|
|[139, 140, 142, 1...|   46|1000001952958_196...|
|[0, 0, 0, 0, 0, 0...|   98|12872267_1910-10-...|
|[224, 223, 216, 2...|   19|43481633_1995-11-...|
|[220, 219, 218, 2...|   52|1000001671401_195...|
|[3, 3, 3, 3, 3, 3...|   59|100000185210_1944...|
|[21, 19, 18, 21, ...|   27|1000007243240_197...|
|[196, 196, 196, 1...|   31|10000022821347_19...|
|[35, 35, 35, 35, ...|   61|10000029563647_19...|
|[23, 23, 23, 23, ...|   28|1000003156912_198...|
|[61, 60, 60, 62, ...|   39|10000026256606_19...|
|[83, 71, 57, 54, ...|   34|610477_1976-04-20...|
|[123, 123, 123, 1...|   45|30948829_1963-03-...|
|[204, 204, 204, 2...|   34|10000012277158_19...|
|[0, 0, 0, 0, 0, 0...|   27|1000001924882_194...|
|[252, 252, 252, 2...|   24|10000015230998_19...|
|[195, 200, 205, 2...|   63|10000028738868_19...|


# Splitting the df to train/test

In [82]:
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql.functions import udf
list_to_vector_udf = udf(lambda l: Vectors.dense(l), VectorUDT())
df = df.select(list_to_vector_udf(df["features"]).alias("features"),'label','f_name')

dataset = df.randomSplit([0.8, 0.2])
train = dataset[0].cache()
test = dataset[1].cache()

In [83]:
train.show()
test.show()

+--------------------+-----+--------------------+
|            features|label|              f_name|
+--------------------+-----+--------------------+
|[0.0,0.0,0.0,0.0,...|    8|29307820_1984-05-...|
|[0.0,0.0,0.0,0.0,...|   11|12321426_1951-04-...|
|[0.0,0.0,0.0,0.0,...|   15|2729047_1981-01-2...|
|[0.0,0.0,0.0,0.0,...|   19|10000046942333_19...|
|[0.0,0.0,0.0,0.0,...|   20|10000038904783_19...|
|[0.0,0.0,0.0,0.0,...|   20|37068594_1993-06-...|
|[0.0,0.0,0.0,0.0,...|   21|10000041491567_19...|
|[0.0,0.0,0.0,0.0,...|   21|17474041_1951-10-...|
|[0.0,0.0,0.0,0.0,...|   21|43684651_1963-04-...|
|[0.0,0.0,0.0,0.0,...|   22|10000012249351_19...|
|[0.0,0.0,0.0,0.0,...|   22|10000033328513_19...|
|[0.0,0.0,0.0,0.0,...|   22|10000038381611_19...|
|[0.0,0.0,0.0,0.0,...|   22|2587518_1983-01-3...|
|[0.0,0.0,0.0,0.0,...|   22|42356895_1992-10-...|
|[0.0,0.0,0.0,0.0,...|   23|10000029763303_19...|
|[0.0,0.0,0.0,0.0,...|   23|10000044065591_19...|
|[0.0,0.0,0.0,0.0,...|   23|27045513_1987-03-...|


# Training the model

In [84]:
from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression(maxIter=10, fitIntercept=True)
lrmodel = lr.fit(train)

## Predictions

In [85]:
from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression(maxIter=10, fitIntercept=True)
lrmodel = lr.fit(train)

In [86]:
validpredict = lrmodel.transform(test)
validpredict.show()

+--------------------+-----+--------------------+--------------------+--------------------+----------+
|            features|label|              f_name|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+--------------------+----------+
|[0.0,0.0,0.0,0.0,...|   21|19215308_1986-05-...|[-4.0758157443478...|[8.32752880314212...|      26.0|
|[0.0,0.0,0.0,0.0,...|   22|10000046408247_19...|[-4.0758157443478...|[8.32752880314212...|      26.0|
|[0.0,0.0,0.0,0.0,...|   25|10000019521511_19...|[-4.0758157443478...|[8.32752880314212...|      26.0|
|[0.0,0.0,0.0,0.0,...|   26|1000003749782_197...|[-4.0758157443478...|[8.32752880314212...|      26.0|
|[0.0,0.0,0.0,0.0,...|   47|971296_1953-02-03...|[-4.0758157443478...|[8.32752880314212...|      26.0|
|[1.0,2.0,3.0,4.0,...|   35|2986476_1965-09-2...|[-4.5217781994925...|[4.17183340552383...|      29.0|
|[3.0,3.0,3.0,3.0,...|   36|6241134_1975-04-1...|[-4.4949556592175...|[3.

# Testing on Individual Images

In [87]:
path = 'test_files/'

In [88]:
folderName = 'test_files/' 
fileNames = glob.glob(folderName +'*.png')
# only testing 4 files for now

x_values, y_values,filename = convertToNumpy(fileNames)

In [89]:
flat_rdd = sc.parallelize(x_values).map(lambda x : x.tolist()).map(lambda x: [int(element) for element in x])
age_rdd = sc.parallelize(y_values).map(lambda x:int(x))
f_name = sc.parallelize(filename)
combined = flat_rdd.zip(age_rdd).zip(f_name)

In [90]:
#Create a DataFrame
imageschema = StructType([
   StructField("features", ArrayType(elementType=IntegerType(),containsNull=False),True),
   StructField("label", IntegerType(),True),
   StructField("f_name", StringType(),True)
])
df = sqlContext.createDataFrame(combined.map(lambda x : Row(x[0][0][:],x[0][1],x[1])), imageschema)
df.show()

+--------------------+-----+--------------------+
|            features|label|              f_name|
+--------------------+-----+--------------------+
|[118, 135, 133, 1...|   30|666666_1986-10-24...|
|[255, 255, 255, 2...|   29|101010_1987-06-24...|
|[22, 21, 21, 21, ...|   32|666777_1986-10-24...|
|[79, 108, 105, 10...|   31|101011_1987-06-24...|
+--------------------+-----+--------------------+



In [91]:
list_to_vector_udf = udf(lambda l: Vectors.dense(l), VectorUDT())
df_test = df.select(list_to_vector_udf(df["features"]).alias("features"),'label','f_name')

In [92]:
validpredict = lrmodel.transform(df_test)
validpredict.show()

+--------------------+-----+--------------------+--------------------+--------------------+----------+
|            features|label|              f_name|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+--------------------+----------+
|[118.0,135.0,133....|   30|666666_1986-10-24...|[-4.4491482230859...|[6.29815527043899...|      43.0|
|[255.0,255.0,255....|   29|101010_1987-06-24...|[-4.7206409159015...|[3.82764225278437...|      21.0|
|[22.0,21.0,21.0,2...|   32|666777_1986-10-24...|[-4.3567617354430...|[6.67115644994620...|      43.0|
|[79.0,108.0,105.0...|   31|101011_1987-06-24...|[-4.4738548639421...|[3.60126166168097...|      25.0|
+--------------------+-----+--------------------+--------------------+--------------------+----------+

