# Create a SQL table from a dataframe

A dataframe can be used to create a temporary table. A temporary table is one that will not exist after the session ends. Spark documentation also refers to this type of table as a SQL temporary view. In the documentation this is referred to as to register the dataframe as a SQL temporary view. This command is called on the dataframe itself, and creates a table if it does not already exist, replacing it with the current data from the dataframe if it does already exist.

In [1]:
from pyspark.sql import SparkSession

# Create a SparkSession
spark = SparkSession.builder \
    .appName("example") \
    .getOrCreate()

In [6]:
# Load trainsched.txt
df = spark.read.csv("dataset/trainsched.txt", header=True)

# Create temporary table called table1
df.createOrReplaceTempView("schedule")

# Determine the column names of a table

This is important to know column names because in practice relational tables are typically provided without additional documentation giving the table schema.

In [7]:
# Inspect the columns in the table schedule
spark.sql("DESCRIBE schedule").show()

+--------+---------+-------+
|col_name|data_type|comment|
+--------+---------+-------+
|train_id|   string|   NULL|
| station|   string|   NULL|
|    time|   string|   NULL|
+--------+---------+-------+



# Running sums using window function SQL

A window function is like an aggregate function, except that it gives an output for every row in the dataset instead of a single row per group.

You can do aggregation along with window functions. A running sum using a window function is simpler than what is required using joins. The query duration can also be much faster.

In [9]:
# Add col running_total that sums diff_min col in each group
query = """
SELECT train_id, station, time, 
SUM(time) OVER (PARTITION BY train_id ORDER BY time) AS running_total
FROM schedule
"""

# Run the query and display the result
spark.sql(query).show()

+--------+-------------+-----+-------------+
|train_id|      station| time|running_total|
+--------+-------------+-----+-------------+
|     217|       Gilroy|6:06a|         NULL|
|     217|   San Martin|6:15a|         NULL|
|     217|  Morgan Hill|6:21a|         NULL|
|     217| Blossom Hill|6:36a|         NULL|
|     217|      Capitol|6:42a|         NULL|
|     217|       Tamien|6:50a|         NULL|
|     217|     San Jose|6:59a|         NULL|
|     324|San Francisco|7:59a|         NULL|
|     324|  22nd Street|8:03a|         NULL|
|     324|     Millbrae|8:16a|         NULL|
|     324|    Hillsdale|8:24a|         NULL|
|     324| Redwood City|8:31a|         NULL|
|     324|    Palo Alto|8:37a|         NULL|
|     324|     San Jose|9:05a|         NULL|
+--------+-------------+-----+-------------+



# Fix the broken query

This query runs correctly, but gives an incorrect result in one of the rows because of an omission in the OVER clause. Can you locate the bug? Can you modify the query to make it give a reasonable result?

In [10]:
query = """
SELECT 
ROW_NUMBER() OVER (ORDER BY time) AS row,
train_id, 
station, 
time, 
LEAD(time,1) OVER (ORDER BY time) AS time_next 
FROM schedule
"""
spark.sql(query).show()


+---+--------+-------------+-----+---------+
|row|train_id|      station| time|time_next|
+---+--------+-------------+-----+---------+
|  1|     217|       Gilroy|6:06a|    6:15a|
|  2|     217|   San Martin|6:15a|    6:21a|
|  3|     217|  Morgan Hill|6:21a|    6:36a|
|  4|     217| Blossom Hill|6:36a|    6:42a|
|  5|     217|      Capitol|6:42a|    6:50a|
|  6|     217|       Tamien|6:50a|    6:59a|
|  7|     217|     San Jose|6:59a|    7:59a|
|  8|     324|San Francisco|7:59a|    8:03a|
|  9|     324|  22nd Street|8:03a|    8:16a|
| 10|     324|     Millbrae|8:16a|    8:24a|
| 11|     324|    Hillsdale|8:24a|    8:31a|
| 12|     324| Redwood City|8:31a|    8:37a|
| 13|     324|    Palo Alto|8:37a|    9:05a|
| 14|     324|     San Jose|9:05a|     NULL|
+---+--------+-------------+-----+---------+



In [12]:

# Give the number of the bad row as an integer
bad_row = 7

# Provide the missing clause, SQL keywords in upper case
clause = 'PARTITION BY train_id'

query = """
SELECT 
ROW_NUMBER() OVER (PARTITION BY train_id ORDER BY time) AS row,
train_id, 
station, 
time, 
LEAD(time,1) OVER (PARTITION BY train_id ORDER BY time) AS time_next 
FROM schedule
"""
spark.sql(query).show()


+---+--------+-------------+-----+---------+
|row|train_id|      station| time|time_next|
+---+--------+-------------+-----+---------+
|  1|     217|       Gilroy|6:06a|    6:15a|
|  2|     217|   San Martin|6:15a|    6:21a|
|  3|     217|  Morgan Hill|6:21a|    6:36a|
|  4|     217| Blossom Hill|6:36a|    6:42a|
|  5|     217|      Capitol|6:42a|    6:50a|
|  6|     217|       Tamien|6:50a|    6:59a|
|  7|     217|     San Jose|6:59a|     NULL|
|  1|     324|San Francisco|7:59a|    8:03a|
|  2|     324|  22nd Street|8:03a|    8:16a|
|  3|     324|     Millbrae|8:16a|    8:24a|
|  4|     324|    Hillsdale|8:24a|    8:31a|
|  5|     324| Redwood City|8:31a|    8:37a|
|  6|     324|    Palo Alto|8:37a|    9:05a|
|  7|     324|     San Jose|9:05a|     NULL|
+---+--------+-------------+-----+---------+



# Aggregation, step by step

there are also cases where the dot notation gives a counterintuitive result, such as when a second aggregation on a column clobbers a prior aggregation on that column.

In [13]:
# Give the identical result in each command
spark.sql('SELECT train_id, MIN(time) AS start FROM schedule GROUP BY train_id').show()


+--------+-----+
|train_id|start|
+--------+-----+
|     217|6:06a|
|     324|7:59a|
+--------+-----+



In [19]:
df.groupBy('train_id').agg({'time':'min'}).withColumnRenamed('min(time)', 'start').show()



+--------+-----+
|train_id|start|
+--------+-----+
|     217|6:06a|
|     324|7:59a|
+--------+-----+



In [16]:
# Print the second column of the result
spark.sql('SELECT train_id, MIN(time), MAX(time) FROM schedule GROUP BY train_id').show()


+--------+---------+---------+
|train_id|min(time)|max(time)|
+--------+---------+---------+
|     217|    6:06a|    6:59a|
|     324|    7:59a|    9:05a|
+--------+---------+---------+



In [17]:
result = df.groupBy('train_id').agg({'time':'min', 'time':'max'})
result.show()
print(result.columns[1])

+--------+---------+
|train_id|max(time)|
+--------+---------+
|     217|    6:59a|
|     324|    9:05a|
+--------+---------+

max(time)


# Aggregating the same column twice

There are cases where dot notation can be more cumbersome than SQL. This exercise calculates the first and last times for each train line. 

In [20]:
from pyspark.sql.functions import min, max, col
expr = [min(col("time")).alias('start'), max(col("time")).alias('end')]
dot_df = df.groupBy("train_id").agg(*expr)
dot_df.show()

+--------+-----+-----+
|train_id|start|  end|
+--------+-----+-----+
|     217|6:06a|6:59a|
|     324|7:59a|9:05a|
+--------+-----+-----+



In [21]:
# Write a SQL query giving a result identical to dot_df
query = "SELECT train_id, MIN(time) as start, MAX(time) as end FROM schedule GROUP BY train_id"
sql_df = spark.sql(query)
sql_df.show()

+--------+-----+-----+
|train_id|start|  end|
+--------+-----+-----+
|     217|6:06a|6:59a|
|     324|7:59a|9:05a|
+--------+-----+-----+



# Aggregate dot SQL

The following code uses SQL to set the value of a dataframe called df

In [23]:
df = spark.sql("""
SELECT *, 
LEAD(time,1) OVER(PARTITION BY train_id ORDER BY time) AS time_next 
FROM schedule
""")
df.show(5)

+--------+------------+-----+---------+
|train_id|     station| time|time_next|
+--------+------------+-----+---------+
|     217|      Gilroy|6:06a|    6:15a|
|     217|  San Martin|6:15a|    6:21a|
|     217| Morgan Hill|6:21a|    6:36a|
|     217|Blossom Hill|6:36a|    6:42a|
|     217|     Capitol|6:42a|    6:50a|
+--------+------------+-----+---------+
only showing top 5 rows



In [28]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lead
# Obtain the identical result using dot notation 
dot_df = df.withColumn('time_next', lead('time', 1)
        .over(Window.partitionBy('train_id')
        .orderBy('time')))
dot_df.show(5)

+--------+------------+-----+---------+
|train_id|     station| time|time_next|
+--------+------------+-----+---------+
|     217|      Gilroy|6:06a|    6:15a|
|     217|  San Martin|6:15a|    6:21a|
|     217| Morgan Hill|6:21a|    6:36a|
|     217|Blossom Hill|6:36a|    6:42a|
|     217|     Capitol|6:42a|    6:50a|
+--------+------------+-----+---------+
only showing top 5 rows



# Convert window function from dot notation to SQL

We are going to add a column to a train schedule so that each row contains the number of minutes for the train to reach its next stop.

In [32]:
from pyspark.sql.functions import unix_timestamp
window = Window.partitionBy('train_id').orderBy('time')
dot_df = df.withColumn('diff_min', 
                    (unix_timestamp(lead('time', 1).over(window),'H:m') 
                     - unix_timestamp('time', 'H:m'))/60)
# dot_df.show()

In [34]:
# Create a SQL query to obtain an identical result to dot_df
query = """
SELECT *, 
(UNIX_TIMESTAMP(LEAD(time, 1) OVER (PARTITION BY train_id ORDER BY time),'H:m') 
 - UNIX_TIMESTAMP(time, 'H:m'))/60 AS diff_min 
FROM schedule 
"""
sql_df = spark.sql(query)
# sql_df.show()