In [19]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import lit, array_contains, sum, collect_list, array, col, explode, struct

In [2]:
spark = (
    SparkSession.builder.master("local[2]")
    .appName("local-tests")
    .config("spark.jars.packages", "graphframes:graphframes:0.8.2-spark3.2-s_2.12")
    .getOrCreate()
)
from graphframes import *

df = spark.createDataFrame(
    [
        ("id1", "anid1", None, "userIDFA1", "SapphireId1", 0.1),
        ("id2", "anid1", "AdId1", None, None, 0.2),
        ("id3", "anid2", "AdId1", None, "SapphireId2", 0.3),
        ("id4", "anid3", None, None, "SapphireId2", 0.4),
        ("id5", "anid4", None, "userIDFA2", None, 0.5),
        ("id6", "anid5", None, "userIDFA2", "SapphireId3", 0.6),
        ("id7", "anid5", None, "userIDFA2", "SapphireId3", 0.7),
    ],
    ["id", "anid", "AdId", "userIDFA", "SapphireId", "Cashback"],
)

id_cols = ["anid", "AdId", "userIDFA", "SapphireId"]

In [3]:
def melt(
    df, id_vars, value_vars, var_name="variable", value_name="value", dropna=False
):
    # Create an array of column names for the id variables
    id_cols = [col(c) for c in id_vars]

    # Create an array of column names for the value variables
    value_cols = [
        struct(lit(c).alias(var_name), col(c).alias(value_name)) for c in value_vars
    ]

    # Explode the value columns into rows
    exploded_df = df.select(id_cols + [explode(array(value_cols)).alias("tmp")])

    # Extract the variable and value columns from the struct
    result_df = exploded_df.select(
        id_cols
        + [
            col("tmp")[var_name].alias(var_name),
            col("tmp")[value_name].alias(value_name),
        ]
    )

    # Drop null values if dropna is True
    if dropna:
        result_df = result_df.dropna()

    return result_df

In [12]:
def createGraphFrame(df, id_vars, value_vars, src_col="src", dst_col="dst"):
    melted_df = melt(
        df, id_vars, value_vars, var_name="Type", value_name="value", dropna=True
    )
    src = melted_df.select(*id_vars, col("value").alias(src_col)).alias("src")
    dst = melted_df.select(*id_vars, col("value").alias(dst_col)).alias("dst")
    edges_df = (
        src.join(dst, on=id_vars)
        .where(col(src_col) != col(dst_col))
        .select(src_col, dst_col, *id_vars)
    )
    v = melted_df.select(col("value").alias("id"), "Type").distinct()
    g = GraphFrame(v, edges_df)
    return g


g = createGraphFrame(df, ["id", "Cashback"], id_cols)
g.vertices.show()
g.edges.show()

+-----------+----------+
|         id|      Type|
+-----------+----------+
|      AdId1|      AdId|
|      anid1|      anid|
|SapphireId1|SapphireId|
|      anid2|      anid|
|  userIDFA1|  userIDFA|
|SapphireId2|SapphireId|
|  userIDFA2|  userIDFA|
|      anid3|      anid|
|SapphireId3|SapphireId|
|      anid4|      anid|
|      anid5|      anid|
+-----------+----------+

+-----------+-----------+---+--------+
|        src|        dst| id|Cashback|
+-----------+-----------+---+--------+
|      anid1|  userIDFA1|id1|     0.1|
|      anid1|SapphireId1|id1|     0.1|
|  userIDFA1|      anid1|id1|     0.1|
|  userIDFA1|SapphireId1|id1|     0.1|
|SapphireId1|      anid1|id1|     0.1|
|SapphireId1|  userIDFA1|id1|     0.1|
|      anid1|      AdId1|id2|     0.2|
|      AdId1|      anid1|id2|     0.2|
|      anid2|      AdId1|id3|     0.3|
|      anid2|SapphireId2|id3|     0.3|
|      AdId1|      anid2|id3|     0.3|
|      AdId1|SapphireId2|id3|     0.3|
|SapphireId2|      anid2|id3|     0.3|


In [13]:
spark.sparkContext.setCheckpointDir(".spark_checkpoint")
type_id_to_component = g.connectedComponents()

component_df_grouped = (
    type_id_to_component.groupBy("component").pivot("Type").agg(collect_list("id"))
)
component_df_grouped = component_df_grouped.toDF(
    *["component_" + c if c != "component" else c for c in component_df_grouped.columns]
)

In [18]:
def join_component_with_transactions(component_df_grouped, transactions_df, id_cols):
    # Join component_df_grouped with transactions_df on matching ids
    for col in id_cols:
        join_condition = array_contains(
            component_df_grouped["component_" + col], transactions_df[col]
        )
    joined_df = component_df_grouped.alias("component_df").join(
        transactions_df.alias("transactions_df"), join_condition, "left_outer"
    )
    return joined_df


test = join_component_with_transactions(component_df_grouped, df, id_cols)
test.show()

+------------+--------------+--------------------+--------------------+------------------+---+-----+-----+---------+-----------+--------+
|   component|component_AdId|component_SapphireId|      component_anid|component_userIDFA| id| anid| AdId| userIDFA| SapphireId|Cashback|
+------------+--------------+--------------------+--------------------+------------------+---+-----+-----+---------+-----------+--------+
|  8589934592|            []|       [SapphireId3]|      [anid4, anid5]|       [userIDFA2]|id6|anid5| null|userIDFA2|SapphireId3|     0.6|
|  8589934592|            []|       [SapphireId3]|      [anid4, anid5]|       [userIDFA2]|id7|anid5| null|userIDFA2|SapphireId3|     0.7|
|369367187456|       [AdId1]|[SapphireId1, Sap...|[anid1, anid2, an...|       [userIDFA1]|id1|anid1| null|userIDFA1|SapphireId1|     0.1|
|369367187456|       [AdId1]|[SapphireId1, Sap...|[anid1, anid2, an...|       [userIDFA1]|id3|anid2|AdId1|     null|SapphireId2|     0.3|
|369367187456|       [AdId1]|[Sapp

In [15]:
component_level_features = test.groupBy("component").agg(
    sum(col("Cashback")).alias("totalCashback")
)

In [16]:
component_df_grouped.join(component_level_features, on="component").show()

+------------+--------------+--------------------+--------------------+------------------+------------------+
|   component|component_AdId|component_SapphireId|      component_anid|component_userIDFA|     totalCashback|
+------------+--------------+--------------------+--------------------+------------------+------------------+
|  8589934592|            []|       [SapphireId3]|      [anid4, anid5]|       [userIDFA2]|1.2999999999999998|
|369367187456|       [AdId1]|[SapphireId1, Sap...|[anid1, anid2, an...|       [userIDFA1]|               0.8|
+------------+--------------+--------------------+--------------------+------------------+------------------+

