In [1]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

conf = SparkConf().setAppName('count_connections').setMaster('local[*]')
sc = SparkContext(conf=conf)
spark = SparkSession(sc)

In [2]:
# generate a synthetic graph

x = spark.range(1000)
x_1 = x.select(col('id').alias('id_1'))
x_2 = x.select(col('id').alias('id_2'))

y = (
    x_1.crossJoin(x_2)
    .filter(col('id_1') < col('id_2'))
    .sample(False, 0.01)
)

connections = y
connections.show(20)
connections.createOrReplaceTempView('connections')

+----+----+
|id_1|id_2|
+----+----+
|   0| 313|
|   0| 391|
|   0| 490|
|   0| 669|
|   0| 719|
|   1|  50|
|   1| 111|
|   1| 184|
|   1| 209|
|   1| 234|
|   1| 421|
|   1| 550|
|   1| 664|
|   1| 743|
|   1| 764|
|   1| 817|
|   1| 851|
|   1| 866|
|   1| 997|
|   2|  92|
+----+----+
only showing top 20 rows



In [3]:
connection_counts_df = spark.sql('''

SELECT id,
       COUNT(*) AS n_connections

FROM (
    SELECT id_1 AS id
    FROM connections
    UNION ALL
    SELECT id_2 AS id
    FROM connections
    )

GROUP BY 1
ORDER BY 2 DESC

''')


connection_counts_df.show(20)
connection_counts_df.createOrReplaceTempView('connection_counts')

+---+-------------+
| id|n_connections|
+---+-------------+
| 47|           22|
|306|           20|
|702|           20|
|700|           20|
|567|           19|
|450|           19|
|282|           19|
|352|           18|
|500|           18|
|554|           18|
|779|           18|
| 24|           18|
|538|           17|
|752|           17|
|171|           17|
|224|           17|
|345|           17|
|628|           17|
|644|           17|
|623|           17|
+---+-------------+
only showing top 20 rows



In [4]:
avg_connections = spark.sql(

    'SELECT COUNT(*), AVG(n_connections) FROM connection_counts'

).collect()[0]

msg = '''
{0} ids in the dataset,
with an average connection count of {1}.
'''.format(*avg_connections)

print(msg)


1000 ids in the dataset,
with an average connection count of 9.956.



In [5]:
spark.stop()