In [1]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

conf = SparkConf().setAppName('count_connections').setMaster('local[*]')
sc = SparkContext(conf=conf)
spark = SparkSession(sc)

In [2]:
# generate a synthetic graph by taking 1% of all
# possible connections between 1000 vertices

x = spark.range(1000)
x_1 = x.select(col('id').alias('id_1'))
x_2 = x.select(col('id').alias('id_2'))

y = (
    x_1.crossJoin(x_2)
    .filter(col('id_1') < col('id_2'))
    .sample(False, 0.01)
)

connections = y
connections.show(20)
connections.createOrReplaceTempView('connections')

+----+----+
|id_1|id_2|
+----+----+
|   0|  21|
|   0|  33|
|   0| 120|
|   0| 320|
|   0| 702|
|   0| 741|
|   0| 820|
|   0| 899|
|   0| 928|
|   0| 956|
|   0| 960|
|   1|  21|
|   1| 248|
|   1| 356|
|   1| 523|
|   1| 731|
|   1| 767|
|   1| 999|
|   2|  31|
|   2| 176|
+----+----+
only showing top 20 rows



In [3]:
# perform analytics on the graph

connection_counts_df = spark.sql('''

SELECT id,
       COUNT(*) AS n_connections

FROM (
    SELECT id_1 AS id
    FROM connections
    UNION ALL
    SELECT id_2 AS id
    FROM connections
    )

GROUP BY 1
ORDER BY 2 DESC

''')


connection_counts_df.show(20)
connection_counts_df.createOrReplaceTempView('connection_counts')

+---+-------------+
| id|n_connections|
+---+-------------+
|578|           25|
|414|           23|
|170|           22|
|986|           22|
|676|           21|
|907|           21|
|884|           20|
|235|           20|
| 78|           19|
|976|           19|
|358|           19|
|798|           19|
|565|           19|
|355|           18|
|838|           18|
|240|           18|
| 44|           18|
|914|           18|
|665|           18|
|741|           17|
+---+-------------+
only showing top 20 rows



In [4]:
connection_stats = spark.sql('''

SELECT
    COUNT(*)                     AS count,
    AVG(n_connections)           AS mean_connections,
    SQRT(VAR_POP(n_connections)) AS stdev_connections
FROM connection_counts

''').collect()[0]

msg = '''
{0} ids in the dataset,
with an average connection count of {1:.1f} +/- {2:.1f}.
'''.strip().format(*connection_stats)

print(msg)

1000 ids in the dataset,
with an average connection count of 9.9 +/- 3.3.


In [5]:
spark.stop()