In [1]:
import pyspark

from pyspark.sql import SparkSession
from pyspark.sql.types import *


spark = SparkSession.builder.appName('').getOrCreate()
sdf = spark.read.csv('../data/imputation.txt', sep='\t', inferSchema=True, header=True)

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/06/15 12:16:24 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
sdf.show()
sdf.describe().printSchema()

+---+-------+-----+
| id|   time|value|
+---+-------+-----+
|  0|0:00:00|  0.0|
|  1|1:00:00|  1.0|
|  2|2:00:00|  NaN|
|  3|3:00:00|  NaN|
|  4|4:00:00|  4.0|
|  5|5:00:00|  5.0|
|  6|6:00:00|  NaN|
|  7|7:00:00|  7.0|
|  8|8:00:00|  8.0|
|  9|9:00:00|  9.0|
+---+-------+-----+

root
 |-- summary: string (nullable = true)
 |-- id: string (nullable = true)
 |-- time: string (nullable = true)
 |-- value: string (nullable = true)



In [3]:
from pyspark.sql.functions import mean as _mean, col
def replace_nan_with_mean(sdf, columns):
    sdf_dropped = sdf.dropna(how ='any')
    sdf_stats = dict()
    for column in columns:
        if isinstance(sdf_dropped.schema[column].dataType, DoubleType):
            sdf_stats = sdf_dropped.select(_mean(col(column)).alias('mean')).collect()
            mean = round(sdf_stats[0]['mean'])
            print(column)
            sdf_filled = sdf.fillna(mean, subset=[column])
    
    sdf_filled.show()
    

In [4]:
replace_nan_with_mean(sdf, columns=['id', 'time', 'value'])

value
+---+-------+-----+
| id|   time|value|
+---+-------+-----+
|  0|0:00:00|  0.0|
|  1|1:00:00|  1.0|
|  2|2:00:00|  5.0|
|  3|3:00:00|  5.0|
|  4|4:00:00|  4.0|
|  5|5:00:00|  5.0|
|  6|6:00:00|  5.0|
|  7|7:00:00|  7.0|
|  8|8:00:00|  8.0|
|  9|9:00:00|  9.0|
+---+-------+-----+



In [5]:
print(list(sdf.select('value').toPandas()['value']))


[0.0, 1.0, nan, nan, 4.0, 5.0, nan, 7.0, 8.0, 9.0]


In [16]:
from pyspark.sql.functions import mean as _mean, col
def run_impute(sdf, columns, methods):
    sdf_dropped = sdf.dropna(how ='any')
    for column, method in zip(columns, methods): 
        output = 0
        if method == 'mean':
            mean = lambda c: round(sum(c)/len(c))
            c = []
            c = list(sdf_dropped.select(column).toPandas()[column])
            output = mean(c)
        elif method == 'median':
            output = sdf_dropped.approxQuantile(column, [0.5], 0.25)[0]
        else:
            output == 0 
        
        sdf_filled = sdf.fillna(output, subset=[column])
        sdf_filled.show()
    
 
    
run_impute(sdf, columns=['value', 'value'], methods=['mean', 'median'])

4.857142857142857
+---+-------+-----+
| id|   time|value|
+---+-------+-----+
|  0|0:00:00|  0.0|
|  1|1:00:00|  1.0|
|  2|2:00:00|  5.0|
|  3|3:00:00|  5.0|
|  4|4:00:00|  4.0|
|  5|5:00:00|  5.0|
|  6|6:00:00|  5.0|
|  7|7:00:00|  7.0|
|  8|8:00:00|  8.0|
|  9|9:00:00|  9.0|
+---+-------+-----+

+---+-------+-----+
| id|   time|value|
+---+-------+-----+
|  0|0:00:00|  0.0|
|  1|1:00:00|  1.0|
|  2|2:00:00|  4.0|
|  3|3:00:00|  4.0|
|  4|4:00:00|  4.0|
|  5|5:00:00|  5.0|
|  6|6:00:00|  4.0|
|  7|7:00:00|  7.0|
|  8|8:00:00|  8.0|
|  9|9:00:00|  9.0|
+---+-------+-----+

