In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
import pyspark.sql.functions as F

In [None]:
spark = SparkSession.builder.appName("cache").getOrCreate()

In [None]:
df_customers = (
    spark.read.format("parquet")
    .option("header", "true")
    .load("/opt/bitnami/spark/custom_data/chapter7/customers/")
)

df_customers.show()

# Without cache

In [None]:
df_base = (
    df_customers
    .filter(F.col('city')=='boston')
    .withColumn(
        "customer_group",
        F.when(
            F.col("age").between(20,30),
            F.lit("young")
        )
        .when(
            F.col("age").between(31,50),
            F.lit("mid")
        )
        .when(
            F.col("age") > 51,
            F.lit("old")
        )
        .otherwise(F.lit('kid'))
    )
    .select("cust_id", "name", "age", "gender", "birthday", "zip", "city", "customer_group")
    
      )

df_base.show(5, truncate=False)

In [None]:
df1 = (
    df_base
    .withColumn("test_column_1", F.lit("test_column_1"))
    .withColumn("birth_year", F.split("birthday", "/").getItem(2))
      )

df1.explain()
df1.show(5, truncate=False)

In [None]:
df2 = (
    df_base
    .withColumn("test_column_2", F.lit("test_column_2"))
    .withColumn("birth_month", F.split("birthday", "/").getItem(1))
      )

df2.explain()
df2.show(5, truncate=False)

# Cache

In [None]:
df_base.cache()

In [None]:
df1 = (
    df_base
    .withColumn("test_column_1", F.lit("test_column_1"))
    .withColumn("birth_year", F.split("birthday", "/").getItem(2))
      )

df1.explain()
df1.show(5, truncate=False)

In [None]:
df2 = (
    df_base
    .withColumn("test_column_2", F.lit("test_column_2"))
    .withColumn("birth_month", F.split("birthday", "/").getItem(1))
      )

df2.explain()
df2.show(5, truncate=False)

In [None]:
spark.stop()