In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql import DataFrame
from pyspark.sql.types import DoubleType

In [2]:
spark = SparkSession.builder.appName("linkage").getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [3]:
parsed = spark.read.option("header", "true").option("nullValue", "?").option("inferSchema", "true").csv("datasets/donations")

                                                                                

In [4]:
parsed.show()

+-----+-----+------------+------------+------------+------------+-------+------+------+------+-------+--------+
| id_1| id_2|cmp_fname_c1|cmp_fname_c2|cmp_lname_c1|cmp_lname_c2|cmp_sex|cmp_bd|cmp_bm|cmp_by|cmp_plz|is_match|
+-----+-----+------------+------------+------------+------------+-------+------+------+------+-------+--------+
| 3148| 8326|         1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|14055|94934|         1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|33948|34740|         1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|  946|71870|         1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|64880|71676|         1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|25739|45991|         1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|  

In [28]:
parsed.printSchema()

root
 |-- id_1: integer (nullable = true)
 |-- id_2: integer (nullable = true)
 |-- cmp_fname_c1: double (nullable = true)
 |-- cmp_fname_c2: double (nullable = true)
 |-- cmp_lname_c1: double (nullable = true)
 |-- cmp_lname_c2: double (nullable = true)
 |-- cmp_sex: integer (nullable = true)
 |-- cmp_bd: integer (nullable = true)
 |-- cmp_bm: integer (nullable = true)
 |-- cmp_by: integer (nullable = true)
 |-- cmp_plz: integer (nullable = true)
 |-- is_match: boolean (nullable = true)



In [29]:
parsed.select("cmp_fname_c1").count()

                                                                                

5749132

In [30]:
parsed.select("cmp_fname_c1").where(col("cmp_fname_c2").isNotNull()).count()

                                                                                

103698

In [37]:
parsed.select(col("is_match")).rdd.countByValue()

                                                                                

defaultdict(int, {Row(is_match=True): 20931, Row(is_match=False): 5728201})

In [52]:
parsed.groupBy("is_match").count().orderBy(col("count").desc()).show()



+--------+-------+
|is_match|  count|
+--------+-------+
|   false|5728201|
|    true|  20931|
+--------+-------+



                                                                                

In [39]:
parsed.agg(avg("cmp_sex"), stddev("cmp_sex")).show()



+-----------------+--------------------+
|     avg(cmp_sex)|stddev_samp(cmp_sex)|
+-----------------+--------------------+
|0.955001381078048| 0.20730111116897781|
+-----------------+--------------------+



                                                                                

In [5]:
parsed.createOrReplaceTempView("linkage")

In [6]:
spark.sql("""SELECT is_match, count(*) cnt from linkage group by is_match order by cnt DESC""").show()



+--------+-------+
|is_match|    cnt|
+--------+-------+
|   false|5728201|
|    true|  20931|
+--------+-------+



                                                                                

In [7]:
summary = parsed.describe()

                                                                                

In [8]:
summary.select("summary", "cmp_fname_c1", "cmp_fname_c2", "cmp_sex").show()

+-------+------------------+------------------+-------------------+
|summary|      cmp_fname_c1|      cmp_fname_c2|            cmp_sex|
+-------+------------------+------------------+-------------------+
|  count|           5748125|            103698|            5749132|
|   mean|0.7129024704437267|0.9000176718903189|  0.955001381078048|
| stddev|0.3887583596162802|0.2713176105782334|0.20730111116897781|
|    min|               0.0|               0.0|                  0|
|    max|               1.0|               1.0|                  1|
+-------+------------------+------------------+-------------------+



In [9]:
matches = parsed.where(col("is_match") == True)
matchSummary = matches.describe()

                                                                                

In [10]:
matchSummary.show()

+-------+------------------+-----------------+-------------------+-------------------+--------------------+-------------------+-------------------+--------------------+--------------------+-------------------+-------------------+
|summary|              id_1|             id_2|       cmp_fname_c1|       cmp_fname_c2|        cmp_lname_c1|       cmp_lname_c2|            cmp_sex|              cmp_bd|              cmp_bm|             cmp_by|            cmp_plz|
+-------+------------------+-----------------+-------------------+-------------------+--------------------+-------------------+-------------------+--------------------+--------------------+-------------------+-------------------+
|  count|             20931|            20931|              20922|               1333|               20931|                475|              20931|               20925|               20925|              20925|              20902|
|   mean| 34575.72117911232|51259.95939037791| 0.9973163859635039| 0.98989003203

In [11]:
misses = parsed.where("is_match = False")
missSummary = misses.describe()

                                                                                

In [12]:
missSummary.show()

+-------+------------------+------------------+-------------------+-------------------+------------------+-------------------+-------------------+------------------+------------------+------------------+--------------------+
|summary|              id_1|              id_2|       cmp_fname_c1|       cmp_fname_c2|      cmp_lname_c1|       cmp_lname_c2|            cmp_sex|            cmp_bd|            cmp_bm|            cmp_by|             cmp_plz|
+-------+------------------+------------------+-------------------+-------------------+------------------+-------------------+-------------------+------------------+------------------+------------------+--------------------+
|  count|           5728201|           5728201|            5727203|             102365|           5728201|               1989|            5728201|           5727412|           5727412|           5727412|             5715387|
|   mean|33319.913548075565| 66643.44259218557|  0.711863480217509| 0.8988473514090156|0.31313801133

In [70]:
schema = summary.schema

In [96]:
schema[0].name

'summary'

In [189]:
summary.columns

['summary',
 'id_1',
 'id_2',
 'cmp_fname_c1',
 'cmp_fname_c2',
 'cmp_lname_c1',
 'cmp_lname_c2',
 'cmp_sex',
 'cmp_bd',
 'cmp_bm',
 'cmp_by',
 'cmp_plz']

In [199]:
summary.select(summary['id_1']).show()

+------------------+
|              id_1|
+------------------+
|           5749132|
| 33324.48559643438|
|23659.859374488064|
|                 1|
|             99980|
+------------------+



In [13]:
summary_p = summary.toPandas()

In [14]:
summary_p

Unnamed: 0,summary,id_1,id_2,cmp_fname_c1,cmp_fname_c2,cmp_lname_c1,cmp_lname_c2,cmp_sex,cmp_bd,cmp_bm,cmp_by,cmp_plz
0,count,5749132.0,5749132.0,5748125.0,103698.0,5749132.0,2464.0,5749132.0,5748337.0,5748337.0,5748337.0,5736289.0
1,mean,33324.48559643438,66587.43558331935,0.7129024704437267,0.9000176718903189,0.3156278193080382,0.3184128315317443,0.955001381078048,0.2244652670850717,0.488855298497635,0.2227485966810923,0.0055286614743434
2,stddev,23659.859374488064,23620.48761326969,0.3887583596162802,0.2713176105782334,0.3342336339615828,0.3685670662006654,0.2073011111689778,0.4172297223846263,0.4998758236779031,0.4160909629831755,0.0741491492542004
3,min,1.0,6.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,max,99980.0,100000.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [174]:
summary_pT = summary_p.set_index('summary').T.reset_index()

In [175]:
summary_pT = summary_pT.rename(columns={'index':'field'})

In [176]:
summary_pT = summary_pT.rename_axis(None, axis=1)

In [177]:
summary_pT

Unnamed: 0,field,count,mean,stddev,min,max
0,id_1,5749132,33324.48559643438,23659.859374488064,1.0,99980.0
1,id_2,5749132,66587.43558331935,23620.4876132697,6.0,100000.0
2,cmp_fname_c1,5748125,0.7129024704437266,0.3887583596162802,0.0,1.0
3,cmp_fname_c2,103698,0.9000176718903189,0.2713176105782334,0.0,1.0
4,cmp_lname_c1,5749132,0.3156278193080382,0.3342336339615828,0.0,1.0
5,cmp_lname_c2,2464,0.3184128315317442,0.3685670662006654,0.0,1.0
6,cmp_sex,5749132,0.955001381078048,0.2073011111689778,0.0,1.0
7,cmp_bd,5748337,0.2244652670850717,0.4172297223846263,0.0,1.0
8,cmp_bm,5748337,0.488855298497635,0.499875823677903,0.0,1.0
9,cmp_by,5748337,0.2227485966810923,0.4160909629831755,0.0,1.0


In [186]:
summaryT = spark.createDataFrame(summary_pT)

In [202]:
summaryT.count()

11

In [192]:
summaryT.printSchema()

root
 |-- field: string (nullable = true)
 |-- count: string (nullable = true)
 |-- mean: string (nullable = true)
 |-- stddev: string (nullable = true)
 |-- min: string (nullable = true)
 |-- max: string (nullable = true)



In [207]:
for c in summaryT.columns:
    if c == "field":
        continue
    summaryT = summaryT.withColumn(c, summaryT[c].cast(DoubleType()))

In [208]:
summaryT.show()

+------------+---------+-------------------+-------------------+---+--------+
|       field|    count|               mean|             stddev|min|     max|
+------------+---------+-------------------+-------------------+---+--------+
|        id_1|5749132.0|  33324.48559643438| 23659.859374488064|1.0| 99980.0|
|        id_2|5749132.0|  66587.43558331935|   23620.4876132697|6.0|100000.0|
|cmp_fname_c1|5748125.0| 0.7129024704437266| 0.3887583596162802|0.0|     1.0|
|cmp_fname_c2| 103698.0| 0.9000176718903189| 0.2713176105782334|0.0|     1.0|
|cmp_lname_c1|5749132.0|0.31562781930803824| 0.3342336339615828|0.0|     1.0|
|cmp_lname_c2|   2464.0|0.31841283153174427| 0.3685670662006654|0.0|     1.0|
|     cmp_sex|5749132.0|  0.955001381078048|0.20730111116897781|0.0|     1.0|
|      cmp_bd|5748337.0|0.22446526708507172| 0.4172297223846263|0.0|     1.0|
|      cmp_bm|5748337.0|0.48885529849763504|  0.499875823677903|0.0|     1.0|
|      cmp_by|5748337.0| 0.2227485966810923|0.41609096298317555|

In [15]:
def pivotSummary(desc: DataFrame) -> DataFrame:
    desc_p = desc.toPandas()
    desc_p = desc_p.set_index("summary").T.reset_index()
    desc_p = desc_p.rename(columns={'index':'field'})
    desc_p = desc_p.rename_axis(None, axis = 1)
    descT = spark.createDataFrame(desc_p)
    for c in descT.columns:
        if c == "field":
            continue
        descT = descT.withColumn(c, descT[c].cast(DoubleType()))
    return descT

In [217]:
def pivotSummary(desc):
    desc_p = desc.toPandas()
    desc_p = desc_p.set_index("summary").T.reset_index()
    desc_p = desc_p.rename(columns={'index':'field'})
    desc_p = desc_p.rename_axis(None, axis = 1)
    descT = spark.createDataFrame(desc_p)
    for c in descT.columns:
        if c == "field":
            continue
        descT = descT.withColumn(c, descT[c].cast(DoubleType()))
    return descT

In [16]:
matchSummaryT = pivotSummary(matchSummary)
missSummaryT = pivotSummary(missSummary)


In [17]:
missSummaryT.show()



+------------+---------+--------------------+-------------------+----+--------+
|       field|    count|                mean|             stddev| min|     max|
+------------+---------+--------------------+-------------------+----+--------+
|        id_1|5728201.0|  33319.913548075565| 23665.760130330673| 1.0| 99980.0|
|        id_2|5728201.0|   66643.44259218557| 23599.551728241317|30.0|100000.0|
|cmp_fname_c1|5727203.0|   0.711863480217509|0.38908060096985553| 0.0|     1.0|
|cmp_fname_c2| 102365.0|  0.8988473514090156|0.27272090294010215| 0.0|     1.0|
|cmp_lname_c1|5728201.0|  0.3131380113364304| 0.3322812130572686| 0.0|     1.0|
|cmp_lname_c2|   1989.0| 0.16295544855122535| 0.1930236663528703| 0.0|     1.0|
|     cmp_sex|5728201.0|  0.9548833918362851|0.20755988859217375| 0.0|     1.0|
|      cmp_bd|5727412.0|  0.2216425149788421| 0.4153518275558732| 0.0|     1.0|
|      cmp_bm|5727412.0|   0.486995347986141| 0.4998308940493865| 0.0|     1.0|
|      cmp_by|5727412.0|  0.219923064728

                                                                                

In [18]:
matchSummaryT.createOrReplaceTempView("match_desc")
missSummaryT.createOrReplaceTempView("miss_desc")
spark.sql("""
SELECT a.field, a.count + b.count total, a.mean - b.mean delta
FROM match_desc a INNER JOIN miss_desc b ON a.field = b.field
WHERE a.field NOT IN ("id_1", "id_2")
ORDER BY delta DESC, total DESC
""").show()

+------------+---------+--------------------+
|       field|    total|               delta|
+------------+---------+--------------------+
|     cmp_plz|5736289.0|  0.9563812499852176|
|cmp_lname_c2|   2464.0|  0.8064147192926266|
|      cmp_by|5748337.0|  0.7762059675300512|
|      cmp_bd|5748337.0|   0.775442311783404|
|cmp_lname_c1|5749132.0|  0.6838772482594513|
|      cmp_bm|5748337.0|  0.5109496938298685|
|cmp_fname_c1|5748125.0|  0.2854529057459949|
|cmp_fname_c2| 103698.0| 0.09104268062280196|
|     cmp_sex|5749132.0|0.032408185250332844|
+------------+---------+--------------------+



                                                                                

In [25]:
good_features = ["cmp_lname_c1", "cmp_plz", "cmp_by", "cmp_bd","cmp_bm"]

In [59]:
good_features1 = ["cmp_lname_c1 * 1.6", "cmp_plz * 1", "cmp_by * 1.3", "cmp_bd * 1.3","cmp_bm * 1.8"]

In [60]:
sum_expression = " + ".join(good_features1)

In [61]:
sum_expression

'cmp_lname_c1 * 1.6 + cmp_plz * 1 + cmp_by * 1.3 + cmp_bd * 1.3 + cmp_bm * 1.8'

In [62]:
scored = parsed.fillna(0, good_features).withColumn('score', expr(sum_expression)).select("score", "is_match")

In [63]:
scored.show()

+-----+--------+
|score|is_match|
+-----+--------+
|  7.0|    true|
|  7.0|    true|
|  7.0|    true|
|  7.0|    true|
|  7.0|    true|
|  7.0|    true|
|  6.0|    true|
|  7.0|    true|
|  7.0|    true|
|  7.0|    true|
|  7.0|    true|
|  7.0|    true|
|  7.0|    true|
|  7.0|    true|
|  7.0|    true|
|  7.0|    true|
|  6.0|    true|
|  7.0|    true|
|  7.0|    true|
|  7.0|    true|
+-----+--------+
only showing top 20 rows



In [64]:
def crossTabs(scored: DataFrame, t: DoubleType) -> DataFrame: 
    return scored.selectExpr(f"score >= {t} as above", "is_match").groupBy("above").pivot("is_match", ("true","false")).count()

In [73]:
crossTabs(scored, 5.8).show()



+-----+-----+-------+
|above| true|  false|
+-----+-----+-------+
| true|20726|    286|
|false|  205|5727915|
+-----+-----+-------+



                                                                                