# Spark DataFrames in Python

## Jupyter essentials

#### To display keyboard shortcuts type Esc+H

#### There are two important types of cells in Jupyter: code and markdown
#### Code is default, it means that whatever you type in the cell is evaluated as code
#### Markdown cell serves for adding text labels that help to organize the code

#### You can execute the cell by clicking "run cell" button or by pressing Shirt+Enter

### Start spark application and check if spark session object is available

In [1]:
spark

### Creating DataFrame from Hive Table

In [2]:
spark.read.table("tomek.taxi_cleaned")

DataFrame[unique_key: string, taxi_id: string, trip_start_timestamp: string, trip_end_timestamp: string, trip_seconds: bigint, trip_miles: float, pickup_census_tract: bigint, dropoff_census_tract: bigint, pickup_community_area: bigint, dropoff_community_area: bigint, fare: float, tips: float, tolls: float, extras: float, trip_total: float, payment_type: string, company: string, yyyymm: string]

In [3]:
spark.sql("""
    SELECT
        trip_start_timestamp,
        payment_type,
        fare
    FROM tomek.taxi_cleaned
    WHERE CAST(extras AS INT) > 0
""").show(3)

+--------------------+------------+------+
|trip_start_timestamp|payment_type|  fare|
+--------------------+------------+------+
|2017-05-22 12:00:...|        Cash|1000.0|
|2017-05-24 12:45:...|        Cash| 800.0|
|2017-05-04 09:30:...| Credit Card|4750.0|
+--------------------+------------+------+
only showing top 3 rows



In [4]:
import numpy as np

In [5]:
month_to_analyze = "{yyyy}{mm}".format(yyyy=str(np.random.choice(range(2013,2017))),
                                       mm="{:02d}".format(np.random.choice(range(12))+1))

In [6]:
print(month_to_analyze)

201407


In [7]:
taxi = spark.sql("""
    SELECT *
    FROM tomek.taxi_cleaned
    WHERE yyyymm = {yyyymm}
""".format(yyyymm=month_to_analyze))

In [8]:
taxi.show(3)

+--------------------+--------------------+--------------------+--------------------+------------+----------+-------------------+--------------------+---------------------+----------------------+------+----+-----+------+----------+------------+--------------------+------+
|          unique_key|             taxi_id|trip_start_timestamp|  trip_end_timestamp|trip_seconds|trip_miles|pickup_census_tract|dropoff_census_tract|pickup_community_area|dropoff_community_area|  fare|tips|tolls|extras|trip_total|payment_type|             company|yyyymm|
+--------------------+--------------------+--------------------+--------------------+------------+----------+-------------------+--------------------+---------------------+----------------------+------+----+-----+------+----------+------------+--------------------+------+
|7bf2b8b50e96833e8...|2e92587d66f35bf32...|2014-07-21 11:15:...|2014-07-21 11:15:...|         180|      0.05|        17031062100|         17031062000|                    6|         

### Creating DataFrame from CSV file

In [9]:
fromcsv = spark.read.csv(
    "gs://dataproc-43f8dac5-83ec-43bc-9165-c8d872ee1626-us-central1/notebooks/tomek/test.csv",
    header=True,
    inferSchema=True,
    mode="DROPMALFORMED",
    sep="|"
)
fromcsv.show(3)

+----+---+---+---+
|name|  g|  a|  s|
+----+---+---+---+
|   a|  2|  3|  4|
|   b| 54|  2|  6|
|   c|  1|  1|  1|
+----+---+---+---+
only showing top 3 rows



In [10]:
! cat home/workshop/test.csv

cat: home/workshop/test.csv: No such file or directory


#### Spark allows us to run standard SQL on DataFrames. In order to do so you have to first register your DataFrame as temporary view with Spark

In [11]:
fromcsv.createOrReplaceTempView("fromcsv")

#### And only when DataFrame is registered as temporary table you can run sql queries against it

In [12]:
spark.sql("select * from fromcsv where s > 5").show(5)

+----+-----+----+---+
|name|    g|   a|  s|
+----+-----+----+---+
|   b|   54|   2|  6|
|   z|23451|1234|123|
+----+-----+----+---+



## Basic DataFrames Operations

#### Check what is in DataFrame

In [13]:
taxi.show(3)

+--------------------+--------------------+--------------------+--------------------+------------+----------+-------------------+--------------------+---------------------+----------------------+------+----+-----+------+----------+------------+--------------------+------+
|          unique_key|             taxi_id|trip_start_timestamp|  trip_end_timestamp|trip_seconds|trip_miles|pickup_census_tract|dropoff_census_tract|pickup_community_area|dropoff_community_area|  fare|tips|tolls|extras|trip_total|payment_type|             company|yyyymm|
+--------------------+--------------------+--------------------+--------------------+------------+----------+-------------------+--------------------+---------------------+----------------------+------+----+-----+------+----------+------------+--------------------+------+
|7bf2b8b50e96833e8...|2e92587d66f35bf32...|2014-07-21 11:15:...|2014-07-21 11:15:...|         180|      0.05|        17031062100|         17031062000|                    6|         

In [14]:
taxi.select("payment_type","company").show(3)

+------------+--------------------+
|payment_type|             company|
+------------+--------------------+
|        Cash|                    |
|        Cash|                    |
|        Cash|Taxi Affiliation ...|
+------------+--------------------+
only showing top 3 rows



#### Check what is DataFrame schema

In [15]:
taxi.printSchema()

root
 |-- unique_key: string (nullable = true)
 |-- taxi_id: string (nullable = true)
 |-- trip_start_timestamp: string (nullable = true)
 |-- trip_end_timestamp: string (nullable = true)
 |-- trip_seconds: long (nullable = true)
 |-- trip_miles: float (nullable = true)
 |-- pickup_census_tract: long (nullable = true)
 |-- dropoff_census_tract: long (nullable = true)
 |-- pickup_community_area: long (nullable = true)
 |-- dropoff_community_area: long (nullable = true)
 |-- fare: float (nullable = true)
 |-- tips: float (nullable = true)
 |-- tolls: float (nullable = true)
 |-- extras: float (nullable = true)
 |-- trip_total: float (nullable = true)
 |-- payment_type: string (nullable = true)
 |-- company: string (nullable = true)
 |-- yyyymm: string (nullable = true)



### Apply function

In [16]:
from pyspark.sql.functions import col,lower

#### How to reference to column

In [17]:
taxi.select(lower(taxi.company)).show(3)

+--------------------+
|      lower(company)|
+--------------------+
|                    |
|                    |
|taxi affiliation ...|
+--------------------+
only showing top 3 rows



#### Another way to reference to column + output column renaming

In [18]:
taxi.select(lower(col("company")).alias("lowered_company")).show(3)

+--------------------+
|     lowered_company|
+--------------------+
|                    |
|                    |
|taxi affiliation ...|
+--------------------+
only showing top 3 rows



### Add new column

In [19]:
taxi = taxi.withColumn('has_tolls', col("tolls") > 0)
taxi.select("has_tolls").show(3)

+---------+
|has_tolls|
+---------+
|    false|
|    false|
|    false|
+---------+
only showing top 3 rows



In [20]:
taxi.printSchema()

root
 |-- unique_key: string (nullable = true)
 |-- taxi_id: string (nullable = true)
 |-- trip_start_timestamp: string (nullable = true)
 |-- trip_end_timestamp: string (nullable = true)
 |-- trip_seconds: long (nullable = true)
 |-- trip_miles: float (nullable = true)
 |-- pickup_census_tract: long (nullable = true)
 |-- dropoff_census_tract: long (nullable = true)
 |-- pickup_community_area: long (nullable = true)
 |-- dropoff_community_area: long (nullable = true)
 |-- fare: float (nullable = true)
 |-- tips: float (nullable = true)
 |-- tolls: float (nullable = true)
 |-- extras: float (nullable = true)
 |-- trip_total: float (nullable = true)
 |-- payment_type: string (nullable = true)
 |-- company: string (nullable = true)
 |-- yyyymm: string (nullable = true)
 |-- has_tolls: boolean (nullable = true)



In [21]:
# do the same using standard sql
# just create a new column
taxi = spark.sql("""
    SELECT *, case when tolls > 0 then TRUE else FALSE end as has_tools
    FROM tomek.taxi_cleaned
""")

In [22]:
taxi.printSchema()

root
 |-- unique_key: string (nullable = true)
 |-- taxi_id: string (nullable = true)
 |-- trip_start_timestamp: string (nullable = true)
 |-- trip_end_timestamp: string (nullable = true)
 |-- trip_seconds: long (nullable = true)
 |-- trip_miles: float (nullable = true)
 |-- pickup_census_tract: long (nullable = true)
 |-- dropoff_census_tract: long (nullable = true)
 |-- pickup_community_area: long (nullable = true)
 |-- dropoff_community_area: long (nullable = true)
 |-- fare: float (nullable = true)
 |-- tips: float (nullable = true)
 |-- tolls: float (nullable = true)
 |-- extras: float (nullable = true)
 |-- trip_total: float (nullable = true)
 |-- payment_type: string (nullable = true)
 |-- company: string (nullable = true)
 |-- yyyymm: string (nullable = true)
 |-- has_tools: boolean (nullable = false)



### Rename a column

In [23]:
taxi = taxi.withColumnRenamed("has_tolls","tolls_paid")

In [24]:
# spark is not warning when column with such name does not exists in case of withColumnRenamed
taxi = taxi.withColumnRenamed("aaaaaaa","vvvvvvv")

In [25]:
taxi.printSchema()

root
 |-- unique_key: string (nullable = true)
 |-- taxi_id: string (nullable = true)
 |-- trip_start_timestamp: string (nullable = true)
 |-- trip_end_timestamp: string (nullable = true)
 |-- trip_seconds: long (nullable = true)
 |-- trip_miles: float (nullable = true)
 |-- pickup_census_tract: long (nullable = true)
 |-- dropoff_census_tract: long (nullable = true)
 |-- pickup_community_area: long (nullable = true)
 |-- dropoff_community_area: long (nullable = true)
 |-- fare: float (nullable = true)
 |-- tips: float (nullable = true)
 |-- tolls: float (nullable = true)
 |-- extras: float (nullable = true)
 |-- trip_total: float (nullable = true)
 |-- payment_type: string (nullable = true)
 |-- company: string (nullable = true)
 |-- yyyymm: string (nullable = true)
 |-- has_tools: boolean (nullable = false)



In [26]:
# do the same using standard sql
taxi.createOrReplaceTempView("taxi")
spark.sql("SELECT has_tools as tools_paid FROM taxi").show(3)

+----------+
|tools_paid|
+----------+
|     false|
|     false|
|     false|
+----------+
only showing top 3 rows



### Drop column

In [27]:
taxi.drop('unique_key').drop('taxi_id').printSchema()

root
 |-- trip_start_timestamp: string (nullable = true)
 |-- trip_end_timestamp: string (nullable = true)
 |-- trip_seconds: long (nullable = true)
 |-- trip_miles: float (nullable = true)
 |-- pickup_census_tract: long (nullable = true)
 |-- dropoff_census_tract: long (nullable = true)
 |-- pickup_community_area: long (nullable = true)
 |-- dropoff_community_area: long (nullable = true)
 |-- fare: float (nullable = true)
 |-- tips: float (nullable = true)
 |-- tolls: float (nullable = true)
 |-- extras: float (nullable = true)
 |-- trip_total: float (nullable = true)
 |-- payment_type: string (nullable = true)
 |-- company: string (nullable = true)
 |-- yyyymm: string (nullable = true)
 |-- has_tools: boolean (nullable = false)



### Filter conditions

In [28]:
taxi.filter(taxi.payment_type == 'Cash').select("payment_type","fare").show(3)

+------------+------+
|payment_type|  fare|
+------------+------+
|        Cash|1550.0|
|        Cash| 950.0|
|        Cash| 500.0|
+------------+------+
only showing top 3 rows



In [29]:
taxi.filter("payment_type = 'Cash'").select("payment_type","fare").show(3)

+------------+------+
|payment_type|  fare|
+------------+------+
|        Cash|1225.0|
|        Cash| 845.0|
|        Cash|1065.0|
+------------+------+
only showing top 3 rows



In [30]:
# do the same using standard sql
spark.sql("""
    SELECT payment_type, fare
    FROM taxi
    WHERE payment_type = 'Cash'
""").show(3)

+------------+------+
|payment_type|  fare|
+------------+------+
|        Cash|1550.0|
|        Cash| 950.0|
|        Cash| 500.0|
+------------+------+
only showing top 3 rows



### Parenthesis required when joining multiple conditions

In [31]:
taxi.filter((taxi.payment_type == 'Cash') & (taxi.trip_miles > 30)).select("payment_type","trip_miles").show(3)

+------------+----------+
|payment_type|trip_miles|
+------------+----------+
|        Cash|      58.0|
|        Cash|      65.9|
|        Cash|      37.8|
+------------+----------+
only showing top 3 rows



In [32]:
taxi.filter((taxi.payment_type == 'Cash') | (taxi.trip_miles > 30)).select("payment_type","trip_miles").show(3)

+------------+----------+
|payment_type|trip_miles|
+------------+----------+
|        Cash|       1.8|
|        Cash|       2.8|
|        Cash|       3.2|
+------------+----------+
only showing top 3 rows



### Aggregation

In [33]:
taxi.count()

78470890

In [34]:
# (optional) do the same using standard sql
try:
    spark.sql(...)
except:
    pass

In [35]:
%%time
taxi.groupBy("payment_type").count().show()

+------------+--------+
|payment_type|   count|
+------------+--------+
| Credit Card|29757998|
|        Cash|48712892|
+------------+--------+

CPU times: user 8 ms, sys: 4 ms, total: 12 ms
Wall time: 14.7 s


In [36]:
%%time
spark.sql("""
    SELECT count(*)
    FROM taxi
    GROUP BY payment_type
""").show()

+--------+
|count(1)|
+--------+
|29757998|
|48712892|
+--------+

CPU times: user 0 ns, sys: 8 ms, total: 8 ms
Wall time: 7.87 s


In [37]:
%%time
taxi.groupBy("payment_type").avg("fare").show(5)

+------------+------------------+
|payment_type|         avg(fare)|
+------------+------------------+
| Credit Card|1549.7433185189407|
|        Cash|1172.3793519998771|
+------------+------------------+

CPU times: user 4 ms, sys: 4 ms, total: 8 ms
Wall time: 6.63 s


In [38]:
# (optional) do the same using standard sql
try:
    spark.sql(...)
except:
    pass

### Sorting

In [39]:
taxi.sort(taxi.trip_miles.desc()).select("trip_miles","trip_seconds","fare").show(3)

+----------+------------+------+
|trip_miles|trip_seconds|  fare|
+----------+------------+------+
|     100.0|         960|2185.0|
|     100.0|        1020|2205.0|
|     100.0|        1860|3975.0|
+----------+------------+------+
only showing top 3 rows



In [40]:
taxi.groupBy("company").count().sort(col("count").desc()).show(3)

+--------------------+--------+
|             company|   count|
+--------------------+--------+
|                    |39356942|
|Taxi Affiliation ...|18807000|
|Dispatch Taxi Aff...| 7538486|
+--------------------+--------+
only showing top 3 rows



### Multiple aggregates

In [41]:
from pyspark.sql.functions import min,max

In [42]:
# We've overwritten Python min and max functions in previous cell
try:
    min([10,2])
except:
    pass

In [43]:
taxi\
    .groupBy("payment_type")\
    .agg(min("fare").alias("smallest"),
         max("fare").alias("biggest")).show(3)

+------------+--------+--------+
|payment_type|smallest| biggest|
+------------+--------+--------+
| Credit Card|     1.0| 85000.0|
|        Cash|     1.0|990045.0|
+------------+--------+--------+



## Joins

In [44]:
taxi_rides_per_taxi = taxi.groupBy("taxi_id").count().withColumnRenamed("count","num_rides")

In [45]:
taxi_rides_per_taxi.show(3)

+--------------------+---------+
|             taxi_id|num_rides|
+--------------------+---------+
|0e21caaf401f961b7...|    13753|
|bb7eb49d01457ba3d...|    18479|
|a8aee50b5b0787156...|    13245|
+--------------------+---------+
only showing top 3 rows



In [46]:
taxi.select('taxi_id','company').join(taxi_rides_per_taxi, 'taxi_id').show(3)

+--------------------+-------+---------+
|             taxi_id|company|num_rides|
+--------------------+-------+---------+
|0678f31a979c4f311...|       |      328|
|0678f31a979c4f311...|       |      328|
|0678f31a979c4f311...|       |      328|
+--------------------+-------+---------+
only showing top 3 rows



### Selecting join mode

In [47]:
taxi.select('taxi_id','company').join(taxi_rides_per_taxi.sample(False,0.01), 'taxi_id','left').show(3)

+--------------------+-------+---------+
|             taxi_id|company|num_rides|
+--------------------+-------+---------+
|0678f31a979c4f311...|       |     null|
|0678f31a979c4f311...|       |     null|
|0678f31a979c4f311...|       |     null|
+--------------------+-------+---------+
only showing top 3 rows



## Custom functions

#### Sample function that extracts month from trip_start_timestamp

In [48]:
def extract_month(value):
    return value[5:7]

In [49]:
from pyspark.sql.types import StringType, IntegerType
from pyspark.sql.functions import udf
udf_extract_month = udf(extract_month,StringType())

taxi.select(udf_extract_month(taxi.trip_start_timestamp)).show(10)

+-----------------------------------+
|extract_month(trip_start_timestamp)|
+-----------------------------------+
|                                 05|
|                                 05|
|                                 05|
|                                 05|
|                                 05|
|                                 05|
|                                 05|
|                                 05|
|                                 05|
|                                 05|
+-----------------------------------+
only showing top 10 rows



#### Please prepare now function that maps Cash to 1 and Credit Card to 0 for payment type

In [50]:
def payment_type(value):
    mapping = {'Cash': 1, 'Credit Card': 0}
    return mapping[value]
udf_payment_type = udf(payment_type, IntegerType())

In [51]:
taxi.select(udf_payment_type(taxi.payment_type)).show(10)

+--------------------------+
|payment_type(payment_type)|
+--------------------------+
|                         0|
|                         0|
|                         1|
|                         1|
|                         0|
|                         1|
|                         0|
|                         1|
|                         0|
|                         1|
+--------------------------+
only showing top 10 rows

