In [None]:
import pickle
import boto3
import re
import json
import random
import pandas as pd
import numpy as np

from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType, StringType, FloatType, ArrayType, DoubleType, StructType, StructField
from pyspark.ml.functions import array_to_vector

In [None]:
base_save_path = "<S3path>"
iteration_save_path = "<S3path>"

#### Getting hard pairs using name embedding

In [None]:
def get_pair_of_hard_samples_(list_of_like_names):
    if len(list_of_like_names) <8:
        return []
    else:
        choice_1 = random.choice(list_of_like_names)
        choice_2 = random.choice(list_of_like_names)
        if choice_1 == choice_2:
            return list_of_like_names[:2]
        else:
            return [choice_1, choice_2]

get_pair_of_hard_samples = F.udf(get_pair_of_hard_samples_, ArrayType(StringType()))

In [None]:
def turn_into_neg_pair_(num_1, num_2):
    return [num_1, num_2]

turn_into_neg_pair = F.udf(turn_into_neg_pair_, ArrayType(StringType()))

In [None]:
hard_embedding_pairs = spark.read \
    .parquet(f"{iteration_save_path}final_model_data/training_data_creation/hard_embedding_orcid_pairs/hard_embeddings.parquet") \
    .select('orcid_1', 'orcid_2', F.lit('embedding_negative').alias('sample_type'))
hard_embedding_pairs.cache().count()

In [None]:
hard_output = spark.read\
    .parquet(f"{iteration_save_path}disambiguator_training_data/hard_samples/hard_negatives_output_name") \
    .select(F.col('orcids').alias('pairs'), F.lit('output_negative').alias('sample_type'))
hard_output.cache().count()

In [None]:
hard_given = spark.read\
    .parquet(f"{iteration_save_path}disambiguator_training_data/hard_samples/hard_negatives_given_name") \
    .select(F.col('orcids').alias('pairs'), F.lit('given_negative').alias('sample_type'))
hard_given.cache().count()

In [None]:
hard_family = spark.read\
    .parquet(f"{iteration_save_path}disambiguator_training_data/hard_samples/hard_negatives_family_name") \
    .select(F.col('orcids').alias('pairs'), F.lit('family_negative').alias('sample_type'))
hard_family.cache().count()

In [None]:
hard_output \
    .union(hard_family.select(*hard_output.columns)) \
    .union(hard_given.select(*hard_output.columns)) \
    .withColumn('final_pair', get_pair_of_hard_samples(F.col('pairs'))) \
        .select('final_pair', 'sample_type') \
    .withColumn('pair_len', F.size(F.col('final_pair'))) \
    .filter(F.col('pair_len')==2) \
    .select('sample_type', F.col('final_pair')[0].alias('orcid_1'), F.col('final_pair')[1].alias('orcid_2')) \
    .union(hard_embedding_pairs.select('sample_type','orcid_1','orcid_2')) \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}final_model_data/training_data_creation/final_negative_sample_orcids")

In [None]:
hard_output \
    .union(hard_family.select(*hard_output.columns)) \
    .union(hard_given.select(*hard_output.columns)) \
    .withColumn('final_pair', get_pair_of_hard_samples(F.col('pairs'))) \
        .select('final_pair', 'sample_type') \
    .withColumn('pair_len', F.size(F.col('final_pair'))) \
    .filter(F.col('pair_len')==2) \
    .select('sample_type', F.col('final_pair')[0].alias('orcid_1'), F.col('final_pair')[1].alias('orcid_2'))  \
    .coalesce(1) \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}datasets_to_share/hard_negative_samples")

### Creating the training data for embedding model

In [None]:
level_0_ids = ['17744445','138885662','162324750','144133560','15744967','33923547','71924100','86803240','41008148','127313418','185592680','142362112','144024400','127413603','205649164','95457728','192562407','121332964','39432304']
broadcast_level_0_ids = spark.sparkContext.broadcast(level_0_ids)

In [None]:
def remove_level_0_ids_(old_concepts_list):
    return [x for x in old_concepts_list if x not in broadcast_level_0_ids.value]

remove_level_0_ids = F.udf(remove_level_0_ids_, ArrayType(StringType()))

In [None]:
def cosine_similarity_(v1, v2):
    return float(v1.dot(v2) / (v1.norm(2) * v2.norm(2)))

cosine_similarity_udf = F.udf(cosine_similarity_, FloatType())

In [None]:
def get_percentage_of_intersection_and_sum_(list_1, list_2):
    if list_1 and list_2:
        list_1 = set(list_1)
        list_2 = set(list_2)
        list_inter = len(list_1.intersection(list_2))
        sum_of_lists = len(list_1 | list_2)
        
        return [float(len(list_1)), float(len(list_2)), float(list_inter), float(sum_of_lists)]
    else:
        if list_1:
            return [float(len(list_1)), 0.0, 0.0, float(len(list_1))]
        elif list_2:
            return [0.0, float(len(list_2)), 0.0, float(len(list_2))]
        else:
            return [0.0, 0.0, 0.0, 0.0]

get_percentage_of_intersection_and_sum = F.udf(get_percentage_of_intersection_and_sum_, ArrayType(FloatType()))

In [None]:
def does_either_work_show_in_citations_(paper_id_1, paper_id_2, citation_1, citation_2):
    if paper_id_1 in citation_2:
        return 1
    elif paper_id_2 in citation_1:
        return 1
    else:
        return 0

does_either_work_show_in_citations = F.udf(does_either_work_show_in_citations_, IntegerType())

In [None]:
def get_paper_id_(work_id):
    return work_id.split("_")[0]

get_paper_id = F.udf(get_paper_id_, StringType())

In [None]:
w1 = Window().partitionBy('orcid_1').orderBy(F.rand())
w2 = Window().partitionBy('orcid_2').orderBy(F.rand())

In [None]:
pos_samples = spark.read.parquet(f"{iteration_save_path}disambiguator_training_data/final_positive_sample_orcids")
pos_samples.cache().count()

In [None]:
neg_samples = spark.read.parquet(f"{iteration_save_path}final_model_data/training_data_creation/final_negative_sample_orcids")
neg_samples.cache().count()

In [None]:
dis_data_proc_1 = spark.read.parquet(f"{iteration_save_path}final_model_data/all_sample_data_for_all_work_authors") \
    .select(F.col('work_id').alias('work_id_1'), F.col('orcid').alias('orcid_1'), F.col('coauthors').alias('coauthors_1'), 
            F.col('citations').alias('citations_1'), F.col('institutions').alias('institutions_1'), 
            F.col('original_author').alias('author_1'), F.col('concepts').alias('concepts_1')) \
    .filter(F.col('orcid')!='') \
    .withColumn('inst_orc_1', F.row_number().over(w1))
dis_data_proc_1.cache().count()

In [None]:
dis_data_proc_2 = spark.read.parquet(f"{iteration_save_path}final_model_data/all_sample_data_for_all_work_authors") \
    .select(F.col('work_id').alias('work_id_2'), F.col('orcid').alias('orcid_2'), F.col('coauthors').alias('coauthors_2'), 
            F.col('institutions').alias('institutions_2'), F.col('citations').alias('citations_2'),
            F.col('original_author').alias('author_2'), F.col('concepts').alias('concepts_2')) \
    .filter(F.col('orcid')!='') \
    .withColumn('inst_orc_2', F.row_number().over(w2))
dis_data_proc_2.cache().count()

In [None]:
pos_samples \
    .join(dis_data_proc_1, how='inner', on='orcid_1') \
    .filter(F.col('inst_orc_1')<2) \
    .join(dis_data_proc_2, how='inner', on='orcid_2') \
    .filter(F.col('inst_orc_2')<2) \
    .select('sample_type','orcid_1','orcid_2','author_1','author_2','work_id_1','work_id_2','coauthors_1','coauthors_2','concepts_1',
            'concepts_2','institutions_1','institutions_2','citations_1','citations_2') \
    .sample(0.1) \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}final_model_data/training_data_creation/final_positive_samples")

In [None]:
final_pos_samples = spark.read.parquet(f"{iteration_save_path}final_model_data/training_data_creation/final_positive_samples")
final_pos_samples.cache().count()

In [None]:
neg_samples \
    .join(dis_data_proc_1, how='inner', on='orcid_1') \
    .filter(F.col('inst_orc_1')<4) \
    .join(dis_data_proc_2, how='inner', on='orcid_2') \
    .filter(F.col('inst_orc_2')<4) \
    .select('sample_type','orcid_1','orcid_2','author_1','author_2','work_id_1','work_id_2','coauthors_1','coauthors_2','concepts_1',
            'concepts_2','institutions_1','institutions_2','citations_1','citations_2') \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}final_model_data/training_data_creation/final_negative_samples")

In [None]:
final_neg_samples = spark.read.parquet(f"{iteration_save_path}final_model_data/training_data_creation/final_negative_samples")
final_neg_samples.cache().count()

In [None]:
final_pos_samples.union(final_neg_samples.select(*final_pos_samples.columns)).orderBy(F.rand()) \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}final_model_data/training_data_creation/all_final_sample_data")

### Train, Val, Test

In [None]:
@udf(returnType=ArrayType(StringType()))
def remove_short_authors(coauthors):
    return [x for x in coauthors if len(x)>6]

In [None]:
spark.read.parquet(f"{iteration_save_path}final_model_data/training_data_creation/all_final_sample_data") \
    .select('sample_type','orcid_1','orcid_2',F.col('work_id_1').alias('work_author_id_1'),
            F.col('work_id_2').alias('work_author_id_2'),'author_1','author_2','coauthors_1','coauthors_2',
            'concepts_1','concepts_2','institutions_1','institutions_2','citations_1','citations_2') \
    .coalesce(1) \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}datasets_to_share/all_possible_training_data")

In [None]:
temp_df = spark.read.parquet(f"{iteration_save_path}final_model_data/training_data_creation/all_final_sample_data")
temp_df.cache().count()

In [None]:
temp_df \
    .withColumn('paper_id_1', get_paper_id(F.col('work_id_1'))) \
    .withColumn('paper_id_2', get_paper_id(F.col('work_id_2'))) \
    .withColumn('coauthors_shorter_1', remove_short_authors(F.col('coauthors_1'))) \
    .withColumn('coauthors_shorter_2', remove_short_authors(F.col('coauthors_2'))) \
    .withColumn('concepts_shorter_1', remove_level_0_ids(F.col('concepts_1'))) \
    .withColumn('concepts_shorter_2', remove_level_0_ids(F.col('concepts_2'))) \
    .withColumn('concepts_shortest_1', get_new_concepts_list(F.col('concepts_1'))) \
    .withColumn('concepts_shortest_2', get_new_concepts_list(F.col('concepts_2'))) \
    .withColumn('per_int_insts', get_percentage_of_intersection_and_sum(F.col('institutions_1'), F.col('institutions_2'))) \
    .withColumn('per_int_coauthors', get_percentage_of_intersection_and_sum(F.col('coauthors_1'), F.col('coauthors_2'))) \
    .withColumn('per_int_coauthors_shorter', get_percentage_of_intersection_and_sum(F.col('coauthors_shorter_1'), 
                                                                                    F.col('coauthors_shorter_2'))) \
    .withColumn('per_int_concepts', get_percentage_of_intersection_and_sum(F.col('concepts_1'), F.col('concepts_2'))) \
    .withColumn('per_int_concepts_shorter', get_percentage_of_intersection_and_sum(F.col('concepts_shorter_1'), 
                                                                                   F.col('concepts_shorter_2'))) \
    .withColumn('per_int_concepts_shortest', get_percentage_of_intersection_and_sum(F.col('concepts_shortest_1'), 
                                                                                    F.col('concepts_shortest_2'))) \
    .withColumn('per_int_citations', get_percentage_of_intersection_and_sum(F.col('citations_1'), F.col('citations_2'))) \
    .withColumn('citation_work_match', does_either_work_show_in_citations(F.col('paper_id_1'), F.col('paper_id_2'), 
                                                                        F.col('citations_1'), F.col('citations_2'))) \
    .select('sample_type','author_1','author_2','orcid_1','orcid_2',
            F.col('per_int_insts')[0].alias('inst_1_len'), F.col('per_int_insts')[1].alias('inst_2_len'),
            F.col('per_int_insts')[2].alias('inst_match'), F.col('per_int_insts')[3].alias('inst_sum'), 
            F.col('per_int_concepts')[0].alias('concepts_1_len'), F.col('per_int_concepts')[1].alias('concepts_2_len'),
            F.col('per_int_concepts')[2].alias('concepts_match'), F.col('per_int_concepts')[3].alias('concepts_sum'), 
            F.col('per_int_concepts_shorter')[0].alias('concepts_shorter_1_len'), 
            F.col('per_int_concepts_shorter')[1].alias('concepts_shorter_2_len'),
            F.col('per_int_concepts_shorter')[2].alias('concepts_shorter_match'), 
            F.col('per_int_concepts_shorter')[3].alias('concepts_shorter_sum'), 
            F.col('per_int_concepts_shortest')[0].alias('concepts_shortest_1_len'), 
            F.col('per_int_concepts_shortest')[1].alias('concepts_shortest_2_len'),
            F.col('per_int_concepts_shortest')[2].alias('concepts_shortest_match'), 
            F.col('per_int_concepts_shortest')[3].alias('concepts_shortest_sum'), 
            F.col('per_int_coauthors_shorter')[0].alias('coauthors_shorter_1_len'), 
            F.col('per_int_coauthors_shorter')[1].alias('coauthors_shorter_2_len'),
            F.col('per_int_coauthors_shorter')[2].alias('coauthors_shorter_match'), 
            F.col('per_int_coauthors_shorter')[3].alias('coauthors_shorter_sum'), 
            F.col('per_int_coauthors')[0].alias('coauthors_1_len'), F.col('per_int_coauthors')[1].alias('coauthors_2_len'),
            F.col('per_int_coauthors')[2].alias('coauthors_match'), F.col('per_int_coauthors')[3].alias('coauthors_sum'), 
            F.col('per_int_citations')[0].alias('citation_1_len'), F.col('per_int_citations')[1].alias('citation_2_len'),
            F.col('per_int_citations')[2].alias('citation_match'), F.col('per_int_citations')[3].alias('citation_sum'), 
            'citation_work_match') \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}final_model_data/training_data_creation/all_final_sample_data_processed")

In [None]:
all_sample_data = spark.read.parquet(f"{iteration_save_path}final_model_data/training_data_creation/all_final_sample_data_processed")
all_sample_data.cache().count()

In [None]:
all_sample_data.groupBy('sample_type').count().show()

In [None]:
all_sample_data.filter(F.col('sample_type')=='positive').sample(0.01).select('coauthors_match','coauthors_shorter_match').show(30)

In [None]:
train, val, test = all_sample_data.randomSplit([0.99, 0.003, 0.002], seed=0)

In [None]:
val.coalesce(1).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}final_model_data/training_data/val")

In [None]:
test.coalesce(1).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}final_model_data/training_data/test")

In [None]:
val_df = spark.read.parquet(f"{iteration_save_path}final_model_data/training_data/val")
val_df.cache().count()

In [None]:
test_df = spark.read.parquet(f"{iteration_save_path}final_model_data/training_data/test")
test_df.cache().count()

In [None]:
train.count()

In [None]:
# making sure that none of the validation or testing samples show up in the training dataset
train \
    .join(val_df, how='leftanti', on='orcid_1') \
    .join(val_df, how='leftanti', on='orcid_2') \
    .join(val_df.select(F.col('orcid_1').alias('orcid_2')), how='leftanti', on='orcid_2') \
    .join(val_df.select(F.col('orcid_2').alias('orcid_1')), how='leftanti', on='orcid_1') \
    .join(test_df, how='leftanti', on='orcid_1') \
    .join(test_df, how='leftanti', on='orcid_2') \
    .join(test_df.select(F.col('orcid_1').alias('orcid_2')), how='leftanti', on='orcid_2') \
    .join(test_df.select(F.col('orcid_2').alias('orcid_1')), how='leftanti', on='orcid_1') \
    .coalesce(1).write.mode('overwrite') \
    .parquet(f"{iteration_save_path}final_model_data/training_data/train")

In [None]:
spark.read.parquet(f"{iteration_save_path}final_model_data/training_data/train").show(10)