In [1]:
from pyspark import RDD, SparkContext
from pyspark.sql import SparkSession
from graphframes import *

import os
import sys
import math
import time

In [4]:
os.environ["PYSPARK_SUBMIT_ARGS"] = "--packages graphframes:graphframes:0.8.2-spark3.1-s_2.12 pyspark-shell"

In [5]:
sc = SparkContext("local", "HW4").getOrCreate()
sc.setLogLevel("WARN")

In [6]:
spark = SparkSession.builder.appName('hw4').getOrCreate()

In [7]:
input = "ub_sample_data.csv"

filter_threshold = 7
test_file = "yelp_val.csv"
output = "output.csv"


In [8]:
to_csv_RDD = sc.textFile(input).map(lambda line: line.split(','))
header = to_csv_RDD.first()
to_csv_RDD = to_csv_RDD.filter(lambda data: data != header)


In [9]:
user_biz_RDD = to_csv_RDD.map(lambda x: (x[0], x[1])).groupByKey().mapValues(lambda x:set(x))

In [10]:
len(user_biz_RDD.collect())

3374

In [60]:
all_pairs = user_biz_RDD.cartesian(user_biz_RDD).filter(lambda x: x[0] != x[1])     #Not filtering allows undirected
filtered_pairs = all_pairs.map(lambda x: ((x[0][0],x[1][0]), len(x[0][1].intersection(x[1][1])))).filter(lambda x:x[1]>=filter_threshold)

In [61]:
edges = filtered_pairs.map(lambda x:x[0])
vertices = filtered_pairs.flatMap(lambda x:x[0]).distinct()

In [64]:
e = edges.collect()

In [65]:
len(e)

996

In [62]:
vertex_spark_df = spark.createDataFrame(vertices, "String").toDF("id")

In [35]:
edge_spark_df = spark.createDataFrame(edges).toDF("src", "dst")

In [66]:
g = GraphFrame(vertex_spark_df, edge_spark_df)

In [68]:
communities = g.labelPropagation(maxIter=5)

In [69]:
communities = communities.rdd.map(lambda x: tuple(x)).map(lambda x: (x[1], x[0])).groupByKey().map(lambda x: sorted(x[1])).collect()

In [39]:
output_format_comm = sorted(communities, key = lambda x: (len(x),x))

In [50]:
output_format_comm

[['23y0Nv9FFWn_3UWudpnFMA'],
 ['3Vd_ATdvvuVVgn_YCpz8fw'],
 ['453V8MlGr8y61PpsDAFjKQ'],
 ['46HhzhpBfTdTSB5ceTx_Og'],
 ['Cf0chERnfd06ltnN45xLNQ'],
 ['F47atsRPw-KHmRVk5exBFw'],
 ['JeOHA8tW7gr-FDYOcPJoeA'],
 ['QYKexxaOJQlseGWmc6soRg'],
 ['Si3aMsOVGSVlsc54iuiPwA'],
 ['YVQFzWm0H72mLUh-8gzd5w'],
 ['_m1ot2zZetDgjerAD2Sidg'],
 ['cyuDrrG5eEK-TZI867MUPA'],
 ['d5WLqmTMvmL7-RmUDVKqqQ'],
 ['eqWEgMH-DCP74i82BEAZzw'],
 ['gH0dJQhyKUOVCKQA6sqAnw'],
 ['gUu0uaiU7UEUVIgCdnqPVQ'],
 ['jJDUCuPwVqwjbth3s92whA'],
 ['jSbXY_rno4hYHQCFftsWXg'],
 ['tX0r-C9BaHYEolRUfufTsQ'],
 ['vENR70IrUsDNTDebbuxyQA'],
 ['0KhRPd66BZGHCtsb9mGh_g', '5fQ9P6kbQM_E0dx8DL6JWA'],
 ['98rLDXbloLXekGjieuQSlA', 'MJ0Wphhko2-LbJ0uZ5XyQA'],
 ['EY8h9IJimXDNbPXVFpYF3A', 'LiNx18WUre9WFCEQlUhtKA'],
 ['Gr-MqCunME2K_KmsAwjpTA', '_6Zg4ukwS0kst9UtkfVw3w'],
 ['QRsuZ_LqrRU65dTs5CL4Lw', 'lJFBgSAccsMGwIjfD7LMeQ'],
 ['S1cjSFKcS5NVc3o1MkfpwA', 'mm9WYrFhiNqvHCyhQKw3Mg'],
 ['jgoG_hHqnhZvQEoBK0-82w', 'qd16czwFUVHICKF7A4qWsQ'],
 ['750rhwO7D_Cul7_GtO9Jsg',
  'DjcR

In [51]:
with open(output, 'w+') as f:
    for comm in output_format_comm:
        for i in range(len(comm)):
            f.write("\'" + comm[i] + "\'")
            if i == len(comm) - 1:
                f.write("\n")
            else:
                f.write(", ")