The problem statement is as follows:

Here are working with two tables, polls and poll_answers, which contain information about poll responses and the correct answer for each poll, respectively. The goal is to calculate the winnings for users who selected the correct poll option, where the total winnings are derived based on the sum of amounts from incorrect answers.

polls Table:

Contains information about individual responses to various polls.
Columns:
user_id: The ID of the user who participated in the poll.
poll_id: The ID of the poll.
poll_option_id: The option chosen by the user.
amount: The amount wagered by the user.
created_date: The date when the poll response was submitted.
poll_answers Table:

Contains the correct option for each poll.
Columns:
poll_id: The ID of the poll.
correct_option_id: The correct option for the poll.

In [0]:
from datetime import datetime
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DateType


In [0]:
# Sample data for the polls DataFrame with proper date conversion
polls_data = [
    ('id1', 'p1', 'A', 200, datetime.strptime('2021-12-01', '%Y-%m-%d').date()),
    ('id2', 'p1', 'C', 250, datetime.strptime('2021-12-01', '%Y-%m-%d').date()),
    ('id3', 'p1', 'A', 200, datetime.strptime('2021-12-01', '%Y-%m-%d').date()),
    ('id4', 'p1', 'B', 500, datetime.strptime('2021-12-01', '%Y-%m-%d').date()),
    ('id5', 'p1', 'C', 50, datetime.strptime('2021-12-01', '%Y-%m-%d').date()),
    ('id6', 'p1', 'D', 500, datetime.strptime('2021-12-01', '%Y-%m-%d').date()),
    ('id7', 'p1', 'C', 200, datetime.strptime('2021-12-01', '%Y-%m-%d').date()),
    ('id8', 'p1', 'A', 100, datetime.strptime('2021-12-01', '%Y-%m-%d').date()),
    ('id9', 'p2', 'A', 300, datetime.strptime('2023-01-10', '%Y-%m-%d').date()),
    ('id10', 'p2', 'C', 400, datetime.strptime('2023-01-11', '%Y-%m-%d').date()),
    ('id11', 'p2', 'B', 250, datetime.strptime('2023-01-12', '%Y-%m-%d').date()),
    ('id12', 'p2', 'D', 600, datetime.strptime('2023-01-13', '%Y-%m-%d').date()),
    ('id13', 'p2', 'C', 150, datetime.strptime('2023-01-14', '%Y-%m-%d').date()),
    ('id14', 'p2', 'A', 100, datetime.strptime('2023-01-15', '%Y-%m-%d').date()),
    ('id15', 'p2', 'C', 200, datetime.strptime('2023-01-16', '%Y-%m-%d').date())
]
# Create polls DataFrame
polls_df = spark.createDataFrame(polls_data, schema=polls_schema)
polls_df.display()


user_id,poll_id,poll_option_id,amount,created_date
id1,p1,A,200,2021-12-01
id2,p1,C,250,2021-12-01
id3,p1,A,200,2021-12-01
id4,p1,B,500,2021-12-01
id5,p1,C,50,2021-12-01
id6,p1,D,500,2021-12-01
id7,p1,C,200,2021-12-01
id8,p1,A,100,2021-12-01
id9,p2,A,300,2023-01-10
id10,p2,C,400,2023-01-11


In [0]:
# Define schema for the poll_answers DataFrame
poll_answers_schema = StructType([
    StructField("poll_id", StringType(), True),
    StructField("correct_option_id", StringType(), True)
])

# Sample data for the poll_answers DataFrame
poll_answers_data = [
    ('p1', 'C'),
    ('p2', 'A')
]

# Create poll_answers DataFrame
poll_answers_df = spark.createDataFrame(poll_answers_data, schema=poll_answers_schema)
poll_answers_df.display()

poll_id,correct_option_id
p1,C
p2,A


In [0]:
poll_answers_df.createOrReplaceTempView('poll_answers')
polls_df.createOrReplaceTempView('polls')

In [0]:
%sql
WITH a as (
SELECT user_id, a.poll_id, correct_option_id,
		FLOOR(
		SUM(CASE WHEN correct_option_id IS NULL THEN amount END) OVER (PARTITION BY a.poll_id)
		*
		CASE WHEN correct_option_id IS NOT NULL THEN
		(amount*1.0 / SUM(CASE WHEN correct_option_id IS NOT NULL THEN amount END) OVER (PARTITION BY a.poll_id) * 100)
		END / 100) amount_win
FROM
polls a
LEFT JOIN poll_answers b
ON a.poll_id = b.poll_id AND b.correct_option_id = a.poll_option_id
)

SELECT user_id, poll_id, amount_win
FROM a
WHERE correct_option_id IS NOT NULL



user_id,poll_id,amount_win
id2,p1,750
id5,p1,150
id7,p1,600
id9,p2,1200
id14,p2,400


In [0]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# Assuming polls_df and poll_answers_df are already created

# Step 1: Perform LEFT JOIN on poll_id and poll_option_id
# Rename columns to avoid ambiguity after the join
joined_df = polls_df.alias("a").join(
    poll_answers_df.alias("b"),
    (F.col("a.poll_id") == F.col("b.poll_id")) & (F.col("a.poll_option_id") == F.col("b.correct_option_id")),
    "left"
).select(
    F.col("a.user_id"),
    F.col("a.poll_id").alias("poll_id"),
    F.col("a.poll_option_id"),
    F.col("a.amount"),
    F.col("b.correct_option_id")
)

# Step 2: Define window specification to partition by poll_id
window_spec = Window.partitionBy("poll_id")

# Step 3: Calculate total losers amount and winners amount using conditional aggregation
result_df = joined_df.withColumn(
    "total_losers_amount", 
    F.sum(F.when(F.col("correct_option_id").isNull(), F.col("amount")).otherwise(0)).over(window_spec)
).withColumn(
    "total_winners_amount", 
    F.sum(F.when(F.col("correct_option_id").isNotNull(), F.col("amount")).otherwise(0)).over(window_spec)
).withColumn(
    "amount_win", 
    F.floor(
        F.col("total_losers_amount") * (
            F.when(F.col("correct_option_id").isNotNull(), 
                   (F.col("amount") * 1.0) / F.col("total_winners_amount") * 100
            ).otherwise(0)
        ) / 100
    )
)

# Step 4: Filter for only those records where correct_option_id is not null
final_df = result_df.filter(F.col("correct_option_id").isNotNull()).select("user_id", "poll_id", "amount_win")

# Display the final result
final_df.display()


user_id,poll_id,amount_win
id2,p1,750
id5,p1,150
id7,p1,600
id9,p2,1200
id14,p2,400


Explanation:

Join: Left join between polls_df and poll_answers_df on poll_id and poll_option_id.

Window Functions: Window.partitionBy("poll_id") is used for window operations like summing the amounts based on correct_option_id being NULL or NOT NULL.

Conditional Aggregation: The when function is used to apply the conditional logic inside the sum and floor functions.

Filter: The final step filters out records where correct_option_id is NULL.

In [0]:
Final Output:
    
You need to output the following columns:

user_id: The ID of the user.
poll_id: The ID of the poll.
amount_win: The amount won by the user who selected the correct option.
Constraints:
Ensure that the calculation handles cases where the correct_option_id is NULL (indicating incorrect answers).
Avoid column name ambiguities when joining the two tables.