In [0]:
from pyspark import SparkConf
from pyspark.sql.session import SparkSession
from pyspark.sql.types import StructType,StructField,DateType,IntegerType
from pyspark.sql.functions import col, avg,sum,round
from datetime import datetime

spark = SparkSession.builder.appName("app").master("local[2]").getOrCreate()

In [0]:
schema = StructType([
    StructField("product_id",IntegerType(),False),
    StructField("start_date",DateType(),False),
    StructField("end_date",DateType(),False),
    StructField("price",IntegerType(),False)
])

data = [
    ( 1          , datetime(2019,2,17) , datetime(2019,2,28) , 5      ),
    ( 1          , datetime(2019,3,1) , datetime(2019,3,22) , 20      ),
    ( 2          , datetime(2019,2,1) , datetime(2019,2,20) , 15      ),
    ( 2          , datetime(2019,2,21) , datetime(2019,3,31) , 30     )
]

prices = spark.createDataFrame(data,schema)
prices.show()

+----------+----------+----------+-----+
|product_id|start_date|  end_date|price|
+----------+----------+----------+-----+
|         1|2019-02-17|2019-02-28|    5|
|         1|2019-03-01|2019-03-22|   20|
|         2|2019-02-01|2019-02-20|   15|
|         2|2019-02-21|2019-03-31|   30|
+----------+----------+----------+-----+



In [0]:
schema = StructType([
    StructField("product_id",IntegerType(),False),
    StructField("purchase_date",DateType(),False),
    StructField("units",IntegerType(),False)
])

data = [
    ( 1          , datetime(2019,2,25)    , 100   ),
    ( 1          , datetime(2019,3,1 )    , 15    ),
    ( 2          , datetime(2019,2,10)    , 200   ),
    ( 2          , datetime(2019,3,22)    , 30    )
]

sale = spark.createDataFrame(data,schema)
sale.show()

+----------+-------------+-----+
|product_id|purchase_date|units|
+----------+-------------+-----+
|         1|   2019-02-25|  100|
|         1|   2019-03-01|   15|
|         2|   2019-02-10|  200|
|         2|   2019-03-22|   30|
+----------+-------------+-----+



In [0]:
# Write a solution to find the average selling price for each product. average_price should be rounded to 2 decimal places.
# Return the result table in any order. There can be products who were never sold
sale.join(prices,((sale.hint("range_join",3).product_id==prices.product_id) & (sale.purchase_date.between(prices.start_date,prices.end_date)) ),'right')\
    .withColumn("sale_amount",sale.units*prices.price).groupBy(sale.product_id).agg(round(sum("sale_amount")/sum("units"),2).alias("average_price")).show()

+----------+-------------+
|product_id|average_price|
+----------+-------------+
|         1|         6.96|
|         2|        16.96|
+----------+-------------+



In [0]:
sale.createOrReplaceTempView("s")
prices.createOrReplaceTempView("p")
spark.sql("""select p.product_id, coalesce(round(sum(p.price*s.units)/sum(s.units),2),0) as average_price 
                from  s right join  
                p 
                on s.product_id=p.product_id and s.purchase_date between p.start_date and p.end_date 
                group by p.product_id""").show()

+----------+-------------+
|product_id|average_price|
+----------+-------------+
|         1|         6.96|
|         2|        16.96|
+----------+-------------+



In [0]:
spark.stop()