# CSCI 4253 / 5253 - Lab #4 - Patent Problem with Spark DataFrames
<div>
 <h2> CSCI 4283 / 5253 
  <IMG SRC="https://www.colorado.edu/cs/profiles/express/themes/cuspirit/logo.png" WIDTH=50 ALIGN="right"/> </h2>
</div>

This [Spark cheatsheet](https://s3.amazonaws.com/assets.datacamp.com/blog_assets/PySpark_SQL_Cheat_Sheet_Python.pdf) is useful as is [this reference on doing joins in Spark dataframe](http://www.learnbymarketing.com/1100/pyspark-joins-by-example/).

The [DataBricks company has one of the better reference manuals for PySpark](https://docs.databricks.com/spark/latest/dataframes-datasets/index.html) -- they show you how to perform numerous common data operations such as joins, aggregation operations following `groupBy` and the like.

In [2]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession

The following aggregation functions may be useful -- [these can be used to aggregate results of `groupby` operations](https://docs.databricks.com/spark/latest/dataframes-datasets/introduction-to-dataframes-python.html#example-aggregations-using-agg-and-countdistinct). More documentation is at the [PySpark SQL Functions manual](https://spark.apache.org/docs/2.3.0/api/python/pyspark.sql.html#module-pyspark.sql.functions). Feel free to use other functions from that library.

In [3]:
from pyspark.sql.functions import col, count, countDistinct

Create our session as described in the tutorials

In [4]:
spark = SparkSession \
    .builder \
    .appName("Lab4-Dataframe") \
    .master("local[*]")\
    .getOrCreate()

Read in the citations and patents data and check that the data makes sense. Note that unlike in the RDD solution, the data is automatically inferred to be Integer() types.

In [5]:
citations = spark.read.load('cite75_99.txt.gz',
            format="csv", sep=",", header=True,
            compression="gzip",
            inferSchema="true")

In [6]:
citations.show(5)

+-------+-------+
| CITING|  CITED|
+-------+-------+
|3858241| 956203|
|3858241|1324234|
|3858241|3398406|
|3858241|3557384|
|3858241|3634889|
+-------+-------+
only showing top 5 rows



In [7]:
from pyspark.sql import functions as F

# Keep US patents with a real state and alias POSTATE -> STATE for joining
# create data frame that stores columns patent, country, and POSTATE
patents_us = (
    patents.select(
        F.col("PATENT").cast("long").alias("PATENT"),
        F.col("COUNTRY"),
        F.col("POSTATE").alias("STATE")
    )
    .withColumn("STATE", F.when(F.length("STATE")==0, None).otherwise(F.col("STATE")))
    .filter((F.col("COUNTRY")=="US") & F.col("STATE").isNotNull())
)

# Build lookup for the cited side: CITED patent id -> CITED_STATE
cited_states  = patents_us.select(F.col("PATENT").alias("CITED"),  F.col("STATE").alias("CITED_STATE"))
# Build lookup for the citing side: CITING patent id -> CITING_STATE
citing_states = patents_us.select(F.col("PATENT").alias("CITING"), F.col("STATE").alias("CITING_STATE"))

# Join the citation pairs with both lookups(above) so each row has both states
joined = (
    citations
    .join(cited_states,  on="CITED",  how="left")
    .join(citing_states, on="CITING", how="left")
)

# Compute a bool flag to represent if matching states
same_counts = (
    joined
    # Flag as 1 only when both states exist and are equal, otherwise 0
    .withColumn(
        "is_same",
        F.when(
            (F.col("CITING_STATE").isNotNull()) &
            (F.col("CITED_STATE").isNotNull()) &
            (F.col("CITING_STATE")==F.col("CITED_STATE")), 1
        ).otherwise(0)
    )
    .groupBy("CITING")
    .agg(F.sum("is_same").cast("int").alias("same_state_citations"))
)


NameError: name 'patents' is not defined

In [None]:
# Rename for a clean join and desired output column name
counts_for_join = same_counts.select(
    F.col("CITING").alias("PATENT"),
    F.col("same_state_citations").cast("int").alias("SAME_STATE")
)

# Left-join onto the full patents data frame, fill missing with 0
patents_with_same = (
    patents
    .join(counts_for_join, on="PATENT", how="left")
    .fillna({"SAME_STATE": 0})
)

# Reorder columns so SAME_STATE is last, like the screenshot
cols_in_order = patents.columns + ["SAME_STATE"]
patents_with_same_ordered = patents_with_same.select(*cols_in_order)


In [None]:
# Sort patents by SAME_STATE (highest first), and break ties by PATENT id (lowest first)
top10 = (
    patents_with_same_ordered
    .orderBy(F.col("SAME_STATE").desc(), F.col("PATENT").asc())
    .limit(10)
)

top10.show(truncate=False)
