In [None]:
!pip install pyspark

Collecting pyspark
  Downloading pyspark-3.5.1.tar.gz (317.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.0/317.0 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.5.1-py2.py3-none-any.whl size=317488491 sha256=c34c0432165ce3776b488c4c93a388b3ceaebb537cdee3491bbe8d74e44eecc6
  Stored in directory: /root/.cache/pip/wheels/80/1d/60/2c256ed38dddce2fdd93be545214a63e02fbd8d74fb0b7f3a6
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.1


In [None]:
!mkdir linkage
!curl -L -o donation.zip https://bit.ly/1Aoywaq
!unzip -q donation.zip
!unzip 'block_*.zip'
!mv block_*.csv linkage/
!rm block_*.zip

In [None]:
import pyspark
from pyspark import SparkContext
from pyspark.sql import SparkSession

In [None]:
spark = SparkSession.builder.appName('Entity Resolution').getOrCreate()

In [None]:
df = spark.read.csv('./linkage/block_1.csv', header=True, inferSchema=True, nullValue='?')

df.show(5)

+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
| id_1| id_2|     cmp_fname_c1|cmp_fname_c2|cmp_lname_c1|cmp_lname_c2|cmp_sex|cmp_bd|cmp_bm|cmp_by|cmp_plz|is_match|
+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
|37291|53113|0.833333333333333|        NULL|         1.0|        NULL|      1|     1|     1|     1|      0|    true|
|39086|47614|              1.0|        NULL|         1.0|        NULL|      1|     1|     1|     1|      1|    true|
|70031|70237|              1.0|        NULL|         1.0|        NULL|      1|     1|     1|     1|      1|    true|
|84795|97439|              1.0|        NULL|         1.0|        NULL|      1|     1|     1|     1|      1|    true|
|36950|42116|              1.0|        NULL|         1.0|         1.0|      1|     1|     1|     1|      1|    true|
+-----+-----+-----------------+------------+------------+-------

In [None]:
df.printSchema()

root
 |-- id_1: integer (nullable = true)
 |-- id_2: integer (nullable = true)
 |-- cmp_fname_c1: double (nullable = true)
 |-- cmp_fname_c2: double (nullable = true)
 |-- cmp_lname_c1: double (nullable = true)
 |-- cmp_lname_c2: double (nullable = true)
 |-- cmp_sex: integer (nullable = true)
 |-- cmp_bd: integer (nullable = true)
 |-- cmp_bm: integer (nullable = true)
 |-- cmp_by: integer (nullable = true)
 |-- cmp_plz: integer (nullable = true)
 |-- is_match: boolean (nullable = true)



In [None]:
df.count()

574913

In [None]:
from pyspark.sql.functions import col

df.groupby('is_match').count().orderBy(col('count'), ascending=False).show()

+--------+------+
|is_match| count|
+--------+------+
|   false|572820|
|    true|  2093|
+--------+------+



In [None]:
df.createOrReplaceTempView('linkage')

In [None]:
spark.sql("""
  SELECT is_match, COUNT(*) count
  FROM linkage
  GROUP BY is_match
  ORDER BY count DESC
""").show()

+--------+------+
|is_match| count|
+--------+------+
|   false|572820|
|    true|  2093|
+--------+------+



In [None]:
summary = df.describe()

summary.select('summary', 'cmp_fname_c1', 'cmp_fname_c2').show()

+-------+------------------+------------------+
|summary|      cmp_fname_c1|      cmp_fname_c2|
+-------+------------------+------------------+
|  count|            574811|             10325|
|   mean|0.7127592938252765|0.8977586763518972|
| stddev|0.3889286452463553|0.2742577520430534|
|    min|               0.0|               0.0|
|    max|               1.0|               1.0|
+-------+------------------+------------------+



In [None]:
matches = df.where('is_match = true')
match_summary = matches.describe()

misses = df.filter(col('is_match') == False)
miss_summary = misses.describe()

In [None]:
summary_p = summary.toPandas()

summary_p.head()

Unnamed: 0,summary,id_1,id_2,cmp_fname_c1,cmp_fname_c2,cmp_lname_c1,cmp_lname_c2,cmp_sex,cmp_bd,cmp_bm,cmp_by,cmp_plz
0,count,574913.0,574913.0,574811.0,10325.0,574913.0,239.0,574913.0,574851.0,574851.0,574851.0,573618.0
1,mean,33271.962171667714,66564.6636865056,0.7127592938252765,0.8977586763518972,0.3155724578098796,0.3269155414552906,0.9550923357099248,0.224755632329073,0.4886361857246487,0.2226663952919974,0.0054949461139643
2,stddev,23622.66942593358,23642.00230967225,0.3889286452463553,0.2742577520430534,0.3342494687554251,0.378309202054067,0.207101522405044,0.4174216587235586,0.4998712818281627,0.416036504164562,0.073924023213019
3,min,1.0,6.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,max,99894.0,100000.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [None]:
summary_p.shape

(5, 12)

In [None]:
summary_p = summary_p.set_index('summary').transpose().reset_index()

summary_p.head()

summary,index,count,mean,stddev,min,max
0,id_1,574913,33271.962171667714,23622.66942593358,1.0,99894.0
1,id_2,574913,66564.6636865056,23642.00230967225,6.0,100000.0
2,cmp_fname_c1,574811,0.7127592938252765,0.3889286452463553,0.0,1.0
3,cmp_fname_c2,10325,0.8977586763518972,0.2742577520430534,0.0,1.0
4,cmp_lname_c1,574913,0.3155724578098796,0.3342494687554251,0.0,1.0


In [None]:
summary_p = summary_p.rename(columns={'index': 'field'})

summary_p.head()

summary,field,count,mean,stddev,min,max
0,id_1,574913,33271.962171667714,23622.66942593358,1.0,99894.0
1,id_2,574913,66564.6636865056,23642.00230967225,6.0,100000.0
2,cmp_fname_c1,574811,0.7127592938252765,0.3889286452463553,0.0,1.0
3,cmp_fname_c2,10325,0.8977586763518972,0.2742577520430534,0.0,1.0
4,cmp_lname_c1,574913,0.3155724578098796,0.3342494687554251,0.0,1.0


In [None]:
summary_p = summary_p.rename_axis(None, axis=1)

summary_p.head()

Unnamed: 0,field,count,mean,stddev,min,max
0,id_1,574913,33271.962171667714,23622.66942593358,1.0,99894.0
1,id_2,574913,66564.6636865056,23642.00230967225,6.0,100000.0
2,cmp_fname_c1,574811,0.7127592938252765,0.3889286452463553,0.0,1.0
3,cmp_fname_c2,10325,0.8977586763518972,0.2742577520430534,0.0,1.0
4,cmp_lname_c1,574913,0.3155724578098796,0.3342494687554251,0.0,1.0


In [None]:
summary_T = spark.createDataFrame(summary_p)

summary_T.show(5)

+------------+------+-------------------+------------------+---+------+
|       field| count|               mean|            stddev|min|   max|
+------------+------+-------------------+------------------+---+------+
|        id_1|574913| 33271.962171667714| 23622.66942593358|  1| 99894|
|        id_2|574913|   66564.6636865056| 23642.00230967225|  6|100000|
|cmp_fname_c1|574811| 0.7127592938252765|0.3889286452463553|0.0|   1.0|
|cmp_fname_c2| 10325| 0.8977586763518972|0.2742577520430534|0.0|   1.0|
|cmp_lname_c1|574913|0.31557245780987964|0.3342494687554251|0.0|   1.0|
+------------+------+-------------------+------------------+---+------+
only showing top 5 rows



In [None]:
summary_T.printSchema()

root
 |-- field: string (nullable = true)
 |-- count: string (nullable = true)
 |-- mean: string (nullable = true)
 |-- stddev: string (nullable = true)
 |-- min: string (nullable = true)
 |-- max: string (nullable = true)



In [None]:
from pyspark.sql.types import DoubleType

for c in summary_T.columns:
  if c == 'field':
    continue

  summary_T = summary_T.withColumn(c, summary_T[c].cast(DoubleType()))

summary_T.printSchema()

root
 |-- field: string (nullable = true)
 |-- count: double (nullable = true)
 |-- mean: double (nullable = true)
 |-- stddev: double (nullable = true)
 |-- min: double (nullable = true)
 |-- max: double (nullable = true)



In [None]:
def pivot_summary(desc):
  desc_p = desc.toPandas()

  desc_p = desc_p.set_index('summary').transpose().reset_index()
  desc_p = desc_p.rename(columns={'index': 'field'})
  desc_p = desc_p.rename_axis(None, axis=1)

  desc_T = spark.createDataFrame(desc_p)

  for c in desc_T.columns:
    if c == 'field':
      continue

    desc_T = desc_T.withColumn(c, desc_T[c].cast(DoubleType()))

  return desc_T

In [None]:
match_summary_T = pivot_summary(match_summary)
miss_summary_T = pivot_summary(miss_summary)

In [None]:
match_summary_T.createOrReplaceTempView('match_desc')
miss_summary_T.createOrReplaceTempView('miss_desc')

spark.sql("""
  SELECT a.field, a.count + b.count total, a.mean - b.mean delta
  FROM match_desc a INNER JOIN miss_desc b ON a.field = b.field
  WHERE a.field NOT IN ("id_1", "id_2")
  ORDER BY delta DESC, total DESC
""").show()

+------------+--------+-------------------+
|       field|   total|              delta|
+------------+--------+-------------------+
|     cmp_plz|573618.0| 0.9524975516429005|
|cmp_lname_c2|   239.0| 0.8136949970410104|
|      cmp_by|574851.0| 0.7763379425859384|
|      cmp_bd|574851.0| 0.7732820129086737|
|cmp_lname_c1|574913.0| 0.6844795197263095|
|      cmp_bm|574851.0|  0.510834819548174|
|cmp_fname_c1|574811.0| 0.2853115682852544|
|cmp_fname_c2| 10325.0|0.09900440489032625|
|     cmp_sex|574913.0|0.03452211590529575|
+------------+--------+-------------------+



In [None]:
good_features = ['cmp_lname_c1', 'cmp_plz', 'cmp_by', 'cmp_bd', 'cmp_bm']

sum_expression = ' + '.join(good_features)

print(sum_expression)

cmp_lname_c1 + cmp_plz + cmp_by + cmp_bd + cmp_bm


In [None]:
from pyspark.sql.functions import expr

scored = df.fillna(
    0, subset=good_features
).withColumn(
    'score', expr(sum_expression)
).select('score', 'is_match')

scored.show()

+-----+--------+
|score|is_match|
+-----+--------+
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
+-----+--------+
only showing top 20 rows



In [None]:
from pyspark.sql import DataFrame

def crossTabs(scored: DataFrame, t: DoubleType) -> DataFrame:
  return scored.selectExpr(
      f'score >= {t} as above', 'is_match'
  ).groupBy('above').pivot('is_match', ('true', 'false')).count()

In [None]:
crossTabs(scored, 4.0).show()

+-----+----+------+
|above|true| false|
+-----+----+------+
| true|2087|    66|
|false|   6|572754|
+-----+----+------+



In [None]:
crossTabs(scored, 2.0).show()

+-----+----+------+
|above|true| false|
+-----+----+------+
| true|2093| 59729|
|false|NULL|513091|
+-----+----+------+

