In [0]:
from pyspark import SparkConf
from pyspark.sql.session import SparkSession
from pyspark.sql.types import StructType,StructField,IntegerType,FloatType
from pyspark.sql.functions import col,count,sum
from pyspark.sql.window import Window

spark = SparkSession.builder.appName("app").master("local[2]").getOrCreate()

In [0]:
schema = StructType([
    StructField("pid",IntegerType(),False),
    StructField("tiv_2015",IntegerType(),False),
    StructField("tiv_2016",IntegerType(),False),
    StructField("lat",IntegerType(),False),
    StructField("lon",IntegerType(),False)
])
data = [
( 1   , 10       , 5        , 10  , 10)  ,
( 2   , 20       , 20       , 20  , 20)  ,
( 3   , 10       , 30       , 20  , 20)  ,
( 4   , 10       , 40       , 40  , 40)  
]
insurance = spark.createDataFrame(data,schema)
insurance.show()

+---+--------+--------+---+---+
|pid|tiv_2015|tiv_2016|lat|lon|
+---+--------+--------+---+---+
|  1|      10|       5| 10| 10|
|  2|      20|      20| 20| 20|
|  3|      10|      30| 20| 20|
|  4|      10|      40| 40| 40|
+---+--------+--------+---+---+



In [0]:
# Write a solution to report the sum of all total investment values in 2016 tiv_2016, for all policyholders who:
#   * have the same tiv_2015 value as one or more other policyholders, and
#   * are not located in the same city as any other policyholder (i.e., the (lat, lon) attribute pairs must be unique).
# Round tiv_2016 to two decimal places.
window_spec_policyholder = Window.partitionBy("tiv_2015")
window_spec_address = Window.partitionBy("lat","lon")
insurance.select("pid","tiv_2016",count("pid").over(window_spec_policyholder).alias("other_policyholder"),count("pid").over(window_spec_address).alias("same_city_member"))\
         .filter((col("other_policyholder")>1) & (col("same_city_member")==1)).select(sum("tiv_2016").alias("tiv_2016"))\
         .show()

+--------+
|tiv_2016|
+--------+
|      45|
+--------+



In [0]:
insurance.createOrReplaceTempView("ins")
spark.sql("""with cte as 
                (select tiv_2016, count(*) over(partition by tiv_2015) as same_policy, count(*) over(partition by lat,lon) as same_address from ins)
            select sum(tiv_2016) tiv_2016 from cte where same_policy>1 and same_address=1
        """).show()

+--------+
|tiv_2016|
+--------+
|      45|
+--------+



In [0]:
spark.stop()