# Spark data frames from CSV files: handling headers & column types

### Source
https://www.nodalpoint.com/spark-data-frames-from-csv-files-handling-headers-column-types/

設定 `IPYTHON_OPTS="notebook" $SPARK_HOME/bin/pyspark`

In [1]:
from pyspark import SparkContext
sc = SparkContext("local", "test App") # local 是 URL，test App 是這程式的名字

from pyspark.sql import SQLContext
from pyspark.sql.types import *
sqlContext = SQLContext(sc)

In [2]:
taxiFile = sc.textFile("nyctaxisub.csv")
taxiFile.count() # 計算有多少 rows

250000

In [3]:
type(taxiFile)

pyspark.rdd.RDD

In [4]:
taxiFile.take(5) # 取出前 5 rows，結果是一個由前 5 rows 組成的 list

['"_id","_rev","dropoff_datetime","dropoff_latitude","dropoff_longitude","hack_license","medallion","passenger_count","pickup_datetime","pickup_latitude","pickup_longitude","rate_code","store_and_fwd_flag","trip_distance","trip_time_in_secs","vendor_id"',
 '"29b3f4a30dea6688d4c289c9672cb996","1-ddfdec8050c7ef4dc694eeeda6c4625e","2013-01-11 22:03:00",+4.07033460000000E+001,-7.40144200000000E+001,"A93D1F7F8998FFB75EEF477EB6077516","68BC16A99E915E44ADA7E639B4DD5F59",2,"2013-01-11 21:48:00",+4.06760670000000E+001,-7.39810790000000E+001,1,,+4.08000000000000E+000,900,"VTS"',
 '"2a80cfaa425dcec0861e02ae44354500","1-b72234b58a7b0018a1ec5d2ea0797e32","2013-01-11 04:28:00",+4.08190960000000E+001,-7.39467470000000E+001,"64CE1B03FDE343BB8DFB512123A525A4","60150AA39B2F654ED6F0C3AF8174A48A",1,"2013-01-11 04:07:00",+4.07280540000000E+001,-7.40020370000000E+001,1,,+8.53000000000000E+000,1260,"VTS"',
 '"29b3f4a30dea6688d4c289c96758d87e","1-387ec30eac5abda89d2abefdf947b2c1","2013-01-11 22:02:00",+4.0727

想要知道每個欄位的名字，所以取出 header

In [5]:
header = taxiFile.first() # 取得第 1 row
header

'"_id","_rev","dropoff_datetime","dropoff_latitude","dropoff_longitude","hack_license","medallion","passenger_count","pickup_datetime","pickup_latitude","pickup_longitude","rate_code","store_and_fwd_flag","trip_distance","trip_time_in_secs","vendor_id"'

In [6]:
type(header)

str

因為 header 的每個元素都還有雙引號括起來，所以要去掉雙引號的部分

In [7]:
schemaString = header.replace('"', '') # 把 header 中的雙引號去掉
schemaString

'_id,_rev,dropoff_datetime,dropoff_latitude,dropoff_longitude,hack_license,medallion,passenger_count,pickup_datetime,pickup_latitude,pickup_longitude,rate_code,store_and_fwd_flag,trip_distance,trip_time_in_secs,vendor_id'

去掉雙引號之後就可以拿來做 schema 了，先全部當成 `StringType`

In [8]:
fields = [StructField(field_name, StringType(), True) for field_name in schemaString.split(",")]
fields

[StructField(_id,StringType,true),
 StructField(_rev,StringType,true),
 StructField(dropoff_datetime,StringType,true),
 StructField(dropoff_latitude,StringType,true),
 StructField(dropoff_longitude,StringType,true),
 StructField(hack_license,StringType,true),
 StructField(medallion,StringType,true),
 StructField(passenger_count,StringType,true),
 StructField(pickup_datetime,StringType,true),
 StructField(pickup_latitude,StringType,true),
 StructField(pickup_longitude,StringType,true),
 StructField(rate_code,StringType,true),
 StructField(store_and_fwd_flag,StringType,true),
 StructField(trip_distance,StringType,true),
 StructField(trip_time_in_secs,StringType,true),
 StructField(vendor_id,StringType,true)]

In [9]:
len(fields) # number of columns

16

In [10]:
# 把應該不是 string 形態的欄位改成正確的型態
fields[2].dataType = TimestampType()
fields[3].dataType = FloatType()
fields[4].dataType = FloatType()
fields[7].dataType = IntegerType()
fields[8].dataType = TimestampType()
fields[9].dataType = FloatType()
fields[10].dataType = FloatType()
fields[11].dataType = IntegerType()
fields[13].dataType = FloatType()
fields[14].dataType = IntegerType()

fields

[StructField(_id,StringType,true),
 StructField(_rev,StringType,true),
 StructField(dropoff_datetime,TimestampType,true),
 StructField(dropoff_latitude,FloatType,true),
 StructField(dropoff_longitude,FloatType,true),
 StructField(hack_license,StringType,true),
 StructField(medallion,StringType,true),
 StructField(passenger_count,IntegerType,true),
 StructField(pickup_datetime,TimestampType,true),
 StructField(pickup_latitude,FloatType,true),
 StructField(pickup_longitude,FloatType,true),
 StructField(rate_code,IntegerType,true),
 StructField(store_and_fwd_flag,StringType,true),
 StructField(trip_distance,FloatType,true),
 StructField(trip_time_in_secs,IntegerType,true),
 StructField(vendor_id,StringType,true)]

In [11]:
# 把下劃線拿掉
fields[0].name = "id"
fields[1].name = "rev"

fields

[StructField(id,StringType,true),
 StructField(rev,StringType,true),
 StructField(dropoff_datetime,TimestampType,true),
 StructField(dropoff_latitude,FloatType,true),
 StructField(dropoff_longitude,FloatType,true),
 StructField(hack_license,StringType,true),
 StructField(medallion,StringType,true),
 StructField(passenger_count,IntegerType,true),
 StructField(pickup_datetime,TimestampType,true),
 StructField(pickup_latitude,FloatType,true),
 StructField(pickup_longitude,FloatType,true),
 StructField(rate_code,IntegerType,true),
 StructField(store_and_fwd_flag,StringType,true),
 StructField(trip_distance,FloatType,true),
 StructField(trip_time_in_secs,IntegerType,true),
 StructField(vendor_id,StringType,true)]

In [12]:
# 建立 schema
schema = StructType(fields)
schema

StructType(List(StructField(id,StringType,true),StructField(rev,StringType,true),StructField(dropoff_datetime,TimestampType,true),StructField(dropoff_latitude,FloatType,true),StructField(dropoff_longitude,FloatType,true),StructField(hack_license,StringType,true),StructField(medallion,StringType,true),StructField(passenger_count,IntegerType,true),StructField(pickup_datetime,TimestampType,true),StructField(pickup_latitude,FloatType,true),StructField(pickup_longitude,FloatType,true),StructField(rate_code,IntegerType,true),StructField(store_and_fwd_flag,StringType,true),StructField(trip_distance,FloatType,true),StructField(trip_time_in_secs,IntegerType,true),StructField(vendor_id,StringType,true)))

In [13]:
# 取出 header 那一 row，不可以用前面定義的 header 變數，因為不是 rdd 的一部分
taxiHeader = taxiFile.filter(lambda l: "_id" in l)
taxiHeader.collect()

['"_id","_rev","dropoff_datetime","dropoff_latitude","dropoff_longitude","hack_license","medallion","passenger_count","pickup_datetime","pickup_latitude","pickup_longitude","rate_code","store_and_fwd_flag","trip_distance","trip_time_in_secs","vendor_id"']

In [14]:
taxiNoHeader = taxiFile.subtract(taxiHeader) # 取出 header 以外的其他 rows
taxiNoHeader.count()

249999

In [15]:
from datetime import *
from dateutil.parser import parse
parse("2013-02-09 18:16:10") # 測試一下

datetime.datetime(2013, 2, 9, 18, 16, 10)

In [16]:
taxiNoHeader.take(1)

['"29b3f4a30dea6688d4c289c96758d87e","1-387ec30eac5abda89d2abefdf947b2c1","2013-01-11 22:02:00",+4.07277180000000E+001,-7.39942860000000E+001,"2D73B0C44F1699C67AB8AE322433BDB7","6F907BC9A85B7034C8418A24A0A75489",5,"2013-01-11 21:46:00",+4.07577480000000E+001,-7.39649810000000E+001,1,,+3.01000000000000E+000,960,"VTS"']

In [17]:
taxi_temp = taxiNoHeader.map(lambda k: k.split(","))\
                        .map(lambda p: (p[0], p[1], parse(p[2].strip('"')), float(p[3]),
                                        float(p[4]), p[5], p[6], int(p[7]),
                                        parse(p[8].strip('"')), float(p[9]), float(p[10]), int(p[11]),
                                        p[12], float(p[13]), int(p[14]), p[15]
                                       ))
taxi_temp.top(2) # 看前兩 rows

[('"fff43e5eb5662eecf42a3f9b5ff42214"',
  '"1-2e9ea2f49a29663d699d1940f42fab66"',
  datetime.datetime(2013, 11, 26, 13, 15),
  40.764915,
  -73.982536,
  '"564F38A1BC4B1AA7EC528E6C2C81EAAC"',
  '"3E29713986A6762D985C4FC53B177F61"',
  1,
  datetime.datetime(2013, 11, 26, 13, 2),
  40.786667,
  -73.972023,
  1,
  '',
  1.87,
  780,
  '"VTS"'),
 ('"fff43e5eb5662eecf42a3f9b5ff1fc5b"',
  '"1-18b010dab3a3f83ebf4b9f31e88c615d"',
  datetime.datetime(2013, 11, 26, 3, 59),
  40.686081,
  -73.952072,
  '"5E3208C5FA0E44EA08223489E3853EAD"',
  '"DC67FC4851D7642EDCA34A8A3C44F116"',
  1,
  datetime.datetime(2013, 11, 26, 3, 42),
  40.740715,
  -74.004562,
  1,
  '',
  5.84,
  1020,
  '"VTS"')]

In [18]:
type(taxi_temp)

pyspark.rdd.PipelinedRDD

In [19]:
# 用 taxi_temp 來建立 DataFrame
# taxi_temp 是 rdd
taxi_df =sqlContext.createDataFrame(taxi_temp, schema)
taxi_df.head(10)

[Row(id='"29b3f4a30dea6688d4c289c96758d87e"', rev='"1-387ec30eac5abda89d2abefdf947b2c1"', dropoff_datetime=datetime.datetime(2013, 1, 11, 22, 2), dropoff_latitude=40.727718353271484, dropoff_longitude=-73.9942855834961, hack_license='"2D73B0C44F1699C67AB8AE322433BDB7"', medallion='"6F907BC9A85B7034C8418A24A0A75489"', passenger_count=5, pickup_datetime=datetime.datetime(2013, 1, 11, 21, 46), pickup_latitude=40.757747650146484, pickup_longitude=-73.96498107910156, rate_code=1, store_and_fwd_flag='', trip_distance=3.009999990463257, trip_time_in_secs=960, vendor_id='"VTS"'),
 Row(id='"c635d7759ccce57b89417432052af020"', rev='"1-87d72b8d964c0bc4804c9ef14ac1a72f"', dropoff_datetime=datetime.datetime(2013, 2, 11, 17, 6, 2), dropoff_latitude=40.74830627441406, dropoff_longitude=-73.9885025024414, hack_license='"47E3BB27C9CD8B519C8B29E92F4386C6"', medallion='"E22AB8679FCC7FBCB81BE56FD5BB2C3C"', passenger_count=1, pickup_datetime=datetime.datetime(2013, 2, 11, 16, 52, 48), pickup_latitude=40.76

In [20]:
# string 仍然是雙引號放在單引號內，但是只需要一組引號，所以全部的 string 都要加上 .strip('"')
taxi_df = taxiNoHeader.map(lambda k: k.split(","))\
                      .map(lambda p: (p[0].strip('"'), p[1].strip('"'), parse(p[2].strip('"')), float(p[3]),
                                      float(p[4]), p[5].strip('"'), p[6].strip('"'), int(p[7]),
                                      parse(p[8].strip('"')), float(p[9]), float(p[10]), int(p[11]),
                                      p[12].strip('"'), float(p[13]), int(p[14]), p[15].strip('"')
                                     )).toDF(schema) # 直接用 rdd.toDF() 建立 DataFrame
taxi_df.head(10)

[Row(id='29b3f4a30dea6688d4c289c96758d87e', rev='1-387ec30eac5abda89d2abefdf947b2c1', dropoff_datetime=datetime.datetime(2013, 1, 11, 22, 2), dropoff_latitude=40.727718353271484, dropoff_longitude=-73.9942855834961, hack_license='2D73B0C44F1699C67AB8AE322433BDB7', medallion='6F907BC9A85B7034C8418A24A0A75489', passenger_count=5, pickup_datetime=datetime.datetime(2013, 1, 11, 21, 46), pickup_latitude=40.757747650146484, pickup_longitude=-73.96498107910156, rate_code=1, store_and_fwd_flag='', trip_distance=3.009999990463257, trip_time_in_secs=960, vendor_id='VTS'),
 Row(id='c635d7759ccce57b89417432052af020', rev='1-87d72b8d964c0bc4804c9ef14ac1a72f', dropoff_datetime=datetime.datetime(2013, 2, 11, 17, 6, 2), dropoff_latitude=40.74830627441406, dropoff_longitude=-73.9885025024414, hack_license='47E3BB27C9CD8B519C8B29E92F4386C6', medallion='E22AB8679FCC7FBCB81BE56FD5BB2C3C', passenger_count=1, pickup_datetime=datetime.datetime(2013, 2, 11, 16, 52, 48), pickup_latitude=40.76223373413086, pick

In [21]:
taxi_df.groupBy("vendor_id").count().show() # 有多少 vendor

+---------+------+
|vendor_id| count|
+---------+------+
|      CMT|114387|
|      VTS|135612|
+---------+------+



In [22]:
taxi_df.filter(taxi_df.store_and_fwd_flag=='').count() # store_and_fwd_flag 欄位有多少 missing values

135616

In [23]:
taxi_df.filter((taxi_df.store_and_fwd_flag== '') & (taxi_df.vendor_id == 'VTS')).count() 

135612

In [24]:
taxi_df.dtypes # 每個欄位的型態

[('id', 'string'),
 ('rev', 'string'),
 ('dropoff_datetime', 'timestamp'),
 ('dropoff_latitude', 'float'),
 ('dropoff_longitude', 'float'),
 ('hack_license', 'string'),
 ('medallion', 'string'),
 ('passenger_count', 'int'),
 ('pickup_datetime', 'timestamp'),
 ('pickup_latitude', 'float'),
 ('pickup_longitude', 'float'),
 ('rate_code', 'int'),
 ('store_and_fwd_flag', 'string'),
 ('trip_distance', 'float'),
 ('trip_time_in_secs', 'int'),
 ('vendor_id', 'string')]

In [25]:
taxi_df.printSchema()

root
 |-- id: string (nullable = true)
 |-- rev: string (nullable = true)
 |-- dropoff_datetime: timestamp (nullable = true)
 |-- dropoff_latitude: float (nullable = true)
 |-- dropoff_longitude: float (nullable = true)
 |-- hack_license: string (nullable = true)
 |-- medallion: string (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- pickup_datetime: timestamp (nullable = true)
 |-- pickup_latitude: float (nullable = true)
 |-- pickup_longitude: float (nullable = true)
 |-- rate_code: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- trip_distance: float (nullable = true)
 |-- trip_time_in_secs: integer (nullable = true)
 |-- vendor_id: string (nullable = true)



In [26]:
taxi_df.registerTempTable("taxi") # 把 datafraem 註冊成暫時的 table 這樣就能用 SQL 語具
sqlContext.sql("SELECT vendor_id, COUNT(*) FROM taxi GROUP BY vendor_id").show()

+---------+--------+
|vendor_id|count(1)|
+---------+--------+
|      CMT|  114387|
|      VTS|  135612|
+---------+--------+



In [27]:
sqlContext.sql("SELECT COUNT(*) FROM taxi WHERE store_and_fwd_flag =''").show() # 注意欄位的大小寫

+--------+
|count(1)|
+--------+
|  135616|
+--------+



In [28]:
sqlContext.sql("SELECT COUNT(*) FROM taxi WHERE vendor_id = 'VTS' AND store_and_fwd_flag = ''").show()

+--------+
|count(1)|
+--------+
|  135612|
+--------+



In [29]:
# 把欄位改名字
taxi_df = taxi_df.withColumnRenamed('dropoff_longitude', 'dropoff_long')\
                 .withColumnRenamed('dropoff_latitude', 'dropoff_lat')\
                 .withColumnRenamed('pickup_latitude', 'pickup_lat')\
                 .withColumnRenamed('pickup_longitude', 'pickup_long')

taxi_df.dtypes

[('id', 'string'),
 ('rev', 'string'),
 ('dropoff_datetime', 'timestamp'),
 ('dropoff_lat', 'float'),
 ('dropoff_long', 'float'),
 ('hack_license', 'string'),
 ('medallion', 'string'),
 ('passenger_count', 'int'),
 ('pickup_datetime', 'timestamp'),
 ('pickup_lat', 'float'),
 ('pickup_long', 'float'),
 ('rate_code', 'int'),
 ('store_and_fwd_flag', 'string'),
 ('trip_distance', 'float'),
 ('trip_time_in_secs', 'int'),
 ('vendor_id', 'string')]

In [30]:
import pandas as pd

# 把 spark dataframe 選出一部分並轉成 pandas dataframe
taxi_CMT = taxi_df.filter("vendor_id = 'CMT' and store_and_fwd_flag != ''").toPandas()
taxi_CMT.head()

Unnamed: 0,id,rev,dropoff_datetime,dropoff_lat,dropoff_long,hack_license,medallion,passenger_count,pickup_datetime,pickup_lat,pickup_long,rate_code,store_and_fwd_flag,trip_distance,trip_time_in_secs,vendor_id
0,c635d7759ccce57b89417432052af020,1-87d72b8d964c0bc4804c9ef14ac1a72f,2013-02-11 17:06:02,40.748306,-73.988503,47E3BB27C9CD8B519C8B29E92F4386C6,E22AB8679FCC7FBCB81BE56FD5BB2C3C,1,2013-02-11 16:52:48,40.762234,-73.973122,1,N,1.8,793,CMT
1,c7a046238400159b6d2f7dbbce76a55a,1-251e553dd840296c5f4c89f7a24bdc67,2013-11-26 15:46:53,40.781536,-73.981041,DEB3CA92C327A5EEE5355FCC61DD0800,0AD9AEDF8965687A1B8AE575EDD96774,1,2013-11-26 15:26:55,40.763351,-73.970451,1,N,2.8,1198,CMT
2,cf56894cd0bedc202dbdd97351b8f4cc,1-f58fda287b6413bd4aed319125ade266,2013-11-26 08:28:07,40.748169,-74.003952,ABA6223933ED3C6761DBA50C8834D86C,35F0522181AE560EC9817B4B12B60EEE,2,2013-11-26 08:15:06,40.729084,-73.993782,1,N,2.1,781,CMT
3,5b6168e8c48445dba0b4c09500d8b8a0,1-f61c3e2b6419b12c7e14f36c743bac3f,2013-01-11 11:21:36,40.757236,-73.975632,9F129AB93B1551856B390F284853FE87,BCDEAD3783CF34DA6D85762CECDA81FE,1,2013-01-11 11:08:07,40.732456,-73.981644,1,N,2.1,809,CMT
4,f0a7675beb17a8b2400ce09028b680d4,1-a35999d5e34a702a09ab313a3db70886,2013-11-26 12:35:42,40.767132,-73.989899,8F327F7849789A57F0DEE7256B045D4F,429AD5759D9F317802F3A6C5F138E7B9,1,2013-11-26 12:22:48,40.763004,-73.973801,1,N,1.0,774,CMT


In [31]:
type(taxi_CMT)

pandas.core.frame.DataFrame