In [1]:
/*
* Author: Mike Urciuoli
* email: urciuolim@gmail.com
* description: The purpose of this notebook is to show how to clean the IRS SOI Tax Stats (county level) for
*              usage in a county prediction model. The overall goal of the project is to predict some feature
*              of a county based on other input features. The dataset is split into files by year, and each year has
*              each county split between at most 8 rows defined by "agi_stub" (or the size of the earner counting towards
*              that row). Ex) agi_stub "1" includes those that earned less than $1, which is possible since AGI accounts
*              for certain deficits. Also AGI is represented in thousands (1 in the dataset = $1000) so that is worth noting,
*              and I'll adjust that number when reading it in.
*/

import org.apache.spark.sql.functions.{udf, input_file_name, lpad}
val extractYear = udf((filename: String) => "20" + filename.split("/").last.substring(0,2))
// Lpad'ing each state/county fips is needed. Sometimes those values are not included with their leading zeros, so put them in here.
val irsRaw = spark.read.option("header", "true").csv("s3://agimodeltrainer/DATA/IRS").
    filter(lpad($"COUNTYFIPS", 3, "0") =!= "000").
    select(
        concat(lpad($"STATEFIPS", 2, "0"), lpad($"COUNTYFIPS", 3, "0")).as("FIPS"),
        extractYear(input_file_name).as("Year").cast("Int"),
        $"agi_stub".as("AGI Size").cast("Int"), 
        $"N1".as("Number_of_Returns").cast("Long"),
        $"A00100".as("AGI").cast("Long")
    )
irsRaw.printSchema
irsRaw.show

VBox()

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
11,application_1598467819421_0012,spark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

import org.apache.spark.sql.functions.{udf, input_file_name, lpad}
extractYear: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))
irsRaw: org.apache.spark.sql.DataFrame = [FIPS: string, Year: int ... 3 more fields]
root
 |-- FIPS: string (nullable = true)
 |-- Year: integer (nullable = true)
 |-- AGI Size: integer (nullable = true)
 |-- Number_of_Returns: long (nullable = true)
 |-- AGI: long (nullable = true)

+-----+----+--------+-----------------+-------+
| FIPS|Year|AGI Size|Number_of_Returns|    AGI|
+-----+----+--------+-----------------+-------+
|01001|2017|       1|              200|  -6342|
|01001|2017|       2|             3050|  16376|
|01001|2017|       3|             5340|  91181|
|01001|2017|       4|             5620| 203737|
|01001|2017|       5|             3540| 217836|
|01001|2017|       6|             2520| 218722|
|01001|2017|       7|             3400| 447884|
|01001|2017|       8|            

In [2]:
// Simple map-reduce style way to calculate total number of returns / AGI for each county
val irsFinal = irsRaw.groupBy("FIPS", "Year").
    agg(
        sum("Number_of_returns").as("Number_of_returns"),
        sum("AGI").as("Total_AGI")
    )
irsFinal.printSchema
irsFinal.show

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

irsFinal: org.apache.spark.sql.DataFrame = [FIPS: string, Year: int ... 2 more fields]
root
 |-- FIPS: string (nullable = true)
 |-- Year: integer (nullable = true)
 |-- Number_of_returns: long (nullable = true)
 |-- Total_AGI: long (nullable = true)

+-----+----+-----------------+---------+
| FIPS|Year|Number_of_returns|Total_AGI|
+-----+----+-----------------+---------+
|13219|2017|            17540|  1942388|
|13301|2017|             2230|    80577|
|16021|2017|             5200|   247296|
|25007|2017|            11460|   897259|
|26021|2017|            73420|  4191570|
|46007|2017|             1150|    36890|
|40005|2017|             5170|   218668|
|54103|2017|             6750|   363112|
|13075|2017|             6570|   256600|
|29211|2017|             2680|    97616|
|39077|2017|            29040|  1400002|
|16051|2017|            11260|   602747|
|27015|2017|            13210|   738172|
|31023|2017|             3920|   206565|
|37187|2017|             5070|   200554|
|55021|201

In [3]:
// Check for null values
irsFinal.count
irsFinal.na.drop.count
// Check to make sure number of counties is in the ballpark
irsFinal.select("FIPS").distinct.count
irsFinal.select("Year").distinct.count
irsFinal.describe().show
// 

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

res10: Long = 21990
res11: Long = 21990
res13: Long = 3144
res14: Long = 7
+-------+------------------+------------------+------------------+------------------+
|summary|              FIPS|              Year| Number_of_returns|         Total_AGI|
+-------+------------------+------------------+------------------+------------------+
|  count|             21990|             21990|             21990|             21990|
|   mean| 30391.44433833561|2013.9997271487039|  46938.7105047749|1435326.4655752615|
| stddev|15161.316170557868| 2.000068196150814|151970.00799021602| 8353992.244544055|
|    min|             01001|              2011|                 0|                 0|
|    max|             56045|              2017|           4699360|         364557900|
+-------+------------------+------------------+------------------+------------------+



In [4]:
irsFinal.write.parquet("s3://agimodeltrainer/Clean_Data/IRS_SOI_County.parquet")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [5]:
// Quick test to show that data was written successfully.
val test = spark.read.parquet("s3://agimodeltrainer/Clean_Data/IRS_SOI_County.parquet")
test.printSchema
test.show
test.count
test.select("FIPS").distinct.count
test.select("Year").distinct.count

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

test: org.apache.spark.sql.DataFrame = [FIPS: string, Year: int ... 2 more fields]
root
 |-- FIPS: string (nullable = true)
 |-- Year: integer (nullable = true)
 |-- Number_of_returns: long (nullable = true)
 |-- Total_AGI: long (nullable = true)

+-----+----+-----------------+---------+
| FIPS|Year|Number_of_returns|Total_AGI|
+-----+----+-----------------+---------+
|22111|2017|             8750|   424074|
|29051|2017|            36620|  2262800|
|45047|2017|            29520|  1428313|
|56003|2017|             4940|   227308|
|16035|2017|             3460|   167195|
|30087|2017|             3850|   203728|
|20119|2017|             1930|   105747|
|37017|2017|            12440|   502354|
|40033|2017|             2380|    99098|
|42049|2017|           128200|  6771920|
|47009|2017|            60960|  3559702|
|02060|2017|              480|    27648|
|21003|2017|             7960|   344395|
|51670|2017|            10570|   408769|
|13085|2017|            10930|   704852|
|31135|2017|  

In [6]:
test.filter($"Total_AGI" === 0).count

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

res25: Long = 66
