# groupBy() and Filtering Aggregated Data

`groupBy()` is a transformation operation in PySpark that is used to group the data in a Spark DataFrame or RDD based on one or more specified columns. It returns a GroupedData object which can then be used to perform aggregation operations such as `count()`, `sum()`, `avg()`, etc. on the grouped data.

The `groupBy()` operation is commonly used for tasks such as calculating statistics for different categories, grouping data by time intervals, and creating pivot tables.

For example, if you have a DataFrame with columns "category" and "value", you can use `groupBy("category")` to group the data by category and then apply an aggregation function such as `sum("value")` to calculate the total value for each category. This would give you a new DataFrame with one row for each category and the total value for that category.

Here are a couple of examples with syntax for `groupBy()` in PySpark:

In [15]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import count, avg, max, sum

In [2]:
spark = SparkSession.builder.appName("example").getOrCreate()

## Example 1 - Grouping data by a single column and counting the number of occurrences for each value in that column --- `count()`

In [8]:
# create a DataFrame with a single column "fruit" and several rows
data = [("apple",), ("banana",), ("apple",), ("orange",), ("banana",), ("banana",)]
df = spark.createDataFrame(data, ["fruit"])
df.show()

+------+
| fruit|
+------+
| apple|
|banana|
| apple|
|orange|
|banana|
|banana|
+------+



In [9]:
# group the data by the "fruit" column and count the number of occurrences for each fruit
grouped_df = df.groupBy("fruit").count()

# show the results
grouped_df.show()

+------+-----+
| fruit|count|
+------+-----+
| apple|    2|
|banana|    3|
|orange|    1|
+------+-----+



## Example 2 - Grouping data by multiple columns and computing the average value for each group - `avg`

In [10]:
# create a DataFrame with columns "city" and "temperature"
data = [("New York", 10.0), ("New York", 12.0),
        ("Los Angeles", 20.0), ("Los Angeles", 22.0),
        ("San Francisco", 15.0), ("San Francisco", 18.0)]
df = spark.createDataFrame(data, ["city", "temperature"])
df.show()

+-------------+-----------+
|         city|temperature|
+-------------+-----------+
|     New York|       10.0|
|     New York|       12.0|
|  Los Angeles|       20.0|
|  Los Angeles|       22.0|
|San Francisco|       15.0|
|San Francisco|       18.0|
+-------------+-----------+



In [11]:
# group the data by the "city" column and compute the average temperature for each city
grouped_df = df.groupBy("city").agg(avg("temperature").alias("avg_temperature"))

# show the results
grouped_df.show()

+-------------+---------------+
|         city|avg_temperature|
+-------------+---------------+
|     New York|           11.0|
|  Los Angeles|           21.0|
|San Francisco|           16.5|
+-------------+---------------+



## Example 3 - Grouping data by multiple columns and computing the maximum value of another column - `max()`

In [13]:
# create a DataFrame with columns "city", "year", and "temperature"
data = [("New York", 2021, 10.0), ("New York", 2022, 12.0),
        ("Los Angeles", 2021, 20.0), ("Los Angeles", 2022, 22.0),
        ("San Francisco", 2021, 15.0), ("San Francisco", 2022, 18.0)]
df = spark.createDataFrame(data, ["city", "year", "temperature"])
df.show()

+-------------+----+-----------+
|         city|year|temperature|
+-------------+----+-----------+
|     New York|2021|       10.0|
|     New York|2022|       12.0|
|  Los Angeles|2021|       20.0|
|  Los Angeles|2022|       22.0|
|San Francisco|2021|       15.0|
|San Francisco|2022|       18.0|
+-------------+----+-----------+



In [14]:
# group the data by the "city" and "year" columns and compute the maximum temperature for each group
grouped_df = df.groupBy(["city", "year"]).agg(max("temperature").alias("max_temperature"))

# show the results
grouped_df.show()

+-------------+----+---------------+
|         city|year|max_temperature|
+-------------+----+---------------+
|     New York|2021|           10.0|
|  Los Angeles|2021|           20.0|
|     New York|2022|           12.0|
|  Los Angeles|2022|           22.0|
|San Francisco|2021|           15.0|
|San Francisco|2022|           18.0|
+-------------+----+---------------+



## Example 4 - Grouping data by a column and computing multiple aggregate functions for another column - `avg()`, `sum()`, `count()`

In [17]:
# create a DataFrame with columns "city" and "temperature"
data = [("New York", 10.0), ("New York", 12.0),
        ("Los Angeles", 20.0), ("Los Angeles", 22.0),
        ("San Francisco", 15.0), ("San Francisco", 18.0)]
df = spark.createDataFrame(data, ["city", "temperature"])
df.show()

+-------------+-----------+
|         city|temperature|
+-------------+-----------+
|     New York|       10.0|
|     New York|       12.0|
|  Los Angeles|       20.0|
|  Los Angeles|       22.0|
|San Francisco|       15.0|
|San Francisco|       18.0|
+-------------+-----------+



In [18]:
# group the data by the "city" column and compute the average, sum, and count of the temperature for each city
grouped_df = df.groupBy("city").agg(avg("temperature").alias("avg_temperature"),
                                     sum("temperature").alias("total_temperature"),
                                     count("temperature").alias("num_measurements"))

# show the results
grouped_df.show()

+-------------+---------------+-----------------+----------------+
|         city|avg_temperature|total_temperature|num_measurements|
+-------------+---------------+-----------------+----------------+
|     New York|           11.0|             22.0|               2|
|  Los Angeles|           21.0|             42.0|               2|
|San Francisco|           16.5|             33.0|               2|
+-------------+---------------+-----------------+----------------+



## Example 5 - Filtering groups based on an aggregate condition - `filter()`

In [6]:
# create a DataFrame with columns "department" and "salary"
data = [("IT", 5000), ("IT", 6000), ("Sales", 7000), ("Sales", 8000)]
df = spark.createDataFrame(data, ["department", "salary"])
df.show()

+----------+------+
|department|salary|
+----------+------+
|        IT|  5000|
|        IT|  6000|
|     Sales|  7000|
|     Sales|  8000|
+----------+------+



In [7]:
# group the data by the "department" column and compute the average salary for each department
grouped_df = df.groupBy("department").agg(avg("salary").alias("avg_salary"))

# filter the groups to include only those where the average salary is greater than 6000
filtered_df = grouped_df.filter("avg_salary > 6000")

# show the results
filtered_df.show()

+----------+----------+
|department|avg_salary|
+----------+----------+
|     Sales|    7500.0|
+----------+----------+



## Example 6 - Filtering groups based on the count of rows in each group

In [3]:
# create a DataFrame with columns "department" and "salary"
data = [("IT", 5000), ("IT", 6000), ("Sales", 7000), ("Sales", 8000)]
df = spark.createDataFrame(data, ["department", "salary"])
df.show()

+----------+------+
|department|salary|
+----------+------+
|        IT|  5000|
|        IT|  6000|
|     Sales|  7000|
|     Sales|  8000|
+----------+------+



In [4]:
# group the data by the "department" column and compute the count of rows in each group
grouped_df = df.groupBy("department").agg(count("*").alias("num_employees"))

# filter the groups to include only those with more than one employee
filtered_df = grouped_df.filter("num_employees > 1")

# show the results
filtered_df.show()

+----------+-------------+
|department|num_employees|
+----------+-------------+
|        IT|            2|
|     Sales|            2|
+----------+-------------+



## Conclusion

This recipe has covered the usage of `groupBy()` and aggregate functions on a Spark DataFrame. Additionally, it has demonstrated how to apply these operations to multiple columns and finally how to filter the data based on the aggregated column. By following the examples provided in this recipe, you can gain a better understanding of how to leverage these features in your own PySpark projects.