In [2]:
import pyspark as ps
import pyspark.sql.functions as f

In [3]:
spark = (ps.sql.SparkSession.builder 
    .master("local") 
    .appName("pipeline")
    .getOrCreate()
    )
sc = spark.sparkContext
sc.setLogLevel("WARN")

In [4]:
## loading responses with survey data into a spark dataframe
path = "../data/SharedResponsesSurvey_10000.csv"
responses = spark.read.csv(path, header=True)

## pullin out all the countries (n < 100)
path = "../data/country_cluster_map.csv"
countries = spark.read.csv(path, header=True).select("ISO3")

In [5]:
responses.columns

['ResponseID',
 'ExtendedSessionID',
 'UserID',
 'ScenarioOrder',
 'Intervention',
 'PedPed',
 'Barrier',
 'CrossingSignal',
 'AttributeLevel',
 'ScenarioTypeStrict',
 'ScenarioType',
 'DefaultChoice',
 'NonDefaultChoice',
 'DefaultChoiceIsOmission',
 'NumberOfCharacters',
 'DiffNumberOFCharacters',
 'Saved',
 'Template',
 'DescriptionShown',
 'LeftHand',
 'UserCountry3',
 'Review_age',
 'Review_education',
 'Review_gender',
 'Review_income',
 'Review_political',
 'Review_religious']

In [8]:
## some EDA stuff
## getting responses by country
users = responses.select(["UserID", "UserCountry3", \
                  "Review_age","Review_education", \
                  "Review_gender", "Review_income", \
                  "Review_political" ,"Review_religious"]).groupby(["UserID", "UserCountry3", \
                  "Review_age","Review_education", \
                  "Review_gender", "Review_income", \
                  "Review_political" ,"Review_religious"]).agg({"UserID":"count"})

users_by_country = users.select(["UserCountry3","UserID"]).groupby("UserCountry3").agg({"UserID": "count"})
top_countries = users_by_country.select('*')\
                                .orderBy("count(userID)", ascending=False)\
                                .filter(users_by_country['count(UserID)'] > 20)\
                                .limit(50)
top_countries.write.format('csv').save("../data/users_by_country_50.csv")

## data for United States users
# us_users = users.select("*").filter("UserCountry3 = 'USA' ")
# us_users.write.csv("../data/us_users.csv", mode="overwrite")

users.show(20)

+----------------+------------+----------+----------------+-------------+-------------+----------------+----------------+-------------+
|          UserID|UserCountry3|Review_age|Review_education|Review_gender|Review_income|Review_political|Review_religious|count(UserID)|
+----------------+------------+----------+----------------+-------------+-------------+----------------+----------------+-------------+
| 809483690453245|         USA|        16|       underHigh|       female|      default|             0.5|             0.5|           12|
|7134809177082630|         FRA|        37|         college|         male|        35000|            0.22|            0.52|           12|
|8751981610944470|         KOR|        27|        bachelor|         male|        25000|            0.33|             0.9|           11|
|6057689097792270|         TWN|        16|            high|       female|    under5000|            0.39|            0.84|           12|
|7689448073852970|         VEN|        13|      

In [82]:
## user demographics
us_users = users.select('*').filter("UserCountry3 = 'USA' ")
n = us_users.count()
print(n)
genders = us_users.select("Review_gender").groupby("Review_gender").agg({"Review_gender": "count"})

ages = us_users.select("Review_age").groupby("Review_age").agg({"Review_age": "count"})
political = us_users.select("Review_political")


204
+-------------+--------------------+
|Review_gender|count(Review_gender)|
+-------------+--------------------+
|       female|                  66|
|       others|                  12|
|         male|                 105|
|      default|                  21|
+-------------+--------------------+

