In [0]:
import pickle
import boto3
import re
import json
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt



In [0]:
from pyspark.sql import SparkSession
sc = spark.sparkContext
from pyspark.sql import SQLContext
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType, StringType, FloatType, ArrayType, DoubleType, StructType, StructField
sqlContext = SQLContext(sc)



In [0]:
base_save_path = "s3://openalex-data-copy/snapshot_2023_02_15/"
iteration_save_path = "s3://author-disambiguation/V3/"

In [0]:
all_names = spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/given_family_output_names") \
.select(F.trim(F.col('given_name')).alias('given_name'), 
        F.trim(F.col('family_name')).alias('family_name'), 
        F.trim(F.col('output_name')).alias('output_name'), 
        F.trim(F.col('raw_input')).alias('raw_input')) \
.drop_duplicates(subset=['given_name','family_name','output_name'])
all_names.cache().count()

Out[7]: 52420616

In [0]:
all_names.groupby('given_name').agg(F.collect_list(F.col('output_name')).alias('output_name')) \
.write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/given_name_matches")

In [0]:
all_names.groupby('family_name').agg(F.collect_list(F.col('output_name')).alias('output_name')) \
.write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/family_name_matches")

In [0]:
given_matches = spark.read \
.parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/given_name_matches")

given_matches.cache().count()

Out[5]: 8460722

In [0]:
family_matches = spark.read \
.parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/family_name_matches")

family_matches.cache().count()

Out[7]: 7684839

In [0]:
def get_pair_of_hard_samples_(list_of_like_names):
    if len(list_of_like_names) <8:
        return []
    else:
        choice_1 = random.choice(list_of_like_names)
        choice_2 = random.choice(list_of_like_names)
        if choice_1 == choice_2:
            return list_of_like_names[:2]
        else:
            return [choice_1, choice_2]

get_pair_of_hard_samples = F.udf(get_pair_of_hard_samples_, ArrayType(StringType()))

In [0]:
given_matches \
.withColumn("hard_pairs", get_pair_of_hard_samples(F.col('output_name'))) \
.withColumn("hard_pair_len", F.size(F.col('hard_pairs'))) \
.filter(F.col('hard_pair_len') ==2) \
.select('hard_pairs') \
.write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/given_hard_pairs")

In [0]:
family_matches \
.withColumn("hard_pairs", get_pair_of_hard_samples(F.col('output_name'))) \
.withColumn("hard_pair_len", F.size(F.col('hard_pairs'))) \
.filter(F.col('hard_pair_len') ==2) \
.select('hard_pairs') \
.write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/family_hard_pairs")

In [0]:
spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/family_hard_pairs") \
.union(spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/given_hard_pairs")) \
.dropDuplicates() \
.coalesce(1).write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/final_hard_negative_pairs")

In [0]:
spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/family_hard_pairs") \
.union(spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/given_hard_pairs"))

Out[45]: 1118133

### For Disambiguator

In [0]:
orcid_names = spark.read.parquet(f"{iteration_save_path}orcid_names_data_dump.parquet") \
    .select('orcid',F.trim(F.col('given_names')).alias('given_name'),F.trim(F.col('family_name')).alias('family_name')) \
    .select('orcid', 'given_name','family_name', 
    F.concat_ws(' ', F.col('given_name'), F.col('family_name')).alias('output_name')) \
    .dropDuplicates()
orcid_names.cache().count()

Out[21]: 14845875

In [0]:
for i in range(4):
    orcid_names.select(F.lit('positive').alias('sample_type'), F.col('orcid').alias('orcid_1'), F.col('orcid').alias('orcid_2')) \
        .coalesce(50).write.mode('append') \
        .parquet(f"{iteration_save_path}disambiguator_training_data/final_positive_sample_orcids")

In [0]:
pos_samples = spark.read.parquet(f"{iteration_save_path}disambiguator_training_data/final_positive_sample_orcids")
pos_samples.cache().count()

Out[24]: 59383500

In [0]:
orcid_names.select(F.lit('positive').alias('sample_type'), F.col('orcid').alias('orcid_1'), F.col('orcid').alias('orcid_2')).show()

+-----------+-------------------+-------------------+
|sample_type|            orcid_1|            orcid_2|
+-----------+-------------------+-------------------+
|   positive|0000-0001-5041-9000|0000-0001-5041-9000|
|   positive|0000-0001-5153-4000|0000-0001-5153-4000|
|   positive|0000-0001-5351-3000|0000-0001-5351-3000|
|   positive|0000-0001-5435-6000|0000-0001-5435-6000|
|   positive|0000-0001-5550-0000|0000-0001-5550-0000|
|   positive|0000-0001-5763-7000|0000-0001-5763-7000|
|   positive|0000-0001-6414-0000|0000-0001-6414-0000|
|   positive|0000-0001-7488-3000|0000-0001-7488-3000|
|   positive|0000-0001-7557-9000|0000-0001-7557-9000|
|   positive|0000-0001-7748-6000|0000-0001-7748-6000|
|   positive|0000-0001-8089-6000|0000-0001-8089-6000|
|   positive|0000-0001-8331-7000|0000-0001-8331-7000|
|   positive|0000-0001-8469-4000|0000-0001-8469-4000|
|   positive|0000-0001-9019-7000|0000-0001-9019-7000|
|   positive|0000-0001-9537-7000|0000-0001-9537-7000|
|   positive|0000-0001-9666-

In [0]:
orcid_names.show(10)

+-------------------+----------+------------+-------------------+
|              orcid|given_name| family_name|        output_name|
+-------------------+----------+------------+-------------------+
|0000-0001-5074-2000|    Martin|Perez-Santos|Martin Perez-Santos|
|0000-0001-5001-3000|   Vincent|      Nguyen|     Vincent Nguyen|
|0000-0001-5109-1000|    Andrew|     Porteus|     Andrew Porteus|
|0000-0001-5031-2000|  Yi-Chien|        null|           Yi-Chien|
|0000-0001-5028-3000|  Benjamin|      Becket|    Benjamin Becket|
|0000-0001-5024-0000|   Pâmella|       Romão|      Pâmella Romão|
|0000-0001-5073-4000|    Kleice|    Oliveira|    Kleice Oliveira|
|0000-0001-5091-0000|   Michele|     Saysana|    Michele Saysana|
|0000-0001-5123-5000|    Şevval|     Özdemir|     Şevval Özdemir|
|0000-0001-5006-4000|Risma Dede|      Andini|  Risma Dede Andini|
+-------------------+----------+------------+-------------------+
only showing top 10 rows



In [0]:
w = Window().partitionBy(F.lit('a')).orderBy(F.rand())

orcid_names.withColumn("row_num", F.row_number().over(w)) \
    .select('orcid', 'row_num') \
    .coalesce(20).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/easy_samples/random_1")

In [0]:
w = Window().partitionBy(F.lit('a')).orderBy(F.rand())

orcid_names.withColumn("row_num", F.row_number().over(w)) \
    .select('orcid', 'row_num') \
    .coalesce(20).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/easy_samples/random_2")

In [0]:
w = Window().partitionBy(F.lit('a')).orderBy(F.rand())

orcid_names.withColumn("row_num", F.row_number().over(w)) \
    .select('orcid', 'row_num') \
    .coalesce(20).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/easy_samples/random_3")

In [0]:
orcid_names \
    .filter(~F.col('given_name').isNull()) \
    .filter(~F.col('family_name').isNull()) \
    .groupBy('output_name').agg(F.collect_list(F.col('orcid')).alias('orcids')) \
    .withColumn('orcid_len', F.size(F.col('orcids'))) \
    .filter(F.col('orcid_len')>1) \
    .coalesce(20).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/hard_samples/hard_negatives_output_name")

In [0]:
orcid_names.groupBy('given_name').agg(F.collect_list(F.col('orcid')).alias('orcids')) \
    .withColumn('orcid_len', F.size(F.col('orcids'))) \
    .filter(F.col('orcid_len')>1) \
    .coalesce(20).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/hard_samples/hard_negatives_given_name")

In [0]:
orcid_names.groupBy('family_name').agg(F.collect_list(F.col('orcid')).alias('orcids')) \
    .withColumn('orcid_len', F.size(F.col('orcids'))) \
    .filter(F.col('orcid_len')>1) \
    .coalesce(20).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/hard_samples/hard_negatives_family_name")

In [0]:
orcid_names.groupBy('output_name').agg(F.collect_list(F.col('orcid')).alias('orcids')) \
    .withColumn('orcid_len', F.size(F.col('orcids'))) \
    .filter(F.col('orcid_len')>1) \
    .sample(0.001).show(10, truncate=False)

+-------------------+---------------------------------------------------------------------------------------------------------+---------+
|output_name        |orcids                                                                                                   |orcid_len|
+-------------------+---------------------------------------------------------------------------------------------------------+---------+
|Abbas Jahanara     |[0000-0001-5832-9309, 0000-0003-3027-0119]                                                               |2        |
|Aiping Song        |[0000-0001-8981-2213, 0000-0003-3812-6744]                                                               |2        |
|António Castanheira|[0000-0002-9811-2961, 0000-0002-4346-2745, 0000-0002-4861-7267]                                          |3        |
|Dalya              |[0000-0001-9326-0896, 0000-0001-9222-7285, 0000-0001-9379-2737, 0000-0002-6733-748X]                     |4        |
|Daqi Liu           |[0000-0002-04

In [0]:
def get_pair_of_hard_samples_(list_of_like_names):
    if len(list_of_like_names) <2:
        return []
    else:
        choice_1 = random.choice(list_of_like_names)
        choice_2 = random.choice(list_of_like_names)
        if choice_1 == choice_2:
            return list_of_like_names[:2]
        else:
            return [choice_1, choice_2]

get_pair_of_hard_samples = F.udf(get_pair_of_hard_samples_, ArrayType(StringType()))

In [0]:
random_1 = spark.read\
    .parquet(f"{iteration_save_path}disambiguator_training_data/easy_samples/random_1") \
    .select('row_num',F.col('orcid').alias('orcid_1'))

In [0]:
random_2 = spark.read\
    .parquet(f"{iteration_save_path}disambiguator_training_data/easy_samples/random_2") \
    .select('row_num',F.col('orcid').alias('orcid_2'))

In [0]:
random_3 = spark.read\
    .parquet(f"{iteration_save_path}disambiguator_training_data/easy_samples/random_3") \
    .select('row_num',F.col('orcid').alias('orcid_3'))

In [0]:
def turn_into_neg_pair_(num_1, num_2):
    return [num_1, num_2]

turn_into_neg_pair = F.udf(turn_into_neg_pair_, ArrayType(StringType()))

In [0]:
random_1\
    .join(random_2, how='inner', on='row_num') \
    .join(random_3, how='inner', on='row_num') \
    .withColumn('pair_1', turn_into_neg_pair(F.col('orcid_1'), F.col('orcid_2'))) \
    .withColumn('pair_2', turn_into_neg_pair(F.col('orcid_2'), F.col('orcid_3'))) \
    .withColumn('pair_3', turn_into_neg_pair(F.col('orcid_1'), F.col('orcid_3'))) \
    .select('pair_1','pair_2','pair_3') \
    .coalesce(20).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/easy_samples/random_negatives")

In [0]:
random_1 = spark.read.parquet(f"{iteration_save_path}disambiguator_training_data/easy_samples/random_negatives") \
    .select(F.col('pair_1').alias('pairs'), F.lit('random_negative').alias('sample_type'))

random_1.cache().count()

Out[4]: 14845875

In [0]:
random_2 = spark.read.parquet(f"{iteration_save_path}disambiguator_training_data/easy_samples/random_negatives") \
    .select(F.col('pair_2').alias('pairs'), F.lit('random_negative').alias('sample_type'))

random_2.cache().count()

Out[5]: 14845875

In [0]:
random_3 = spark.read.parquet(f"{iteration_save_path}disambiguator_training_data/easy_samples/random_negatives") \
    .select(F.col('pair_3').alias('pairs'), F.lit('random_negative').alias('sample_type'))

random_3.cache().count()

Out[6]: 14845875

In [0]:
hard_output = spark.read\
    .parquet(f"{iteration_save_path}disambiguator_training_data/hard_samples/hard_negatives_output_name") \
    .select(F.col('orcids').alias('pairs'), F.lit('output_negative').alias('sample_type'))
hard_output.cache().count()

Out[13]: 986832

In [0]:
hard_given = spark.read\
    .parquet(f"{iteration_save_path}disambiguator_training_data/hard_samples/hard_negatives_given_name") \
    .select(F.col('orcids').alias('pairs'), F.lit('given_negative').alias('sample_type'))
hard_given.cache().count()

Out[14]: 532129

In [0]:
hard_family = spark.read\
    .parquet(f"{iteration_save_path}disambiguator_training_data/hard_samples/hard_negatives_family_name") \
    .select(F.col('orcids').alias('pairs'), F.lit('family_negative').alias('sample_type'))
hard_family.cache().count()

Out[15]: 937638

In [0]:
hard_output \
    .union(hard_family.select(*hard_output.columns)) \
    .union(hard_given.select(*hard_output.columns)) \
    .union(random_1.select(*hard_output.columns)) \
    .union(random_2.select(*hard_output.columns)) \
    .union(random_3.select(*hard_output.columns)) \
    .withColumn('final_pair', get_pair_of_hard_samples(F.col('pairs'))) \
        .select('final_pair', 'sample_type') \
    .withColumn('pair_len', F.size(F.col('final_pair'))) \
    .filter(F.col('pair_len')==2) \
    .select('sample_type', F.col('final_pair')[0].alias('orcid_1'), F.col('final_pair')[1].alias('orcid_2')) \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/final_negative_sample_orcids")

### Creating the training data for disambiguator

* join sample data on orcid for first ID
* then sample data
* join on second ID to first orcid
* sample that data one more time
* dataset complete (do for both positive and negative)

In [0]:
w1 = Window().partitionBy('orcid_1').orderBy(F.rand())
w2 = Window().partitionBy('orcid_2').orderBy(F.rand())

In [0]:
pos_samples = spark.read.parquet(f"{iteration_save_path}disambiguator_training_data/final_positive_sample_orcids")
pos_samples.cache().count()

Out[6]: 59383500

In [0]:
neg_samples = spark.read.parquet(f"{iteration_save_path}disambiguator_training_data/final_negative_sample_orcids")
neg_samples.cache().count()

Out[7]: 46994224

In [0]:
dis_data_proc_1 = spark.read.parquet(f"{iteration_save_path}disambiguator_processed_training_samples") \
    .select(F.col('work_id').alias('work_id_1'),F.col('orcid').alias('orcid_1'), F.col('new_training_sample').alias('new_training_sample_1'))
dis_data_proc_1.cache().count()

Out[8]: 190816087

In [0]:
dis_data_proc_2 = spark.read.parquet(f"{iteration_save_path}disambiguator_processed_training_samples") \
    .select(F.col('work_id').alias('work_id_2'),F.col('orcid').alias('orcid_2'), F.col('new_training_sample').alias('new_training_sample_2'))
dis_data_proc_2.cache().count()

Out[9]: 190816087

In [0]:
pos_samples \
    .join(dis_data_proc_1, how='inner', on='orcid_1') \
    .withColumn('inst_orc', F.row_number().over(w1)) \
    .filter(F.col('inst_orc')<3) \
    .select('sample_type','orcid_1','orcid_2','work_id_1','new_training_sample_1') \
    .join(dis_data_proc_2, how='inner', on='orcid_2') \
    .withColumn('inst_orc', F.row_number().over(w2)) \
    .filter(F.col('inst_orc')<3) \
    .select('sample_type','orcid_1','orcid_2','work_id_1','work_id_2','new_training_sample_1','new_training_sample_2') \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/final_positive_samples")

In [0]:
final_pos_samples = spark.read.parquet(f"{iteration_save_path}disambiguator_training_data/final_positive_samples")
final_pos_samples.cache().count()

Out[11]: 8845254

In [0]:
neg_samples \
    .join(dis_data_proc_1, how='inner', on='orcid_1') \
    .withColumn('inst_orc', F.row_number().over(w1)) \
    .filter(F.col('inst_orc')<4) \
    .select('sample_type','orcid_1','orcid_2','work_id_1','new_training_sample_1') \
    .join(dis_data_proc_2, how='inner', on='orcid_2') \
    .withColumn('inst_orc', F.row_number().over(w2)) \
    .filter(F.col('inst_orc')<4) \
    .select('sample_type','orcid_1','orcid_2','work_id_1','work_id_2','new_training_sample_1','new_training_sample_2') \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/final_negative_samples")

In [0]:
final_neg_samples = spark.read.parquet(f"{iteration_save_path}disambiguator_training_data/final_negative_samples")
final_neg_samples.cache().count()

Out[15]: 6494792

In [0]:
final_pos_samples.union(final_neg_samples.select(*final_pos_samples.columns)).orderBy(F.rand()) \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/all_final_sample_data")

### Train, Val, Test

In [0]:
all_sample_data = spark.read.parquet(f"{iteration_save_path}disambiguator_training_data/all_final_sample_data")
all_sample_data.cache().count()

Out[17]: 15340046

In [0]:
all_sample_data.groupBy('sample_type').count().show()

+---------------+-------+
|    sample_type|  count|
+---------------+-------+
|       positive|8845254|
|family_negative| 149869|
|output_negative| 140818|
|random_negative|6146597|
| given_negative|  57508|
+---------------+-------+



In [0]:
train, val, test = all_sample_data.randomSplit([0.998, 0.001, 0.001], seed=0)

In [0]:
val_df = val.toPandas()

In [0]:
test_df = test.toPandas()

In [0]:
orcids_to_skip = list(set(val_df['orcid_1'].tolist() + val_df['orcid_2'].tolist() + test_df['orcid_1'].tolist() + test_df['orcid_2'].tolist()))
len(orcids_to_skip)

Out[37]: 43725

In [0]:
train \
    .filter(~F.col('orcid_1').isin(orcids_to_skip)) \
    .filter(~F.col('orcid_2').isin(orcids_to_skip)) \
    .orderBy(F.rand()).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/training_data/train")

In [0]:
val.orderBy(F.rand()).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/training_data/val")

In [0]:
test.orderBy(F.rand()).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/training_data/test")

In [0]:
spark.read.parquet(f"{iteration_save_path}disambiguator_training_data/training_data/train").count()

Out[41]: 15114415

In [0]:
spark.read.parquet(f"{iteration_save_path}disambiguator_training_data/training_data/val").count()

Out[42]: 15355

In [0]:
spark.read.parquet(f"{iteration_save_path}disambiguator_training_data/training_data/test").count()

Out[43]: 15471