In [1]:
import pickle
import boto3
import re
import json
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType, StringType, FloatType, ArrayType, DoubleType, StructType, StructField

In [None]:
base_save_path = "<S3path>"
iteration_save_path = "<S3path>"

In [None]:
all_names = spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/given_family_output_names") \
.select(F.trim(F.col('given_name')).alias('given_name'), 
        F.trim(F.col('family_name')).alias('family_name'), 
        F.trim(F.col('output_name')).alias('output_name'), 
        F.trim(F.col('raw_input')).alias('raw_input')) \
.drop_duplicates(subset=['given_name','family_name','output_name'])
all_names.cache().count()

In [None]:
all_names.groupby('given_name').agg(F.collect_list(F.col('output_name')).alias('output_name')) \
.write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/given_name_matches")

In [None]:
all_names.groupby('family_name').agg(F.collect_list(F.col('output_name')).alias('output_name')) \
.write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/family_name_matches")

In [None]:
given_matches = spark.read \
.parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/given_name_matches")

given_matches.cache().count()

In [None]:
family_matches = spark.read \
.parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/family_name_matches")

family_matches.cache().count()

In [None]:
def get_pair_of_hard_samples_(list_of_like_names):
    if len(list_of_like_names) <8:
        return []
    else:
        choice_1 = random.choice(list_of_like_names)
        choice_2 = random.choice(list_of_like_names)
        if choice_1 == choice_2:
            return list_of_like_names[:2]
        else:
            return [choice_1, choice_2]

get_pair_of_hard_samples = F.udf(get_pair_of_hard_samples_, ArrayType(StringType()))

In [None]:
given_matches \
.withColumn("hard_pairs", get_pair_of_hard_samples(F.col('output_name'))) \
.withColumn("hard_pair_len", F.size(F.col('hard_pairs'))) \
.filter(F.col('hard_pair_len') ==2) \
.select('hard_pairs') \
.write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/given_hard_pairs")

In [None]:
family_matches \
.withColumn("hard_pairs", get_pair_of_hard_samples(F.col('output_name'))) \
.withColumn("hard_pair_len", F.size(F.col('hard_pairs'))) \
.filter(F.col('hard_pair_len') ==2) \
.select('hard_pairs') \
.write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/family_hard_pairs")

In [None]:
spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/family_hard_pairs") \
.union(spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/given_hard_pairs")) \
.dropDuplicates() \
.coalesce(1).write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/final_hard_negative_pairs")

In [None]:
spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/family_hard_pairs") \
.union(spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/hard_samples/given_hard_pairs"))

### For Disambiguator

In [None]:
orcid_names = spark.read.parquet(f"{iteration_save_path}orcid_names_data_dump.parquet") \
    .select('orcid',F.trim(F.col('given_names')).alias('given_name'),F.trim(F.col('family_name')).alias('family_name')) \
    .select('orcid', 'given_name','family_name', 
    F.concat_ws(' ', F.col('given_name'), F.col('family_name')).alias('output_name')) \
    .dropDuplicates()
orcid_names.cache().count()

In [None]:
for i in range(4):
    orcid_names.select(F.lit('positive').alias('sample_type'), F.col('orcid').alias('orcid_1'), F.col('orcid').alias('orcid_2')) \
        .coalesce(50).write.mode('append') \
        .parquet(f"{iteration_save_path}disambiguator_training_data/final_positive_sample_orcids")

In [None]:
pos_samples = spark.read.parquet(f"{iteration_save_path}disambiguator_training_data/final_positive_sample_orcids")
pos_samples.cache().count()

In [None]:
w = Window().partitionBy(F.lit('a')).orderBy(F.rand())

orcid_names.withColumn("row_num", F.row_number().over(w)) \
    .select('orcid', 'row_num') \
    .coalesce(20).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/easy_samples/random_1")

In [None]:
w = Window().partitionBy(F.lit('a')).orderBy(F.rand())

orcid_names.withColumn("row_num", F.row_number().over(w)) \
    .select('orcid', 'row_num') \
    .coalesce(20).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/easy_samples/random_2")

In [None]:
w = Window().partitionBy(F.lit('a')).orderBy(F.rand())

orcid_names.withColumn("row_num", F.row_number().over(w)) \
    .select('orcid', 'row_num') \
    .coalesce(20).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/easy_samples/random_3")

In [None]:
orcid_names \
    .filter(~F.col('given_name').isNull()) \
    .filter(~F.col('family_name').isNull()) \
    .groupBy('output_name').agg(F.collect_list(F.col('orcid')).alias('orcids')) \
    .withColumn('orcid_len', F.size(F.col('orcids'))) \
    .filter(F.col('orcid_len')>1) \
    .coalesce(20).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/hard_samples/hard_negatives_output_name")

In [None]:
orcid_names.groupBy('given_name').agg(F.collect_list(F.col('orcid')).alias('orcids')) \
    .withColumn('orcid_len', F.size(F.col('orcids'))) \
    .filter(F.col('orcid_len')>1) \
    .coalesce(20).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/hard_samples/hard_negatives_given_name")

In [None]:
orcid_names.groupBy('family_name').agg(F.collect_list(F.col('orcid')).alias('orcids')) \
    .withColumn('orcid_len', F.size(F.col('orcids'))) \
    .filter(F.col('orcid_len')>1) \
    .coalesce(20).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}disambiguator_training_data/hard_samples/hard_negatives_family_name")

In [None]:
orcid_names.groupBy('output_name').agg(F.collect_list(F.col('orcid')).alias('orcids')) \
    .withColumn('orcid_len', F.size(F.col('orcids'))) \
    .filter(F.col('orcid_len')>1) \
    .sample(0.001).show(10, truncate=False)