In [1]:
from pyspark.sql import SparkSession
from graphframes import GraphFrame
from receiptprocessor.graphframes.connected_components import ConnectedComponents


In [2]:
spark = (
    SparkSession.builder.master("local[2]")
    .appName("local-tests")
    .config("spark.jars.packages", "graphframes:graphframes:0.8.2-spark3.2-s_2.12")
    .getOrCreate()
)
spark.sparkContext.setCheckpointDir(".spark_checkpoint")

In [3]:
def cleaned_data(spark_session):
    data = [
        ("id1", "anid1", None, "userIDFA1", "SapphireId1", 0.1),
        ("id2", "anid1", "AdId1", None, None, 0.2),
        ("id3", "anid2", "AdId1", None, "SapphireId2", 0.3),
        ("id4", "anid3", None, None, "SapphireId2", 0.4),
        ("id5", "anid4", None, "userIDFA2", None, 0.5),
        ("id6", "anid5", None, "userIDFA2", "SapphireId3", 0.6),
        ("id7", "anid5", None, "userIDFA2", "SapphireId3", 0.7),
        ("id8", "anid5", None, "userIDFA2", "SapphireId3", 0.8),
        ("id9", "anid6", None, None, None, 0.9),
    ]

    df = spark_session.createDataFrame(
        data, ["id", "anid", "AdId", "userIDFA", "SapphireId", "Cashback"]
    )
    return df

df = cleaned_data(spark)
df.show()

+---+-----+-----+---------+-----------+--------+
| id| anid| AdId| userIDFA| SapphireId|Cashback|
+---+-----+-----+---------+-----------+--------+
|id1|anid1| null|userIDFA1|SapphireId1|     0.1|
|id2|anid1|AdId1|     null|       null|     0.2|
|id3|anid2|AdId1|     null|SapphireId2|     0.3|
|id4|anid3| null|     null|SapphireId2|     0.4|
|id5|anid4| null|userIDFA2|       null|     0.5|
|id6|anid5| null|userIDFA2|SapphireId3|     0.6|
|id7|anid5| null|userIDFA2|SapphireId3|     0.7|
|id8|anid5| null|userIDFA2|SapphireId3|     0.8|
|id9|anid6| null|     null|       null|     0.9|
+---+-----+-----+---------+-----------+--------+



In [4]:
edge_columns = ["id", "Cashback"]
vertice_columns = ["anid", "AdId", "SapphireId", "userIDFA"]
connected_components_transformer = ConnectedComponents(edge_columns=edge_columns, vertice_columns=vertice_columns)

In [5]:
connected_components_transformer.transform(df).show()

Wrote _tmp_graphframe/vertices to: _tmp_graphframe/vertices-0fd6b7d8-6f28-41e6-992c-9f827a708c90
Wrote _tmp_graphframe/edges to: _tmp_graphframe/edges-94f11e8f-f887-4543-a88b-bcda0beffec5
+-----------+----------+------------+
|         id|      Type|   component|
+-----------+----------+------------+
|      AdId1|      AdId|369367187456|
|      anid3|      anid|369367187456|
|      anid1|      anid|369367187456|
|SapphireId1|SapphireId|369367187456|
|      anid2|      anid|369367187456|
|  userIDFA1|  userIDFA|369367187456|
|SapphireId2|SapphireId|369367187456|
|  userIDFA2|  userIDFA|  8589934592|
|SapphireId3|SapphireId|  8589934592|
|      anid6|      anid|111669149696|
|      anid4|      anid|  8589934592|
|      anid5|      anid|  8589934592|
+-----------+----------+------------+

