In [None]:
%run "./Setup.ipynb"

In [None]:
#input_file = "E:\\PySpark\\data\\flight-data\json\\2015-summary.json"
input_file = "E:\\PySpark\\data\\flight-data\json"

df1 = spark.read.json(input_file)

df1.show()

In [None]:
df1.printSchema()

In [None]:
from pyspark.sql.functions import *

**select**

In [None]:
df2 = df1.select("ORIGIN_COUNTRY_NAME", "DEST_COUNTRY_NAME", "count")

df2.show()

In [None]:
#col("age") + 1
#expr("age > 18")
#df1["ORIGIN_COUNTRY_NAME"]
#df1.ORIGIN_COUNTRY_NAME

In [None]:
df2 = df1.select(
    col("ORIGIN_COUNTRY_NAME").alias("origin"),
    expr("DEST_COUNTRY_NAME").alias("destination"),
    col("count").cast("int"),
    expr("count + 10 as new_count"),
    expr("count > 200 as high_frequency"),
    expr("ORIGIN_COUNTRY_NAME = DEST_COUNTRY_NAME as domestic"),
    current_date().alias("today"),
    lit("India").alias("country")
)

df2.show()

In [None]:
df2.printSchema()

**where / filter**

In [None]:
#df3 = df2.where("count > 100 and domestic = false")
df3 = df2.filter("count > 100 and domestic = false")

df3.show()

**orderBy / sort**

In [None]:
#df3 = df2.orderBy("count", "origin")
#df3 = df2.sort("count", "origin")
#df3 = df2.orderBy(desc("count"), asc("origin"))

df3 = df2.orderBy(expr("count + 10").desc())
df3.show(50)

**groupBy**

- Returns a 'pyspark.sql.group.GroupedData' object (not a DataFrame)
- Apply aggregation methods to return a DataFrame

In [None]:
#df3 = df2.groupBy("high_frequency", "domestic").count()
#df3 = df2.groupBy("high_frequency", "domestic").sum("count")
#df3 = df2.groupBy("high_frequency", "domestic").avg("count")
#df3 = df2.groupBy("high_frequency", "domestic").max("count")

df3 = df2.groupBy("high_frequency", "domestic") \
        .agg(
            count("count").alias("count"),
            sum("count").alias("sum"),
            round(avg("count"), 1).alias("average"),
            max("count").alias("max")
        )

df3.show()

**selectExpr**

In [None]:
df2 = df1.selectExpr(
    "ORIGIN_COUNTRY_NAME as origin",
    "DEST_COUNTRY_NAME as destination",
    "cast(count as int)",
    "count + 10 as new_count",
    "count > 200 as high_frequency",
    "ORIGIN_COUNTRY_NAME = DEST_COUNTRY_NAME as domestic",
    "current_date() as today",
    "'India' as country"
)

df2.show()

In [None]:
df2.printSchema()

**withColumn** & **withColumnRenamed**

In [None]:
df1.printSchema()

In [None]:
df2 = df1.withColumn("new_count", col("count") + 10) \
        .withColumn("high_frequency", col("count") > 200) \
        .withColumn("domestic", col("DEST_COUNTRY_NAME") == col("ORIGIN_COUNTRY_NAME")) \
        .withColumn("today", current_date()) \
        .withColumn("country", lit("India")) \
        .withColumn("count", col("count").cast("int")) \
        .withColumnRenamed("DEST_COUNTRY_NAME", "destination") \
        .withColumnRenamed("ORIGIN_COUNTRY_NAME", "origin")

df2.show()

In [None]:
df2.printSchema()

In [None]:
listUsers = [(1, "Raju", 5),
             (2, "Ramesh", 75),
             (3, "Rajesh", 18),
             (4, "Raghu", 35),
             (5, "Ramya", 25),
             (6, "Radhika", 35),
             (7, "Ravi", 10)]

In [None]:
users_df = spark.createDataFrame(listUsers).toDF("id", "name", "age")
users_df.show()

In [None]:
age_group_df = users_df.withColumn("age_group", when(col("age") < 13, "child")
                                              .when(col("age") < 20, "teenager")
                                              .when(col("age") < 60, "adult")
                                              .otherwise("senior"))
age_group_df.show()

**udf  (user defined function)**

In [None]:
from pyspark.sql.types import *

In [None]:
listUsers = [(1, "Raju", 5),
             (2, "Ramesh", 75),
             (3, "Rajesh", 18),
             (4, "Raghu", 35),
             (5, "Ramya", 25),
             (6, "Radhika", 35),
             (7, "Ravi", 10)]


In [None]:
users_df = spark.createDataFrame(listUsers, ["id", "name", "age"])
users_df.show()

In [None]:
def get_age_group( age ):
    if (age <= 12):
        return "child"
    elif (age >= 13 and age <= 19):
        return "teenager"
    elif (age >= 20 and age < 60):
        return "adult"
    else:
        return "senior"

In [None]:
get_age_group_udf = udf(get_age_group, StringType() )
get_age_group_udf

In [None]:
age_group_df = users_df.withColumn("age_group", get_age_group_udf(col("age")))
age_group_df.show()

**Register UDF in the catalog**

In [None]:
users_df.createOrReplaceTempView("users")

In [None]:
spark.catalog.listTables()

In [None]:
for f in spark.catalog.listFunctions():
    print(f.name)

In [None]:
spark.udf.register("age_group", get_age_group, returnType = StringType())

In [None]:
qry = "select id, name, age, age_group(age) as age_group from users"
spark.sql(qry).show()

**drop**
- used to exclude columns in the output dataframe

In [None]:
df2.printSchema()

In [None]:
df3 = df2.drop("new_count", "country", "high_frequency")
df3.printSchema()

**dropDuplicates**

- drops duplicate rows/data

In [None]:
listUsers = [(1, "Raju", 5),
             (1, "Raju", 5),
             (3, "Raju", 5),
             (4, "Raghu", 35),
             (4, "Raghu", 35),
             (6, "Raghu", 35),
             (7, "Ravi", 70)]

users_df = spark.createDataFrame(listUsers, ["id", "name", "age"])
users_df.show()

In [None]:
dropdups_df = users_df.dropDuplicates()
#dropdups_df = users_df.dropDuplicates(["name", "age"])
dropdups_df.show()

**dropna**
- drop the rows with NULL values

In [None]:
users_df = spark.read.json("E:\\PySpark\\data\\users.json")
users_df.show()

In [None]:
#clean_df = users_df.dropna()
clean_df = users_df.dropna(subset = ["age", "phone"])
clean_df.show()

**distinct**

In [None]:
listUsers = [(1, "Raju", 5),
             (1, "Raju", 5),
             (3, "Raju", 5),
             (4, "Raghu", 35),
             (4, "Raghu", 35),
             (6, "Raghu", 35),
             (7, "Ravi", 70)]

users_df = spark.createDataFrame(listUsers, ["id", "name", "age"])
users_df.show()

In [None]:
users_df.distinct().show()

**Q: How many unique DEST_COUNTRY_NAME values are there in df1?**


In [None]:
df1.show()

In [None]:
## How many dups are there in df1
df1.count() - df1.distinct().count()

In [None]:
df1.select("DEST_COUNTRY_NAME").distinct().count()

In [None]:
df1.dropDuplicates(["DEST_COUNTRY_NAME"]).count()

**union, intersect, subtract**

In [None]:
## CONTINUE FROM HERE ....

In [None]:
display(df1)

In [None]:
df2 = df1.where("count > 1000")
display(df2)

In [None]:
df2.rdd.getNumPartitions()

In [None]:
df3 = df1.where("DEST_COUNTRY_NAME = 'India'")
display(df3)

In [None]:
df3.rdd.getNumPartitions()

**repartition**
- Is used to increase or decrease the number of partitions of the output DF
- Causes global shuffle

In [150]:
def partitions(df):
    df.withColumn("partition", spark_partition_id()).groupBy("partition").count().sort("partition").show()
    pass

In [151]:
partitions(df1)

+---------+-----+
|partition|count|
+---------+-----+
|        0|  256|
|        1|  255|
|        2|  255|
|        3|  250|
|        4|  245|
|        5|  241|
+---------+-----+



In [152]:
df10 = df1.repartition(10)
partitions(df10)

+---------+-----+
|partition|count|
+---------+-----+
|        0|  150|
|        1|  150|
|        2|  150|
|        3|  151|
|        4|  150|
|        5|  150|
|        6|  151|
|        7|  150|
|        8|  150|
|        9|  150|
+---------+-----+



In [154]:
df11 = df10.repartition(4)
partitions(df11)

+---------+-----+
|partition|count|
+---------+-----+
|        0|  376|
|        1|  375|
|        2|  373|
|        3|  378|
+---------+-----+



**coalesce**
- Is used to only decrease the number of partitions of the output DF
- Causes partition merging

In [155]:
partitions(df10)

+---------+-----+
|partition|count|
+---------+-----+
|        0|  150|
|        1|  150|
|        2|  150|
|        3|  151|
|        4|  150|
|        5|  150|
|        6|  151|
|        7|  150|
|        8|  150|
|        9|  150|
+---------+-----+



In [157]:
df12 = df10.coalesce(4)
partitions(df12)

+---------+-----+
|partition|count|
+---------+-----+
|        0|  300|
|        1|  451|
|        2|  301|
|        3|  450|
+---------+-----+



**Window functions**

In [None]:
data_file = "dbfs:/FileStore/data/empdata.csv"

In [None]:
csv_schema = "id INT, name STRING, dept STRING, salary INT"

In [None]:
windows_df = spark.read.csv(data_file, schema=csv_schema)

display(windows_df)

In [None]:
from pyspark.sql.window import Window
from pyspark.sql.functions import col, desc, sum, max, avg, col, dense_rank, rank, row_number, to_date, round

In [None]:
window_spec = Window.partitionBy("dept")

In [None]:
window_df_2 = windows_df.withColumn("total_dept_salary", sum(col("salary")).over(window_spec)) \
                .withColumn("avg_dept_salary", round(avg(col("salary")).over(window_spec), 1)) \
                .withColumn("max_dept_salary", max(col("salary")).over(window_spec))

In [None]:
display(window_df_2)

In [None]:
window_spec_3 = Window \
    .partitionBy("dept") \
    .orderBy(col("salary")) \
    .rowsBetween(Window.unboundedPreceding, Window.currentRow)

In [None]:
window_df_3 = windows_df.withColumn("total_salary", sum(col("salary")).over(window_spec_3)) \
                .withColumn("avg_salary", round(avg(col("salary")).over(window_spec_3), 1)) \
                .withColumn("rank", rank().over(window_spec_3)) \
                .withColumn("drank", dense_rank().over(window_spec_3)) \
                .withColumn("row_num", row_number().over(window_spec_3))

In [None]:
display(window_df_3)

**Get top 3 employees with highest salary in each department**

In [None]:
window_spec = Window.partitionBy("dept").orderBy(desc("salary"))

In [None]:
top_emp_df = windows_df.withColumn("row_num", row_number().over(window_spec)) \
                .where("row_num <= 3") \
                .drop("row_num")

In [None]:
display(top_emp_df)