In [19]:
import pyspark
from pyspark.sql.types import StructField, StructType, StringType, LongType
from pyspark.sql import SparkSession
from py4j.protocol import Py4JJavaError

# Dataframe example
- This is about reading data from BQ. Notice that when we load the data through the method described below, no query is issued via BQ.
- Alternatively you can also read the data by issuing query to BQ directly, it will create a temp table under the materializationDataset, and thus be loaded into Spark for more information read https://github.com/GoogleCloudDataproc/spark-bigquery-connector.

In [20]:
# Create a two column schema consisting of a string and a long integer
schema = StructType(
    [
        StructField("subreddit", StringType(), True),
        StructField("count", LongType(), True)
    ]
)

# Create an empty DataFrame. We will continuously union our output with this
subreddit_counts = spark.createDataFrame([], schema)

# By default, SparkSession will load the data as a DataFrame
table_df = spark.read.format('bigquery').option('table', "fh-bigquery.reddit_posts.2017_01").load()
type(table_df)

pyspark.sql.dataframe.DataFrame

In [21]:
# Establish a set of years and months to iterate over
years = ['2017']
months = ['01', '02']

# Keep track of all tables accessed via the job
tables_read = []

for year in years:
    for month in months:

        # In the form of <project-id>.<dataset>.<table>
        table = f"fh-bigquery.reddit_posts.{year}_{month}"

        # If the table doesn't exist we will simply continue and not
        # log it into our "tables_read" list
        try:
            table_df = spark.read.format('bigquery').option('table', table).load()
            tables_read.append(table)
        except Py4JJavaError as e:
            if f"Table {table} not found" in str(e):
                continue
            else:
                raise

        # We perform a group-by on subreddit, aggregating by the count and then
        # unioning the output to our base dataframe
        subreddit_counts = (
            table_df
            .groupBy("subreddit")
            .count()
            .union(subreddit_counts)
        )

In [22]:
print("The following list of tables will be accounted for in our analysis:")
for table in tables_read:
    print(table)

# From our base table, we perform a group-by, summing over the counts.
# We then rename the column and sort in descending order both for readability.
# show() will collect the table into memory output the table to std out.
(
    subreddit_counts
    .groupBy("subreddit")
    .sum("count")
    .withColumnRenamed("sum(count)", "count")
    .sort("count", ascending=False)
    .show()
)

The following list of tables will be accounted for in our analysis:
fh-bigquery.reddit_posts.2017_01
fh-bigquery.reddit_posts.2017_02




+--------------------+------+
|           subreddit| count|
+--------------------+------+
|           AskReddit|489877|
|RocketLeagueExchange|394825|
|          The_Donald|358716|
|       AutoNewspaper|324971|
|GlobalOffensiveTrade|224767|
|                news|144650|
|              videos|143173|
|      Showerthoughts|138401|
|               funny|132761|
|              gaming| 97418|
|           Overwatch| 96623|
|            politics| 93481|
|                 aww| 82689|
|                pics| 79858|
|     leagueoflegends| 78422|
|        dirtykikpals| 75812|
|              me_irl| 74686|
| GlobalParadigmShift| 70842|
|           worldnews| 67100|
|        dirtypenpals| 64638|
+--------------------+------+
only showing top 20 rows



                                                                                

# RDD example
This is about reading a text file as unstructured data from GCS

In [23]:
inputUri='gs://pub/shakespeare/rose.txt'
outputUri='gs://dataproc-staging-us-central1-712368347106-boh5iflc/misc/output_folder'

SparkContext can be created from SparkSession

In [24]:
def read_text(inputUri, entrypoint):
      
    if entrypoint == sc:
        lines = sc.textFile(inputUri)
    elif entrypoint == spark:
        lines = spark.sparkContext.textFile(inputUri)
    
    words = lines.flatMap(lambda line: line.split())
    wordCounts = words.map(lambda word: (word, 1)).reduceByKey(lambda count1, count2: count1 + count2)
    return wordCounts

wordCounts = read_text(inputUri, sc)

For RDD, the data must be collected to the driver with the collect() method for it to be meaningful. This will transform the data into list object

In [25]:
print(f"Printing wordCounts as RDD:\n{wordCounts}\n\nPrinting wordCounts after it as been collected to the driver:\n{wordCounts.collect()}")



Printing wordCounts as RDD:
PythonRDD[59] at RDD at PythonRDD.scala:53

Printing wordCounts after it as been collected to the driver:
[('other', 1), ('name', 1), ('would', 1), ('smell', 1), ('as', 1), ('sweet.', 1), ("What's", 1), ('in', 1), ('name?', 1), ('That', 1), ('we', 1), ('call', 1), ('rose', 1), ('a', 2), ('which', 1), ('By', 1), ('any', 1)]


                                                                                

The RDD will be saved into multiple files

In [28]:
# Save
wordCounts.saveAsTextFile(outputUri)

# Read
rdd_output = spark.sparkContext.textFile(outputUri)
print(type(rdd_output))

                                                                                

<class 'pyspark.rdd.RDD'>
