In [1]:
import json
from pyspark import SparkContext, SparkConf
from pyspark.sql import SQLContext, DataFrameWriter, DataFrameReader
from pyspark.sql.types import *
from pyspark.sql.functions import *

In [2]:
def get_config():
    with open("../config.json", "r") as f:
        jsonstr = f.read()
        conf = json.loads(jsonstr)
        return conf

In [3]:
def get_spark_conf(config):
    '''set config'''
    conf = SparkConf()
    conf.setAppName('yelp')
    conf.set('spark.master', config["spark"]["master_url"])
    return conf

In [4]:
def get_pg_props(config):
    '''set psql properties'''
    props = {
        "user": config["postgres"]["user"],
        "password": config["postgres"]["password"],
        "driver": "org.postgresql.Driver",
    }
    return props

In [5]:
env = "development"

In [6]:
def getdf(sql_context, config):
    '''filter yelp dataset'''
    df = sql_context.read.json(config["yelp"]["s3"])
    return df

In [7]:
config = get_config()
spark_conf = get_spark_conf(config)
sc = SparkContext(conf=spark_conf)

In [8]:
sql_context = SQLContext(sc)

In [9]:
raw_df = getdf(sql_context, config)

In [22]:
def get_category_set(raw_df):
    def split_categories(row):
        if row.categories != None:
            return row.categories.split(", ")
        return []
    
    raw_categories = raw_df.rdd.map(split_categories).collect()
    category_set = set()
    for row in raw_categories:
        category_set.update(row)
    return category_set

In [25]:
all_categories = list(get_category_set(raw_df))
print(len(all_categories))

1305


In [46]:
def write_to_pg(df, table, config):
    '''write to psql'''
    url = config["postgres"][env]["jdbc"]
    props = get_pg_props(config)
    df.write.jdbc(url=url, table=table, mode='overwrite', properties=props)

In [78]:
def write_category_df_to_pg(all_categories, config):
    schema = StructType([
        StructField("id", IntegerType(), False),
        StructField("name", StringType(), False)
    ])
    df = sql_context.createDataFrame(zip(range(len(all_categories)), all_categories), schema)
    write_to_pg(df, "categories", config)
    return df

In [79]:
category_df = write_category_df_to_pg(all_categories, config)

In [73]:
def write_yelp_df_to_pg(raw_df, config):
    selected_columns = ["id", "name", "latitude", "longitude", "stars", "review_count", "address", "city", "state"]
    df = raw_df.withColumn("id", monotonically_increasing_id())
    write_to_pg(df[selected_columns], "yelp", config)
    return df[selected_columns + ["categories"]]

In [104]:
def write_yelp2category_to_pg(yelp_df, category_df, config):
    def categories_to_ids(row):
        categories = []
        if row["categories"] != None:
            categories = row["categories"].split(", ")
        return zip([row.id] * len(categories), [category_dict[cat] for cat in categories])    
    
    yelp_df = write_yelp_df_to_pg(raw_df, config)
    category_dict = {}
    for row in category_df.collect():
        category_dict[row.name] = row.id
        
    yelp2cat_rdd = yelp_df.rdd.flatMap(categories_to_ids)
    schema = StructType([
        StructField("yelp_id", LongType(), False),
        StructField("category_id", IntegerType(), False)
    ])
    yelp2cat_df = sql_context.createDataFrame(yelp2cat_rdd, schema)
    
    write_to_pg(yelp2cat_df, "yelp2category", config)
    return yelp2cat_df

In [98]:
yelp_df = write_yelp_df_to_pg(raw_df, config)

In [105]:
write_yelp2category_to_pg(yelp_df, category_df, config)

DataFrame[yelp_id: bigint, category_id: int]