In [1]:
import pyspark
spark = pyspark.sql.SparkSession.builder.getOrCreate()

In [2]:
from pydataset import data

mpg = spark.createDataFrame(data("mpg"))
mpg.show(5)

+------------+-----+-----+----+---+----------+---+---+---+---+-------+
|manufacturer|model|displ|year|cyl|     trans|drv|cty|hwy| fl|  class|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+
|        audi|   a4|  1.8|1999|  4|  auto(l5)|  f| 18| 29|  p|compact|
|        audi|   a4|  1.8|1999|  4|manual(m5)|  f| 21| 29|  p|compact|
|        audi|   a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|  p|compact|
|        audi|   a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|  p|compact|
|        audi|   a4|  2.8|1999|  6|  auto(l5)|  f| 16| 26|  p|compact|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+
only showing top 5 rows



In [3]:
mpg.hwy

Column<b'hwy'>

In [4]:
mpg.select(mpg.cty, mpg.hwy, mpg.model)

DataFrame[cty: bigint, hwy: bigint, model: string]

In [5]:
mpg.select(mpg.cty, mpg.hwy, mpg.model).show(10)

+---+---+----------+
|cty|hwy|     model|
+---+---+----------+
| 18| 29|        a4|
| 21| 29|        a4|
| 20| 31|        a4|
| 21| 30|        a4|
| 16| 26|        a4|
| 18| 26|        a4|
| 18| 27|        a4|
| 18| 26|a4 quattro|
| 16| 25|a4 quattro|
| 20| 28|a4 quattro|
+---+---+----------+
only showing top 10 rows



In [6]:
mpg.hwy + 1


Column<b'(hwy + 1)'>

In [7]:
mpg.select(mpg.hwy+1).show(5)

+---------+
|(hwy + 1)|
+---------+
|       30|
|       30|
|       32|
|       31|
|       27|
+---------+
only showing top 5 rows



In [8]:
mpg.select(mpg.hwy.alias('highway_mileage')).show(5)

+---------------+
|highway_mileage|
+---------------+
|             29|
|             29|
|             31|
|             30|
|             26|
+---------------+
only showing top 5 rows



In [9]:
col1 = mpg.hwy.alias("highway_mileage")
col2 = (mpg.hwy / 2).alias("highway_mileage_halved")
mpg.select(col1, col2).show(5)

+---------------+----------------------+
|highway_mileage|highway_mileage_halved|
+---------------+----------------------+
|             29|                  14.5|
|             29|                  14.5|
|             31|                  15.5|
|             30|                  15.0|
|             26|                  13.0|
+---------------+----------------------+
only showing top 5 rows



In [10]:
from pyspark.sql.functions import col, expr

In [11]:
col("hwy")

Column<b'hwy'>

In [12]:
avg_column = (col("hwy") + col("cty")) / 2

mpg.select(
    col("hwy").alias("highway_mileage"),
    mpg.cty.alias("city_mileage"),
    avg_column.alias("avg_mileage"),
).show(5)

+---------------+------------+-----------+
|highway_mileage|city_mileage|avg_mileage|
+---------------+------------+-----------+
|             29|          18|       23.5|
|             29|          21|       25.0|
|             31|          20|       25.5|
|             30|          21|       25.5|
|             26|          16|       21.0|
+---------------+------------+-----------+
only showing top 5 rows



In [13]:
mpg.select(
    expr("hwy"),  # the same as `col`
    expr("hwy + 1"),  # an arithmetic expression
    expr("hwy AS highway_mileage"),  # using an alias
    expr("hwy + 1 AS highway_incremented"),  # a combination of the above
).show(5)

+---+---------+---------------+-------------------+
|hwy|(hwy + 1)|highway_mileage|highway_incremented|
+---+---------+---------------+-------------------+
| 29|       30|             29|                 30|
| 29|       30|             29|                 30|
| 31|       32|             31|                 32|
| 30|       31|             30|                 31|
| 26|       27|             26|                 27|
+---+---------+---------------+-------------------+
only showing top 5 rows



In [14]:
mpg.select(
    mpg.hwy.alias("highway"),
    col("hwy").alias("highway"),
    expr("hwy").alias("highway"),
    expr("hwy AS highway"),
).show(5)

+-------+-------+-------+-------+
|highway|highway|highway|highway|
+-------+-------+-------+-------+
|     29|     29|     29|     29|
|     29|     29|     29|     29|
|     31|     31|     31|     31|
|     30|     30|     30|     30|
|     26|     26|     26|     26|
+-------+-------+-------+-------+
only showing top 5 rows



In [15]:
mpg.createOrReplaceTempView("mpg")

In [16]:
spark.sql(
    """
SELECT hwy, cty, (hwy + cty) / 2 AS avg
FROM mpg
"""
).show(5)

+---+---+----+
|hwy|cty| avg|
+---+---+----+
| 29| 18|23.5|
| 29| 21|25.0|
| 31| 20|25.5|
| 30| 21|25.5|
| 26| 16|21.0|
+---+---+----+
only showing top 5 rows



In [17]:
mpg.select(mpg.hwy.cast("string")).printSchema()

root
 |-- hwy: string (nullable = true)



In [18]:
mpg.select(mpg.model, mpg.model.cast("int")).show(5)

+-----+-----+
|model|model|
+-----+-----+
|   a4| null|
|   a4| null|
|   a4| null|
|   a4| null|
|   a4| null|
+-----+-----+
only showing top 5 rows



In [19]:
from pyspark.sql.functions import concat, sum, avg, min, max, count, mean
from pyspark.sql.functions import lit

In [20]:
mpg.select(
    sum(mpg.hwy) / count(mpg.hwy).alias("average_1"),
    avg(mpg.hwy).alias("average_2"),
    min(mpg.hwy),
    max(mpg.hwy),
).show()

+--------------------------------------+-----------------+--------+--------+
|(sum(hwy) / count(hwy) AS `average_1`)|        average_2|min(hwy)|max(hwy)|
+--------------------------------------+-----------------+--------+--------+
|                     23.44017094017094|23.44017094017094|      12|      44|
+--------------------------------------+-----------------+--------+--------+



In [21]:
mpg.select(concat(mpg.cyl, lit(" cylinders"))).show(5)


+-----------------------+
|concat(cyl,  cylinders)|
+-----------------------+
|            4 cylinders|
|            4 cylinders|
|            4 cylinders|
|            4 cylinders|
|            6 cylinders|
+-----------------------+
only showing top 5 rows



In [22]:
mpg.filter(mpg.cyl == 4).where(mpg["class"] == "subcompact").show()

+------------+-----------+-----+----+---+----------+---+---+---+---+----------+
|manufacturer|      model|displ|year|cyl|     trans|drv|cty|hwy| fl|     class|
+------------+-----------+-----+----+---+----------+---+---+---+---+----------+
|       honda|      civic|  1.6|1999|  4|manual(m5)|  f| 28| 33|  r|subcompact|
|       honda|      civic|  1.6|1999|  4|  auto(l4)|  f| 24| 32|  r|subcompact|
|       honda|      civic|  1.6|1999|  4|manual(m5)|  f| 25| 32|  r|subcompact|
|       honda|      civic|  1.6|1999|  4|manual(m5)|  f| 23| 29|  p|subcompact|
|       honda|      civic|  1.6|1999|  4|  auto(l4)|  f| 24| 32|  r|subcompact|
|       honda|      civic|  1.8|2008|  4|manual(m5)|  f| 26| 34|  r|subcompact|
|       honda|      civic|  1.8|2008|  4|  auto(l5)|  f| 25| 36|  r|subcompact|
|       honda|      civic|  1.8|2008|  4|  auto(l5)|  f| 24| 36|  c|subcompact|
|       honda|      civic|  2.0|2008|  4|manual(m6)|  f| 21| 29|  p|subcompact|
|     hyundai|    tiburon|  2.0|1999|  4

In [23]:
from pyspark.sql.functions import when

In [24]:
mpg.select(mpg.hwy, when(mpg.hwy > 25, "good_mileage").alias("mpg_desc")).show(
    12
)

+---+------------+
|hwy|    mpg_desc|
+---+------------+
| 29|good_mileage|
| 29|good_mileage|
| 31|good_mileage|
| 30|good_mileage|
| 26|good_mileage|
| 26|good_mileage|
| 27|good_mileage|
| 26|good_mileage|
| 25|        null|
| 28|good_mileage|
| 27|good_mileage|
| 25|        null|
+---+------------+
only showing top 12 rows



In [25]:
mpg.select(
    mpg.hwy,
    when(mpg.hwy > 25, "good_mileage")
    .otherwise("bad_mileage")
    .alias("mpg_desc"),
).show(12)

+---+------------+
|hwy|    mpg_desc|
+---+------------+
| 29|good_mileage|
| 29|good_mileage|
| 31|good_mileage|
| 30|good_mileage|
| 26|good_mileage|
| 26|good_mileage|
| 27|good_mileage|
| 26|good_mileage|
| 25| bad_mileage|
| 28|good_mileage|
| 27|good_mileage|
| 25| bad_mileage|
+---+------------+
only showing top 12 rows



In [26]:
mpg.select(
    mpg.displ,
    (
        when(mpg.displ < 2, "small")
        .when(mpg.displ < 3, "medium")
        .otherwise("large")
        .alias("engine_size")
    ),
).show(10)

+-----+-----------+
|displ|engine_size|
+-----+-----------+
|  1.8|      small|
|  1.8|      small|
|  2.0|     medium|
|  2.0|     medium|
|  2.8|     medium|
|  2.8|     medium|
|  3.1|      large|
|  1.8|      small|
|  1.8|      small|
|  2.0|     medium|
+-----+-----------+
only showing top 10 rows



In [27]:
mpg.sort(mpg.hwy).show(5)

+------------+-------------------+-----+----+---+----------+---+---+---+---+------+
|manufacturer|              model|displ|year|cyl|     trans|drv|cty|hwy| fl| class|
+------------+-------------------+-----+----+---+----------+---+---+---+---+------+
|       dodge|ram 1500 pickup 4wd|  4.7|2008|  8|  auto(l5)|  4|  9| 12|  e|pickup|
|       dodge|        durango 4wd|  4.7|2008|  8|  auto(l5)|  4|  9| 12|  e|   suv|
|        jeep| grand cherokee 4wd|  4.7|2008|  8|  auto(l5)|  4|  9| 12|  e|   suv|
|       dodge|ram 1500 pickup 4wd|  4.7|2008|  8|manual(m6)|  4|  9| 12|  e|pickup|
|       dodge|  dakota pickup 4wd|  4.7|2008|  8|  auto(l5)|  4|  9| 12|  e|pickup|
+------------+-------------------+-----+----+---+----------+---+---+---+---+------+
only showing top 5 rows



In [28]:
from pyspark.sql.functions import asc, desc

In [29]:
mpg.sort(desc("class"), mpg.cyl.asc(), col("hwy").desc()).show()

+------------+------------------+-----+----+---+----------+---+---+---+---+-----+
|manufacturer|             model|displ|year|cyl|     trans|drv|cty|hwy| fl|class|
+------------+------------------+-----+----+---+----------+---+---+---+---+-----+
|      subaru|      forester awd|  2.5|2008|  4|manual(m5)|  4| 20| 27|  r|  suv|
|      subaru|      forester awd|  2.5|2008|  4|  auto(l4)|  4| 20| 26|  r|  suv|
|      subaru|      forester awd|  2.5|1999|  4|manual(m5)|  4| 18| 25|  r|  suv|
|      subaru|      forester awd|  2.5|2008|  4|manual(m5)|  4| 19| 25|  p|  suv|
|      subaru|      forester awd|  2.5|1999|  4|  auto(l4)|  4| 18| 24|  r|  suv|
|      subaru|      forester awd|  2.5|2008|  4|  auto(l4)|  4| 18| 23|  p|  suv|
|      toyota|       4runner 4wd|  2.7|1999|  4|manual(m5)|  4| 15| 20|  r|  suv|
|      toyota|       4runner 4wd|  2.7|1999|  4|  auto(l4)|  4| 16| 20|  r|  suv|
|        jeep|grand cherokee 4wd|  3.0|2008|  6|  auto(l5)|  4| 17| 22|  d|  suv|
|        jeep|gr

In [30]:
mpg.groupBy(mpg.cyl).agg(avg(mpg.cty), avg(mpg.hwy)).show()

+---+------------------+-----------------+
|cyl|          avg(cty)|         avg(hwy)|
+---+------------------+-----------------+
|  6| 16.21518987341772|22.82278481012658|
|  5|              20.5|            28.75|
|  8|12.571428571428571|17.62857142857143|
|  4|21.012345679012345|28.80246913580247|
+---+------------------+-----------------+



In [31]:
mpg.groupBy("cyl", "class").agg(avg(mpg.cty), avg(mpg.hwy)).show()

+---+----------+------------------+------------------+
|cyl|     class|          avg(cty)|          avg(hwy)|
+---+----------+------------------+------------------+
|  5|   compact|              21.0|              29.0|
|  5|subcompact|              20.0|              28.5|
|  6|subcompact|              17.0|24.714285714285715|
|  6|    pickup|              14.5|              17.9|
|  4|subcompact|22.857142857142858| 30.80952380952381|
|  8|       suv|12.131578947368421|16.789473684210527|
|  8|    pickup|              11.8|              15.8|
|  8|   midsize|              16.0|              24.0|
|  4|   midsize|              20.5|           29.1875|
|  8|   2seater|              15.4|              24.8|
|  6|   compact|16.923076923076923|25.307692307692307|
|  6|   minivan|              15.6|              22.2|
|  4|   compact|            21.375|          29.46875|
|  8|subcompact|              14.8|              21.6|
|  6|   midsize|17.782608695652176| 26.26086956521739|
|  4|   mi

In [32]:
mpg.rollup("cyl").count().sort("cyl").show()

+----+-----+
| cyl|count|
+----+-----+
|null|  234|
|   4|   81|
|   5|    4|
|   6|   79|
|   8|   70|
+----+-----+



In [33]:
mpg.rollup("cyl").agg(expr("avg(hwy)")).sort("cyl").show()

+----+-----------------+
| cyl|         avg(hwy)|
+----+-----------------+
|null|23.44017094017094|
|   4|28.80246913580247|
|   5|            28.75|
|   6|22.82278481012658|
|   8|17.62857142857143|
+----+-----------------+



In [34]:
mpg.rollup("cyl", "class").mean("hwy").sort(col("cyl"), col("class")).show(10)

+----+----------+------------------+
| cyl|     class|          avg(hwy)|
+----+----------+------------------+
|null|      null| 23.44017094017094|
|   4|      null| 28.80246913580247|
|   4|   compact|          29.46875|
|   4|   midsize|           29.1875|
|   4|   minivan|              24.0|
|   4|    pickup|20.666666666666668|
|   4|subcompact| 30.80952380952381|
|   4|       suv|             23.75|
|   5|      null|             28.75|
|   5|   compact|              29.0|
+----+----------+------------------+
only showing top 10 rows



In [36]:
from vega_datasets import data

weather = data.seattle_weather().assign(date=lambda df: df.date.astype(str))
weather = spark.createDataFrame(weather)
weather.show(4)

+----------+-------------+--------+--------+----+-------+
|      date|precipitation|temp_max|temp_min|wind|weather|
+----------+-------------+--------+--------+----+-------+
|2012-01-01|          0.0|    12.8|     5.0| 4.7|drizzle|
|2012-01-02|         10.9|    10.6|     2.8| 4.5|   rain|
|2012-01-03|          0.8|    11.7|     7.2| 2.3|   rain|
|2012-01-04|         20.3|    12.2|     5.6| 4.7|   rain|
+----------+-------------+--------+--------+----+-------+
only showing top 4 rows



In [37]:
print(weather.count(), "rows", len(weather.columns), "columns")

1461 rows 6 columns


In [38]:
min_date, max_date = weather.select(min("date"), max("date")).first()
min_date, max_date

('2012-01-01', '2015-12-31')

In [39]:
weather = weather.withColumn(
    "temp_avg", expr("ROUND(temp_min + temp_max) / 2")
).drop("temp_max", "temp_min")
weather.show(6)

+----------+-------------+----+-------+--------+
|      date|precipitation|wind|weather|temp_avg|
+----------+-------------+----+-------+--------+
|2012-01-01|          0.0| 4.7|drizzle|     9.0|
|2012-01-02|         10.9| 4.5|   rain|     6.5|
|2012-01-03|          0.8| 2.3|   rain|     9.5|
|2012-01-04|         20.3| 4.7|   rain|     9.0|
|2012-01-05|          1.3| 6.1|   rain|     6.0|
|2012-01-06|          2.5| 2.2|   rain|     3.5|
+----------+-------------+----+-------+--------+
only showing top 6 rows



In [40]:
from pyspark.sql.functions import month, year, quarter