In [1]:
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession, Row, DataFrame
from pyspark.sql.types import StructType, StructField, IntegerType, ArrayType
import pyspark.sql.functions as sf
import os


In [2]:
#  setup
conf = SparkConf().setMaster("local").setAppName("DegreesOfSeparation")
sc = SparkContext(conf = conf)

spark = SparkSession.builder.appName('HerosBreadthSearch').getOrCreate()

In [None]:
# from enum import Enum
# from functools import total_ordering

# @total_ordering
# class ProcessStatus(Enum):
#     NOT_PROCESSED: str = 0
#     PROCESSING: str = 1
#     PROCESSED: str = 2
#     def __eq__(self, other) -> bool:
#         if isinstance(other, self.__class__):
#             return self.value == other.value
#         else:
#             return NotImplemented
#     def __lt__(self, other) -> bool:
#         if isinstance(other, self.__class__):
#             return self.value < other.value
#         else:
#             return NotImplemented

# ProcessStatus.PROCESSED > ProcessStatus.NOT_PROCESSED

In [5]:
heroSeprationSchema = StructType([
    StructField('id', IntegerType(), False),
    StructField('connections', ArrayType(IntegerType(), True), False),
])

def hero_connections_parse_line(line: str) -> Row:
    fields = line.split()
    return Row(
        id = int(fields[0]),
        connections = [int(connection) for connection in fields[1:]],
        # distance = 9999,
        # processStatus = 0
    )

def loadStartingDF() -> DataFrame:
    inputFile = sc.textFile(f"file:///{os.path.abspath('')}/../data/Marvel-Graph.txt")
    rdd = inputFile.map(hero_connections_parse_line)
    return spark.createDataFrame(rdd, schema=heroSeprationSchema)

df = loadStartingDF()
df.printSchema()
df.show()

root
 |-- id: integer (nullable = false)
 |-- connections: array (nullable = false)
 |    |-- element: integer (containsNull = true)

+----+--------------------+
|  id|         connections|
+----+--------------------+
|5988|[748, 1722, 3752,...|
|5989|[4080, 4264, 4446...|
|5982|[217, 595, 1194, ...|
|5983|[1165, 3836, 4361...|
|5980|[2731, 3712, 1587...|
|5981|[3569, 5353, 4087...|
|5986|[2658, 3712, 2650...|
|5987|[2614, 5716, 1765...|
|5984|[590, 4898, 745, ...|
|5985|[3233, 2254, 212,...|
|6294|[4898, 1127, 3220...|
| 270|[2658, 3003, 3805...|
| 271|[4935, 5716, 4309...|
| 272|[2717, 4363, 4088...|
| 273|[1165, 5013, 5110...|
| 274|[3920, 5310, 4024...|
| 275|[4366, 3373, 1587...|
| 276|[2277, 5251, 4806...|
| 277|[1068, 3495, 6194...|
| 278|[1145, 667, 2650,...|
+----+--------------------+
only showing top 20 rows



In [15]:
# get list of ids
# df.select(sf.collect_list('id')).first()[0]

In [12]:
def connections_by_hero(df: DataFrame) -> DataFrame:
    """Makes id unique, connections will be set of all connections with same id from original df

        Args:
            df (DataFrame): original df, where id's are not unique

        Returns:
            DataFrame: where id are unique, connections are set with all connection without repetition
        """
    return \
        df.select('id', sf.explode('connections').alias('connections'))\
        .groupBy('id').agg(sf.collect_set('connections').alias('connections'))
        # this would work as well
        # df.withColumn('connections', sf.explode('connections'))\
        # .groupBy('id').agg(sf.collect_set('connections').alias('connections'))


In [35]:
def firstIterationDf(startDf: DataFrame, startId: int) -> DataFrame:
    # make id unique and connections complete
    df = connections_by_hero(startDf)
    # mark first row to process
    return df.withColumn(
        'processStatus',
        sf.when(df.id == startId, 1).otherwise(0)
    )

def countTargetHits(df: DataFrame, targetId: int) -> int:
    # how many heros have the targetId among their connections
    hit_times = df.where(sf.array_contains(df.connections, targetId)).count()
    print(f"{df.select(sf.explode('connections')).count()} connections were inspected.")
    return hit_times

def updateProcessStatus(df: DataFrame) -> DataFrame:
    # get ids of next heros to process
    isProcessing = df.processStatus == 1
    nextIds = df.where(isProcessing)\
        .select(sf.explode('connections').alias('id'))\
        .join(df.select('id', 'processStatus'), on='id')\
        .where(sf.col('processStatus')==0)\
        .select(sf.collect_set('id')).first()[0]
    isToProcessNext = df.id.isin(nextIds)
    # mark current processing as processed, next to be processed as processing
    return \
        df.select(
            'id', 'connections',
            sf.when(isProcessing, 2)\
                .when(isToProcessNext, 1)\
                .otherwise(df.processStatus)\
                .alias('processStatus')
        )

def degreeSeparation(startDf: DataFrame, startId: int, targetId: int) -> int:
    df = firstIterationDf(startDf, startId)

    for distance in range (1, 11):
        print(f"Running BFS iteration # {distance}")
        # process layer (rows with process status 1)
        hit_times = countTargetHits(df.where(df.processStatus == 1), targetId)
        if hit_times:
            print(f"Hit the target character! From {hit_times} different direction(s).")
            break
        df = updateProcessStatus(df)
    
    return distance

degreeSeparation(loadStartingDF(), 5306, 14)

Running BFS iteration # 1
1741 connections were inspected.
Running BFS iteration # 2
214129 connections were inspected.
Hit the target character! From 1 different direction(s).


2