### Task 1

Based on the sales dataframe used above, create a new dataframe that has the following fields:
* order_id
* email
* purchase_revenue_in_usd
* value_added_tax

Value_added_tax should be an UDF, calculated as follows:
* if purchase_revenue_in_usd > 1500: tax = revenue*0.2
* else: tax = revenue*0.1

The dataframe should be filtered to only include rows where value_added_tax is above 110.

The dataframe should be returned in a descending order based on the value_added_tax column.

In [0]:
from pyspark.sql.types import DoubleType
import pyspark.sql.functions as F

sales_df = spark.read.parquet("/mnt/training/ecommerce/sales/sales.parquet")

def value_added_tax(revenue):
  if revenue > 1500:
    return revenue * 0.2
  else:
    return revenue * 0.1

vat = udf(value_added_tax, DoubleType())

vat_df = (sales_df.select("order_id", "email", "purchase_revenue_in_usd")
  .withColumn("value_added_tax", vat("purchase_revenue_in_usd"))
  .filter("value_added_tax > 110")
  .orderBy(F.desc("value_added_tax")))
              
display(vat_df)


### Task 2

Join the events dataframe from above and the zips dataset at path:</br>
/mnt/training/zips.json

The join should be done based on city and state. Note that the join is case-sensitive, so transform the columns accordingly before the join.

Return a dataframe which has the following columns:
* user_id
* latitude
* longitude

Latitude is element 2 in the "loc" column of the zips dataset</br>
Longitude is element 1 in the "loc" column of the zips dataset

In [0]:
events_df = (spark.read
             .option("inferSchema", True)
             .json("/mnt/training/ecommerce/events/events-500k.json")
             )

zips_df = (spark.read
            .option("inferSchema", True)
            .json("/mnt/training/zips.json")
            )

result_df = (events_df.select("user_id"
                             , F.upper("geo.city").alias("city")
                             , F.upper("geo.state").alias("state")
                            )
             .join(
               zips_df.select(F.upper("city").alias("city")
                              , F.upper("state").alias("state")
                              , F.element_at(F.col("loc"), 2).alias("latitude")
                              , F.element_at(F.col("loc"), 1).alias("longitude")
                             )
               , ["city", "state"]
             )
             .drop("city", "state")
            )
      
display(result_df)