Develop a PySpark script to clean and preprocess data before performing entity resolution.
Include steps like tokenization and normalization.

In [1]:
import pyspark
from pyspark import SparkContext
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.config("spark.driver.memory", "4g").appName("entityres").getOrCreate()



In [3]:
parsed = spark.read.option("header", "true").\
                    option("nullValue", "?").\
                    option("inferSchema", "true").\
                    option("recursiveFileLookup", "true").\
                    csv('linkage/block_*.csv')

In [4]:
parsed.show()

+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
| id_1| id_2|     cmp_fname_c1|cmp_fname_c2|cmp_lname_c1|cmp_lname_c2|cmp_sex|cmp_bd|cmp_bm|cmp_by|cmp_plz|is_match|
+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
|37291|53113|0.833333333333333|        null|         1.0|        null|      1|     1|     1|     1|      0|    true|
|39086|47614|              1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|70031|70237|              1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|84795|97439|              1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|36950|42116|              1.0|        null|         1.0|         1.0|      1|     1|     1|     1|      1|    true|
|42413|48491|              1.0|        null|         1.0|       

In [5]:
parsed.printSchema()

root
 |-- id_1: integer (nullable = true)
 |-- id_2: integer (nullable = true)
 |-- cmp_fname_c1: double (nullable = true)
 |-- cmp_fname_c2: double (nullable = true)
 |-- cmp_lname_c1: double (nullable = true)
 |-- cmp_lname_c2: double (nullable = true)
 |-- cmp_sex: integer (nullable = true)
 |-- cmp_bd: integer (nullable = true)
 |-- cmp_bm: integer (nullable = true)
 |-- cmp_by: integer (nullable = true)
 |-- cmp_plz: integer (nullable = true)
 |-- is_match: boolean (nullable = true)



In [6]:
parsed.count()

5749132

In [7]:
parsed.cache()

DataFrame[id_1: int, id_2: int, cmp_fname_c1: double, cmp_fname_c2: double, cmp_lname_c1: double, cmp_lname_c2: double, cmp_sex: int, cmp_bd: int, cmp_bm: int, cmp_by: int, cmp_plz: int, is_match: boolean]

In [8]:
from pyspark.sql.functions import col

In [9]:
parsed.groupBy('is_match').count().orderBy(col('count').desc()).show()

+--------+-------+
|is_match|  count|
+--------+-------+
|   false|5728201|
|    true|  20931|
+--------+-------+



In [10]:
parsed.createOrReplaceTempView('linkage')

In [11]:
spark.sql("""
    SELECT is_match, COUNT(*) cnt
    FROM linkage
    GROUP BY is_match
    ORDER BY cnt DESC
""").show()

+--------+-------+
|is_match|    cnt|
+--------+-------+
|   false|5728201|
|    true|  20931|
+--------+-------+



In [12]:
summary = parsed.describe()

In [13]:
summary.select("summary", "cmp_fname_c1", "cmp_fname_c2").show()

+-------+------------------+-------------------+
|summary|      cmp_fname_c1|       cmp_fname_c2|
+-------+------------------+-------------------+
|  count|           5748125|             103698|
|   mean|  0.71290247044295| 0.9000176718903214|
| stddev|0.3887583596162793|0.27131761057823345|
|    min|               0.0|                0.0|
|    max|               1.0|                1.0|
+-------+------------------+-------------------+



In [14]:
matches = parsed.where("is_match = true")
match_summary = matches.describe()

misses = parsed.filter(col("is_match") == False)
miss_summary = misses.describe()

In [15]:
summary_p = summary.toPandas()

In [16]:
summary_p.head()

Unnamed: 0,summary,id_1,id_2,cmp_fname_c1,cmp_fname_c2,cmp_lname_c1,cmp_lname_c2,cmp_sex,cmp_bd,cmp_bm,cmp_by,cmp_plz
0,count,5749132.0,5749132.0,5748125.0,103698.0,5749132.0,2464.0,5749132.0,5748337.0,5748337.0,5748337.0,5736289.0
1,mean,33324.48559643438,66587.43558331935,0.71290247044295,0.9000176718903214,0.3156278193075508,0.318412831531744,0.955001381078048,0.2244652670850717,0.488855298497635,0.2227485966810923,0.0055286614743434
2,stddev,23659.85937448807,23620.48761326976,0.3887583596162793,0.2713176105782334,0.3342336339615835,0.3685670662006653,0.2073011111689782,0.4172297223846248,0.4998758236779034,0.4160909629831756,0.0741491492542002
3,min,1.0,6.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,max,99980.0,100000.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [17]:
summary_p.shape

(5, 12)

In [18]:
summary_p = summary_p.set_index("summary").transpose().reset_index()
summary_p = summary_p.rename(columns={"index": "field"})
summary_p = summary_p.rename_axis(None, axis=1)

In [19]:
summary_p.shape

(11, 6)

In [20]:
summaryT = spark.createDataFrame(summary_p)

In [21]:
summaryT

DataFrame[field: string, count: string, mean: string, stddev: string, min: string, max: string]

In [22]:
summaryT.printSchema()

root
 |-- field: string (nullable = true)
 |-- count: string (nullable = true)
 |-- mean: string (nullable = true)
 |-- stddev: string (nullable = true)
 |-- min: string (nullable = true)
 |-- max: string (nullable = true)



In [23]:
from pyspark.sql.types import DoubleType

In [24]:
for c in summaryT.columns:
    if c == "field":
        continue
    summaryT = summaryT.withColumn(c, summaryT[c].cast(DoubleType()))

In [25]:
summaryT.printSchema()

root
 |-- field: string (nullable = true)
 |-- count: double (nullable = true)
 |-- mean: double (nullable = true)
 |-- stddev: double (nullable = true)
 |-- min: double (nullable = true)
 |-- max: double (nullable = true)



In [26]:
from pyspark.sql import DataFrame

In [27]:
def pivot_summary(desc):
    desc_p = desc.toPandas()
    
    desc_p = desc_p.set_index("summary").transpose().reset_index()
    desc_p = desc_p.rename(columns={"index": "field"})
    desc_p = desc_p.rename_axis(None, axis=1)
    
    descT = spark.createDataFrame(desc_p)
    
    for c in descT.columns:
        if c == 'field':
            continue
        descT = descT.withColumn(c, descT[c].cast(DoubleType()))
    
    return descT

In [28]:
match_summaryT = pivot_summary(match_summary)
miss_summaryT = pivot_summary(miss_summary)

In [29]:
match_summaryT.createOrReplaceTempView("match_desc")
miss_summaryT.createOrReplaceTempView("miss_desc")

spark.sql("""
    SELECT a.field, a.count + b.count total, a.mean - b.mean delta
    FROM match_desc a INNER JOIN miss_desc b ON a.field = b.field
    WHERE a.field NOT IN ("id_1", "id_2")
    ORDER BY delta DESC, total DESC
""")

DataFrame[field: string, total: double, delta: double]

In [30]:
good_features = ["cmp_lname_c1", "cmp_plz", "cmp_by", "cmp_bd", "cmp_bm"]

sum_expression = " + ".join(good_features)

sum_expression

'cmp_lname_c1 + cmp_plz + cmp_by + cmp_bd + cmp_bm'

In [31]:
from pyspark.sql.functions import expr

In [32]:
scored = parsed.fillna(0, subset=good_features).\
                withColumn('score', expr(sum_expression)).\
                select('score', 'is_match')

In [33]:
scored.show()

+-----+--------+
|score|is_match|
+-----+--------+
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
+-----+--------+
only showing top 20 rows



In [34]:
def crossTabs(scored, t):
    return scored.selectExpr(f"score >= {t} as above", "is_match").\
                    groupBy("above").pivot("is_match", ("true", "false")).\
                    count()

In [35]:
crossTabs(scored, 4.0).show()

+-----+-----+-------+
|above| true|  false|
+-----+-----+-------+
| true|20871|    637|
|false|   60|5727564|
+-----+-----+-------+



In [36]:
crossTabs(scored, 2.0).show()

+-----+-----+-------+
|above| true|  false|
+-----+-----+-------+
| true|20931| 596414|
|false| null|5131787|
+-----+-----+-------+



In [37]:
def print_metrics(confusion_matrix):
    tp = confusion_matrix[0][0]
    tn = confusion_matrix[1][1]
    fp = confusion_matrix[1][0]
    fn = confusion_matrix[0][1]
    
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    
    f1_score = (2 * precision * recall) / (precision + recall)
    
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"F1 Score: {f1_score}")

In [38]:
confusion_matrix = crossTabs(scored, 4.0).collect()

In [39]:
print_metrics(confusion_matrix)

Precision: 1.0
Recall: 4.7911077041011885e-05
F1 Score: 9.581756335936378e-05
