In [None]:
import pickle
import boto3
import re
import json
import random
import unicodedata
import unidecode
import datetime
from statistics import mode
from nameparser import HumanName
from collections import Counter
import pandas as pd
import numpy as np
import xgboost as xgb

In [None]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType, StringType, FloatType, ArrayType, DoubleType, StructType, StructField, LongType

In [None]:
curr_date = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")
prod_save_path = "<S3path>"
temp_save_path = f"<S3path>/{curr_date}"

#### Load Disambiguator Model

In [None]:
with open("<local-model-path>/Disambiguator.pkl", "rb") as f:
    disambiguator_model = pickle.load(f)

broadcast_disambiguator_model = spark.sparkContext.broadcast(disambiguator_model)

#### Get Latest Data to Disambiguate

In [None]:
def get_secret():

    ### code for getting AWS secrets ###
    return secret

In [None]:
secret = get_secret()

In [None]:
df = (spark.read
    .format("postgresql")
    .option("dbtable", "<input-data-postgres-table>")
    .option("host", secret['host'])
    .option("port", secret['port'])
    .option("database", secret['dbname'])
    .option("user", secret['username'])
    .option("password", secret['password'])
    .option("partitionColumn", "partition")
    .option("lowerBound", 0)
    .option("upperBound", 21)
    .option("numPartitions", 6)
    .option("fetchSize", "15")
    .load()
)

In [None]:
df.write.mode('overwrite') \
    .filter(F.col('original_author').isNotNull()) \
    .filter(F.col('original_author')!='') \
    .parquet(f"<S3path>")

#### Transform Data for Disambiguation

In [None]:
@udf(returnType=StringType())
def transform_author_name(author):
    if author.startswith("None "):
        author = author.replace("None ", "")
    elif author.startswith("Array "):
        author = author.replace("Array ", "")

    author = unicodedata.normalize('NFKC', author)
    
    author_name = HumanName(" ".join(author.split()))

    if (author_name.title == 'Dr.') | (author_name.title == ''):
        temp_new_author_name = f"{author_name.first} {author_name.middle} {author_name.last}"
    else:
        temp_new_author_name = f"{author_name.title} {author_name.first} {author_name.middle} {author_name.last}"

    new_author_name = " ".join(temp_new_author_name.split())

    author_names = new_author_name.split(" ")
    
    if (author_name.title != '') : 
        final_author_name = new_author_name
    else:
        if len(author_names) == 1:
            final_author_name = new_author_name
        elif len(author_names) == 2:
            if (len(author_names[1]) == 1) & (len(author_names[0]) > 3):
                final_author_name = f"{author_names[1]} {author_names[0]}"
            elif (len(author_names[1]) == 2) & (len(author_names[0]) > 3):
                if (author_names[1][1]=="."):
                    final_author_name = f"{author_names[1]} {author_names[0]}"
                else:
                    final_author_name = new_author_name
            else:
                final_author_name = new_author_name
        elif len(author_names) == 3:
            if (len(author_names[1]) == 1) & (len(author_names[2]) == 1) & (len(author_names[0]) > 3):
                final_author_name = f"{author_names[1]} {author_names[2]} {author_names[0]}"
            elif (len(author_names[1]) == 2) & (len(author_names[2]) == 2) & (len(author_names[0]) > 3):
                if (author_names[1][1]==".") & (author_names[2][1]=="."):
                    final_author_name = f"{author_names[1]} {author_names[2]} {author_names[0]}"
                else:
                    final_author_name = new_author_name
            else:
                final_author_name = new_author_name
        elif len(author_names) == 4:
            if (len(author_names[1]) == 1) & (len(author_names[2]) == 1) & (len(author_names[3]) == 1) & (len(author_names[0]) > 3):
                final_author_name = f"{author_names[1]} {author_names[2]} {author_names[3]} {author_names[0]}"
            elif (len(author_names[1]) == 2) & (len(author_names[2]) == 2) & (len(author_names[3]) == 2) & (len(author_names[0]) > 3):
                if (author_names[1][1]==".") & (author_names[2][1]==".") & (author_names[3][1]=="."):
                    final_author_name = f"{author_names[1]} {author_names[2]} {author_names[3]} {author_names[0]}"
                else:
                    final_author_name = new_author_name
            else:
                final_author_name = new_author_name
        else:
            final_author_name = new_author_name
    return final_author_name


@udf(returnType=ArrayType(StringType()))
def transform_coauthors(coauthors):
    return [transform_author_name_reg(x) for x in coauthors]

def transform_author_name_reg(author):
    if author.startswith("None "):
        author = author.replace("None ", "")
    elif author.startswith("Array "):
        author = author.replace("Array ", "")

    author = unicodedata.normalize('NFKC', author)
    
    author_name = HumanName(" ".join(author.split()))

    if (author_name.title == 'Dr.') | (author_name.title == ''):
        temp_new_author_name = f"{author_name.first} {author_name.middle} {author_name.last}"
    else:
        temp_new_author_name = f"{author_name.title} {author_name.first} {author_name.middle} {author_name.last}"

    new_author_name = " ".join(temp_new_author_name.split())

    author_names = new_author_name.split(" ")
    
    if (author_name.title != '') : 
        final_author_name = new_author_name
    else:
        if len(author_names) == 1:
            final_author_name = new_author_name
        elif len(author_names) == 2:
            if (len(author_names[1]) == 1) & (len(author_names[0]) > 3):
                final_author_name = f"{author_names[1]} {author_names[0]}"
            elif (len(author_names[1]) == 2) & (len(author_names[0]) > 3):
                if (author_names[1][1]=="."):
                    final_author_name = f"{author_names[1]} {author_names[0]}"
                else:
                    final_author_name = new_author_name
            else:
                final_author_name = new_author_name
        elif len(author_names) == 3:
            if (len(author_names[1]) == 1) & (len(author_names[2]) == 1) & (len(author_names[0]) > 3):
                final_author_name = f"{author_names[1]} {author_names[2]} {author_names[0]}"
            elif (len(author_names[1]) == 2) & (len(author_names[2]) == 2) & (len(author_names[0]) > 3):
                if (author_names[1][1]==".") & (author_names[2][1]=="."):
                    final_author_name = f"{author_names[1]} {author_names[2]} {author_names[0]}"
                else:
                    final_author_name = new_author_name
            else:
                final_author_name = new_author_name
        elif len(author_names) == 4:
            if (len(author_names[1]) == 1) & (len(author_names[2]) == 1) & (len(author_names[3]) == 1) & (len(author_names[0]) > 3):
                final_author_name = f"{author_names[1]} {author_names[2]} {author_names[3]} {author_names[0]}"
            elif (len(author_names[1]) == 2) & (len(author_names[2]) == 2) & (len(author_names[3]) == 2) & (len(author_names[0]) > 3):
                if (author_names[1][1]==".") & (author_names[2][1]==".") & (author_names[3][1]=="."):
                    final_author_name = f"{author_names[1]} {author_names[2]} {author_names[3]} {author_names[0]}"
                else:
                    final_author_name = new_author_name
            else:
                final_author_name = new_author_name
        else:
            final_author_name = new_author_name
    return final_author_name

@udf(returnType=ArrayType(StringType()))  
def remove_current_author(author, coauthors):
    return [x for x in coauthors if x!=author][:250]

@udf(returnType=ArrayType(StringType()))
def transform_list_col_for_nulls_string(col_with_nulls):
    if isinstance(col_with_nulls, list):
        return col_with_nulls
    else:
        return []

@udf(returnType=ArrayType(LongType()))
def transform_list_col_for_nulls_long(col_with_nulls):
    if isinstance(col_with_nulls, list):
        return col_with_nulls
    else:
        return []

@udf(returnType=ArrayType(StringType()))
def remove_current_author(author, coauthors):
    return [x for x in coauthors if x!=author][:250]

@udf(returnType=ArrayType(StringType()))
def coauthor_transform(coauthors):
    final_coauthors = []
    skip_list = [" ", "," ,"." ,"-" ,":" ,"/"]

    for coauthor in coauthors:
        split_coauthor = coauthor.split(" ")
        if len(split_coauthor) > 1:
            temp_coauthor = f"{split_coauthor[0][0]}_{split_coauthor[-1]}".lower()
            final_coauthors.append("".join([i for i in temp_coauthor if i not in skip_list]))
        else:
            final_coauthors.append("".join([i for i in coauthor if i not in skip_list]))

    return list(set(final_coauthors))

@udf(returnType=StringType())
def get_orcid_from_list(orcid_list):
    if isinstance(orcid_list, list):
        if orcid_list:
            orcid = orcid_list[0]
        else:
            orcid = ''
    elif isinstance(orcid_list, set):
        orcid_list = list(orcid_list)
        if orcid_list:
            orcid = orcid_list[0]
        else:
            orcid = ''
    else:
        orcid = ''
    return orcid

def length_greater_than_6(x):
    return (F.length(x) > 6)

def concept_L0_removed(x):
    return ~x.isin([17744445,138885662,162324750,144133560,15744967,33923547,71924100,86803240,41008148,127313418,185592680,142362112,144024400,127413603,205649164,95457728,192562407,121332964,39432304])

In [None]:
curr_author_table = spark.read.parquet(f"{prod_save_path}/current_authors_table/")

In [None]:
authors_table_last_date = curr_author_table.select(F.max('modified_date')).collect()[0][0]

need to add the filter for new data above the last modified date of authors table (replace join with date filter)

In [None]:
w1 = Window.partitionBy('work_author_id').orderBy(F.col('name_len').desc())

(spark.read
    .parquet(f"<S3path>")
    .select('work_author_id', F.trim(F.col('original_author')).alias('original_author'), 'orcid', 'concepts', 'institutions', 
            'citations', 'coauthors', 'created_date', 'partition')
    .filter(F.col('original_author').isNotNull())
    .filter(F.col('original_author')!='')
    .withColumn('name_len', F.length(F.col('original_author')))
    .withColumn('rank', F.row_number().over(w1))
    .filter(F.col('rank')==1)
    .join(curr_author_table.select('work_author_id'), how='leftanti', on='work_author_id')
    .withColumn('citations', transform_list_col_for_nulls_long(F.col('citations')))
    .withColumn('coauthors', transform_list_col_for_nulls_string(F.col('coauthors')))
    .withColumn('concepts', transform_list_col_for_nulls_long(F.col('concepts')))
    .withColumn('institutions', transform_list_col_for_nulls_long(F.col('institutions')))
    .withColumn('author', transform_author_name(F.col('original_author')))
    .withColumn('coauthors', transform_coauthors(F.col('coauthors')))
    .withColumn('coauthors', remove_current_author(F.col('author'),F.col('coauthors')))
    .withColumn('coauthors', coauthor_transform(F.col('coauthors')))
    .withColumn('orcid', F.when(F.col('orcid').isNull(), '').otherwise(F.col('orcid')))
    .withColumn('paper_id', F.split(F.col('work_author_id'), "_").getItem(0).cast(LongType()))
    .withColumn('concepts', F.array_distinct(F.col('concepts')))
    .withColumn('concepts_shorter', F.filter(F.col('concepts'), concept_L0_removed))
    .withColumn('coauthors_shorter', F.filter(F.col('coauthors'), length_greater_than_6))
    .select('work_author_id','paper_id','original_author','author','orcid','coauthors_shorter','concepts_shorter',
        'institutions','citations','created_date')
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/new_data_to_disambiguate/"))

In [None]:
new_data_size = spark.read.parquet(f"{temp_save_path}/new_data_to_disambiguate/").count()
new_data_size

In [None]:
if new_data_size > 0:
    print("NEW ROWS TO DISAMBIGUATE")
    pass
else:
    print("NO NEW DATA")

#### Functions

In [None]:
@udf(returnType=IntegerType())
def get_random_int_udf(block_id):
    return random.randint(0, 1000000)

def length_greater_than_6(x):
    return (F.length(x) > 6)

def concept_L0_removed(x):
    return ~x.isin([17744445,138885662,162324750,144133560,15744967,33923547,71924100,86803240,41008148,127313418,185592680,142362112,144024400,127413603,205649164,95457728,192562407,121332964,39432304])

@udf(returnType=StringType())
def only_get_last(all_names):
    all_names = all_names.split(" ")
    if len(all_names) > 1:
        return all_names[-1]
    else:
        return all_names[0]
    
@udf (returnType=ArrayType(ArrayType(StringType())))
def score_data(full_arr):
    full_arr = np.array(full_arr)
    data_arr = full_arr[:,2:].astype('float')
    block_arr = full_arr[:,0]
    label_arr = full_arr[:,1]
    model_preds = broadcast_disambiguator_model.value.predict_proba(data_arr)[:,1]
    return np.vstack([block_arr[model_preds>0.05], label_arr[model_preds>0.05], model_preds[model_preds>0.05].astype('str')]).T.tolist()

@udf(returnType=StringType())
def get_starting_letter(names):
    temp_letters = [x[0] for x in names.split(" ") if x]
    return temp_letters[0] if temp_letters else ""

In [None]:
@udf(returnType=ArrayType(StringType()))
def group_non_latin_characters(text):
    groups = []
    text = text.replace(".", "").replace(" ", "")
    for char in text:
        try:
            script = unicodedata.name(char).split(" ")[0]
            if script == 'LATIN':
                pass
            else:
                if script not in groups:
                    groups.append(script)
        except:
            if "UNK" not in groups:
                groups.append("UNK")
    return groups

@udf(returnType=IntegerType())
def name_to_keep_ind(groups):
    groups_to_skip = ['HIRAGANA', 'CJK', 'KATAKANA','ARABIC', 'HANGUL', 'THAI','DEVANAGARI','BENGALI',
                      'THAANA','GUJARATI']
    
    if any(x in groups_to_skip for x in groups):
        return 0
    else:
        return 1

In [None]:
@udf(returnType=IntegerType())
def check_block_vs_block(block_1_names_list, block_2_names_list):
    
    # check first names
    first_check, _ = match_block_names(block_1_names_list[0], block_1_names_list[1], block_2_names_list[0], 
                                    block_2_names_list[1])
    # print(f"FIRST {first_check}")
    
    if first_check:
        last_check, _ = match_block_names(block_1_names_list[-2], block_1_names_list[-1], block_2_names_list[-2], 
                                           block_2_names_list[-1])
        # print(f"LAST {last_check}")
        if last_check:
            m1_check, more_to_go = match_block_names(block_1_names_list[2], block_1_names_list[3], block_2_names_list[2], 
                                           block_2_names_list[3])
            if m1_check:
                if not more_to_go:
                    return 1
                m2_check, more_to_go = match_block_names(block_1_names_list[4], block_1_names_list[5], block_2_names_list[4], 
                                                block_2_names_list[5])
                
                if m2_check:
                    if not more_to_go:
                        return 1
                    m3_check, more_to_go = match_block_names(block_1_names_list[6], block_1_names_list[7], block_2_names_list[6], 
                                                block_2_names_list[7])
                    if m3_check:
                        if not more_to_go:
                            return 1
                        m4_check, more_to_go = match_block_names(block_1_names_list[8], block_1_names_list[8], block_2_names_list[8], 
                                                block_2_names_list[9])
                        if m4_check:
                            if not more_to_go:
                                return 1
                            m5_check, _ = match_block_names(block_1_names_list[10], block_1_names_list[11], block_2_names_list[10], 
                                                block_2_names_list[11])
                            if m5_check:
                                return 1
                            else:
                                return 0
                        else:
                            return 0
                    else:
                        return 0
                else:
                    return 0
            else:
                return 0
        else:
            return 0
    else:
        swap_check = check_if_last_name_swapped_to_front_creates_match(block_1_names_list, block_2_names_list)
        # print(f"SWAP {swap_check}")
        if swap_check:
            return 1
        else:
            return 0
        
def get_name_from_name_list(name_list):
    name = []
    for i in range(0,12,2):
        if name_list[i]:
            name.append(name_list[i][0])
        elif name_list[i+1]:
            name.append(name_list[i+1][0])
        else:
            break
    if name_list[-2]:
        name.append(name_list[-2][0])
    elif name_list[-1]:
        name.append(name_list[-1][0])
    else:
        pass

    return name
        
def check_if_last_name_swapped_to_front_creates_match(block_1, block_2):
    name_1 = get_name_from_name_list(block_1)
    if len(name_1) != 2:
        return False
    else:
        name_2 = get_name_from_name_list(block_2)
        if len(name_2)==2:
            if " ".join(name_1) == " ".join(name_2[-1:] + name_2[:-1]):
                return True
            else:
                return False
        else:
            return False
    
def match_block_names(block_1_names, block_1_initials, block_2_names, block_2_initials):
    if block_1_names and block_2_names:
        if any(x in block_1_names for x in block_2_names):
            return True, True
        else:
            return False, True
    elif block_1_names and not block_2_names:
        if block_2_initials:
            if any(x in block_1_initials for x in block_2_initials):
                return True, True
            else:
                return False, True
        else:
            return True, True
    elif not block_1_names and block_2_names:
        if block_1_initials:
            if any(x in block_1_initials for x in block_2_initials):
                return True, True
            else:
                return False, True
        else:
            return True, True
    elif block_1_initials and block_2_initials:
        if any(x in block_1_initials for x in block_2_initials):
            return True, True
        else:
            return False, True
    else:
        return True, False

@udf(returnType=ArrayType(ArrayType(StringType())))
def get_name_match_list(name):
    name_split_1 = name.replace("-", "").split()
    name_split_2 = ""
    if "-" in name:
        name_split_2 = name.replace("-", " ").split()

    fn = []
    fni = []
    
    m1 = []
    m1i = []
    m2 = []
    m2i = []
    m3 = []
    m3i = []
    m4 = []
    m4i = []
    m5 = []
    m5i = []

    ln = []
    lni = []
    for name_split in [name_split_1, name_split_2]:
        if len(name_split) == 0:
            pass
        elif len(name_split) == 1:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[0]) > 1:
                ln.append(name_split[0])
                lni.append(name_split[0][0])
            else:
                lni.append(name_split[0][0])
            
        elif len(name_split) == 2:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 3:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 4:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 5:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])
                
            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 6:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])
            
            if len(name_split[4]) > 1:
                m4.append(name_split[4])
                m4i.append(name_split[4][0])
            else:
                m4i.append(name_split[4][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 7:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])
            
            if len(name_split[4]) > 1:
                m4.append(name_split[4])
                m4i.append(name_split[4][0])
            else:
                m4i.append(name_split[4][0])

            if len(name_split[5]) > 1:
                m5.append(name_split[5])
                m5i.append(name_split[5][0])
            else:
                m5i.append(name_split[5][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        else:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])
                
            if len(name_split[4]) > 1:
                m4.append(name_split[4])
                m4i.append(name_split[4][0])
            else:
                m4i.append(name_split[4][0])

            joined_names = " ".join(name_split[5:-1])
            m5.append(joined_names)
            m5i.append(joined_names[0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
            

    return [list(set(x)) for x in [fn,fni,m1,m1i,m2,m2i,m3,m3i,m4,m4i,m5,m5i,ln,lni]]

@udf(returnType=StringType())
def transform_author_name(author):
    if author.startswith("None "):
        author = author.replace("None ", "")
    elif author.startswith("Array "):
        author = author.replace("Array ", "")

    author = unicodedata.normalize('NFKC', author)
    
    author_name = HumanName(" ".join(author.split()))

    if (author_name.title == 'Dr.') | (author_name.title == ''):
        temp_new_author_name = f"{author_name.first} {author_name.middle} {author_name.last}"
    else:
        temp_new_author_name = f"{author_name.title} {author_name.first} {author_name.middle} {author_name.last}"

    new_author_name = " ".join(temp_new_author_name.split())

    author_names = new_author_name.split(" ")
    
    if (author_name.title != '') : 
        final_author_name = new_author_name
    else:
        if len(author_names) == 1:
            final_author_name = new_author_name
        elif len(author_names) == 2:
            if (len(author_names[1]) == 1) & (len(author_names[0]) > 3):
                final_author_name = f"{author_names[1]} {author_names[0]}"
            elif (len(author_names[1]) == 2) & (len(author_names[0]) > 3):
                if (author_names[1][1]=="."):
                    final_author_name = f"{author_names[1]} {author_names[0]}"
                else:
                    final_author_name = new_author_name
            else:
                final_author_name = new_author_name
        elif len(author_names) == 3:
            if (len(author_names[1]) == 1) & (len(author_names[2]) == 1) & (len(author_names[0]) > 3):
                final_author_name = f"{author_names[1]} {author_names[2]} {author_names[0]}"
            elif (len(author_names[1]) == 2) & (len(author_names[2]) == 2) & (len(author_names[0]) > 3):
                if (author_names[1][1]==".") & (author_names[2][1]=="."):
                    final_author_name = f"{author_names[1]} {author_names[2]} {author_names[0]}"
                else:
                    final_author_name = new_author_name
            else:
                final_author_name = new_author_name
        elif len(author_names) == 4:
            if (len(author_names[1]) == 1) & (len(author_names[2]) == 1) & (len(author_names[3]) == 1) & (len(author_names[0]) > 3):
                final_author_name = f"{author_names[1]} {author_names[2]} {author_names[3]} {author_names[0]}"
            elif (len(author_names[1]) == 2) & (len(author_names[2]) == 2) & (len(author_names[3]) == 2) & (len(author_names[0]) > 3):
                if (author_names[1][1]==".") & (author_names[2][1]==".") & (author_names[3][1]=="."):
                    final_author_name = f"{author_names[1]} {author_names[2]} {author_names[3]} {author_names[0]}"
                else:
                    final_author_name = new_author_name
            else:
                final_author_name = new_author_name
        else:
            final_author_name = new_author_name
    return final_author_name

@udf(returnType=ArrayType(StringType()))  
def remove_current_author(author, coauthors):
    return [x for x in coauthors if x!=author][:250]

@udf(returnType=StringType())
def transform_name_for_search(name):
    name = unidecode.unidecode(unicodedata.normalize('NFKC', name))
    name = name.lower().replace(" ", " ").replace(".", " ").replace(",", " ").replace("|", " ").replace(")", "").replace("(", "")\
        .replace("-", "").replace("&", "").replace("$", "").replace("#", "").replace("@", "").replace("%", "").replace("0", "") \
        .replace("1", "").replace("2", "").replace("3", "").replace("4", "").replace("5", "").replace("6", "").replace("7", "") \
        .replace("8", "").replace("9", "").replace("*", "").replace("^", "").replace("{", "").replace("}", "").replace("+", "") \
        .replace("=", "").replace("_", "").replace("~", "").replace("`", "").replace("[", "").replace("]", "").replace("\\", "") \
        .replace("<", "").replace(">", "").replace("?", "").replace("/", "").replace(";", "").replace(":", "").replace("\'", "") \
        .replace("\"", "")
    name = " ".join(name.split())
    return name

@udf(returnType=ArrayType(ArrayType(StringType())))
def create_author_name_list_from_list(name_lists):
    if not isinstance(name_lists, list):
        name_lists = name_lists.tolist()
    
    name_list_len = len(name_lists[0])
    
    temp_name_list = [[j[i] for j in name_lists] for i in range(name_list_len)]
    temp_name_list_2 = [[j[0] for j in i if j] for i in temp_name_list]
    
    return [list(set(x)) for x in temp_name_list_2]

@udf(returnType=ArrayType(ArrayType(StringType())))
def get_name_match_from_alternate_names(alt_names):
    trans_names = list(set([transform_name_for_search_reg(x) for x in alt_names]))
    name_lists = [get_name_match_list_reg(x) for x in trans_names]
    return create_author_name_list_from_list_reg(name_lists)

def create_author_name_list_from_list_reg(name_lists):
    if not isinstance(name_lists, list):
        name_lists = name_lists.tolist()
    
    name_list_len = len(name_lists[0])
    
    temp_name_list = [[j[i] for j in name_lists] for i in range(name_list_len)]
    temp_name_list_2 = [[j[0] for j in i if j] for i in temp_name_list]
    
    return [list(set(x)) for x in temp_name_list_2]

def transform_name_for_search_reg(name):
    name = unidecode.unidecode(unicodedata.normalize('NFKC', name))
    name = name.lower().replace(" ", " ").replace(".", " ").replace(",", " ").replace("|", " ").replace(")", "").replace("(", "")\
        .replace("-", "").replace("&", "").replace("$", "").replace("#", "").replace("@", "").replace("%", "").replace("0", "") \
        .replace("1", "").replace("2", "").replace("3", "").replace("4", "").replace("5", "").replace("6", "").replace("7", "") \
        .replace("8", "").replace("9", "").replace("*", "").replace("^", "").replace("{", "").replace("}", "").replace("+", "") \
        .replace("=", "").replace("_", "").replace("~", "").replace("`", "").replace("[", "").replace("]", "").replace("\\", "") \
        .replace("<", "").replace(">", "").replace("?", "").replace("/", "").replace(";", "").replace(":", "").replace("\'", "") \
        .replace("\"", "")
    name = " ".join(name.split())
    return name

def get_name_match_list_reg(name):
    name_split_1 = name.replace("-", "").split()
    name_split_2 = ""
    if "-" in name:
        name_split_2 = name.replace("-", " ").split()

    fn = []
    fni = []
    
    m1 = []
    m1i = []
    m2 = []
    m2i = []
    m3 = []
    m3i = []
    m4 = []
    m4i = []
    m5 = []
    m5i = []

    ln = []
    lni = []
    for name_split in [name_split_1, name_split_2]:
        if len(name_split) == 0:
            pass
        elif len(name_split) == 1:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[0]) > 1:
                ln.append(name_split[0])
                lni.append(name_split[0][0])
            else:
                lni.append(name_split[0][0])
            
        elif len(name_split) == 2:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 3:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 4:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 5:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])
                
            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 6:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])
            
            if len(name_split[4]) > 1:
                m4.append(name_split[4])
                m4i.append(name_split[4][0])
            else:
                m4i.append(name_split[4][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 7:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])
            
            if len(name_split[4]) > 1:
                m4.append(name_split[4])
                m4i.append(name_split[4][0])
            else:
                m4i.append(name_split[4][0])

            if len(name_split[5]) > 1:
                m5.append(name_split[5])
                m5i.append(name_split[5][0])
            else:
                m5i.append(name_split[5][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        else:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])
                
            if len(name_split[4]) > 1:
                m4.append(name_split[4])
                m4i.append(name_split[4][0])
            else:
                m4i.append(name_split[4][0])

            joined_names = " ".join(name_split[5:-1])
            m5.append(joined_names)
            m5i.append(joined_names[0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
            

    return [list(set(x)) for x in [fn,fni,m1,m1i,m2,m2i,m3,m3i,m4,m4i,m5,m5i,ln,lni]]

@udf(returnType=StringType())
def get_most_frequent_name(x):
    return mode(x)

@udf(returnType=StringType())
def get_unique_orcid_for_author_table(list_of_orcids):
    if not isinstance(list_of_orcids, list):
        try:
            list_of_orcids = list_of_orcids.tolist()
        except:
            list_of_orcids = list(list_of_orcids)
        
    orcids = [x for x in list_of_orcids if x]
    
    if orcids:
        return orcids[0]
    else:
        return ""
    
@udf(returnType=IntegerType())
def check_for_unique_orcid_live_clustering(list_of_orcids):
    if not isinstance(list_of_orcids, list):
        try:
            list_of_orcids = list_of_orcids.tolist()
        except:
            list_of_orcids = list(list_of_orcids)
        
    orcids = [x for x in list_of_orcids if x]
    
    if len(orcids) > 1:
        return 0
    else:
        return 1

In [None]:
def get_data_features_scored(df, prefix):
    df \
        .withColumn('row_label', F.concat_ws("|", F.col('work_author_id'), F.col('work_author_id_2'))) \
        .withColumn('work_in_citations_2', F.array_contains(F.col('citations_2'), F.col('paper_id')).cast(IntegerType())) \
        .withColumn('work_2_in_citations', F.array_contains(F.col('citations'), F.col('paper_id_2')).cast(IntegerType())) \
        .withColumn('citation_work_match', F.when((F.col('work_2_in_citations')==1) | 
                                                  (F.col('work_in_citations_2')==1), 1).otherwise(0)) \
        .withColumn('insts_inter', F.size(F.array_intersect(F.col('institutions'), F.col('institutions_2')))) \
        .withColumn('coauths_inter', F.size(F.array_intersect(F.col('coauthors_shorter'), F.col('coauthors_shorter_2')))) \
        .withColumn('concps_inter', F.size(F.array_intersect(F.col('concepts_shorter'), F.col('concepts_shorter_2')))) \
        .withColumn('cites_inter', F.size(F.array_intersect(F.col('citations'), F.col('citations_2')))) \
        .withColumn('coauths_union', F.size(F.array_union(F.col('coauthors_shorter'), F.col('coauthors_shorter_2')))) \
        .withColumn('concps_union', F.size(F.array_union(F.col('concepts_shorter'), F.col('concepts_shorter_2')))) \
        .withColumn('cites_union', F.size(F.array_union(F.col('citations'), F.col('citations_2')))) \
        .withColumn('inst_per', F.when(F.col('insts_inter')>0, 1).otherwise(0)) \
        .withColumn('coauthors_shorter_per', F.round(F.when(F.col('coauths_union')>0, 
                                                            F.col('coauths_inter')/F.col('coauths_union')).otherwise(0.0), 4)) \
        .withColumn('concepts_shorter_per', F.round(F.when(F.col('concps_union')>0, 
                                                           F.col('concps_inter')/F.col('concps_union')).otherwise(0.0), 4)) \
        .withColumn('citation_per', F.round(F.when(F.col('cites_union')>0, 
                                                   F.col('cites_inter')/F.col('cites_union')).otherwise(0.0), 4)) \
        .withColumn('exact_match', F.when(F.col('author')==F.col('author_2'), 1).otherwise(0)) \
        .withColumn('name_len', F.length(F.col('author'))) \
        .withColumn('name_spaces', F.size(F.split(F.col('author'), " "))) \
        .select(F.col('work_author_id').alias('block'),'row_label', 'inst_per','concepts_shorter_per', 'coauthors_shorter_per', 
            (F.col('exact_match')*F.col('name_len')).alias('exact_match_len'),
            (F.col('exact_match')*F.col('name_spaces')).alias('exact_match_spaces'), 'citation_per', 'citation_work_match') \
        .write.mode('overwrite') \
        .parquet(f"{temp_save_path}{prefix}all_features/")

    print('features saved: ', spark.read.parquet(f"{temp_save_path}{prefix}all_features/").count())
        
    spark.read.parquet(f"{temp_save_path}{prefix}all_features/")\
        .withColumn('random_int', get_random_int_udf(F.col('block'))) \
        .withColumn('concat_cols', F.array(F.col('block'), F.col('row_label').cast(StringType()), 
                                            F.col('inst_per').cast(StringType()), 
                                            F.col('concepts_shorter_per').cast(StringType()), 
                                            F.col('coauthors_shorter_per').cast(StringType()), 
                                            F.col('exact_match_len').cast(StringType()), 
                                            F.col('exact_match_spaces').cast(StringType()), 
                                            F.col('citation_per').cast(StringType()), 
                                            F.col('citation_work_match').cast(StringType()))) \
        .groupby('random_int') \
        .agg(F.collect_list(F.col('concat_cols')).alias('data_to_score')) \
        .withColumn('scored_data', score_data(F.col('data_to_score'))) \
        .select('scored_data') \
        .write.mode('overwrite') \
        .parquet(f"{temp_save_path}{prefix}data_scored/")

In [None]:
def live_clustering_algorithm(scored_data_prefix):
    w1 = Window.partitionBy('work_author_id').orderBy(F.col('score').desc())
    w2 = Window.partitionBy('author_id').orderBy(F.col('score').desc())

    
    spark.read.parquet(f"{temp_save_path}{scored_data_prefix}data_scored/") \
        .select(F.explode('scored_data').alias('scored_data')) \
        .select(F.col('scored_data').getItem(0).alias('work_author_id'),
                F.col('scored_data').getItem(1).alias('pairs'), 
                F.col('scored_data').getItem(2).alias('score').cast(FloatType())) \
        .dropDuplicates(subset=['pairs']) \
        .select('work_author_id', 
                F.split(F.col('pairs'), "\|")[1].alias('work_author_id_2'), 
                'score') \
        .repartition(250) \
        .write.mode('overwrite') \
        .parquet(f"{temp_save_path}{scored_data_prefix}flat_scored_data/")
    
    spark.read.parquet(f"{temp_save_path}{scored_data_prefix}flat_scored_data/") \
        .join(all_new_data.select('work_author_id','orcid','author'), 
              how='inner', on='work_author_id') \
        .join(temp_authors_table.select('work_author_id_2','author_id','orcid_2').distinct(), 
              how='inner',on='work_author_id_2') \
        .filter((F.col('orcid')==F.col('orcid_2')) | 
        (F.col('orcid')=='') | 
        (F.col('orcid_2')=='')) \
        .withColumn('rank', F.row_number().over(w1)) \
        .filter(F.col('rank')==1) \
        .write.mode('overwrite') \
        .parquet(f"{temp_save_path}{scored_data_prefix}potential_cluster_matches/")

    pot_cluster_matches = spark.read.parquet(f"{temp_save_path}{scored_data_prefix}potential_cluster_matches/")

    orcids_check = pot_cluster_matches\
        .groupby('author_id')\
        .agg(F.collect_set(F.col('orcid')).alias('orcids')) \
        .withColumn('orcid_good', check_for_unique_orcid_live_clustering('orcids')) \
        .select('author_id','orcid_good') \
        .alias('orcids_check')

    pot_cluster_matches \
        .join(orcids_check.filter(F.col('orcid_good')==1).select('author_id').distinct(), how='inner', on='author_id')\
        .select('work_author_id', 'author_id') \
        .dropDuplicates(subset=['work_author_id']) \
        .write.mode('overwrite') \
        .parquet(f"{temp_save_path}{scored_data_prefix}matched_to_cluster/orcids_good/")

    pot_cluster_matches \
        .join(orcids_check.filter(F.col('orcid_good')==0).select('author_id').distinct(), how='inner', on='author_id')\
        .write.mode('overwrite') \
        .parquet(f"{temp_save_path}{scored_data_prefix}orcids_not_good/")

    spark.read.parquet(f"{temp_save_path}{scored_data_prefix}orcids_not_good/") \
        .withColumn('rank', F.row_number().over(w2)) \
        .filter(F.col('rank')==1) \
        .select('work_author_id', 'author_id') \
        .dropDuplicates(subset=['work_author_id']) \
        .write.mode('overwrite') \
        .parquet(f"{temp_save_path}{scored_data_prefix}matched_to_cluster/orcids_not_good/")

In [None]:
def create_new_features_table(new_rows_location):
    new_rows = spark.read.parquet(f"{temp_save_path}/new_rows_for_author_table/{new_rows_location}/") \
        .dropDuplicates()

    temp_features_table \
        .union(all_new_data.join(new_rows.select('work_author_id').dropDuplicates(), how='inner', on='work_author_id') \
                .select(F.col('work_author_id').alias('work_author_id_2'), 
                        F.col('orcid').alias('orcid_2'),
                        F.col('citations').alias('citations_2'),
                        F.col('institutions').alias('institutions_2'),
                        F.col('author').alias('author_2'),
                        F.col('paper_id').alias('paper_id_2'),
                        'original_author',
                        F.col('concepts_shorter').alias('concepts_shorter_2'),
                        F.col('coauthors_shorter').alias('coauthors_shorter_2'))) \
        .dropDuplicates(subset=['work_author_id_2']) \
        .write.mode('overwrite') \
        .parquet(f"{temp_save_path}/temp_features_table/{new_rows_location}/")

In [None]:
def create_new_author_table(new_rows_location):
    new_rows = spark.read.parquet(f"{temp_save_path}/new_rows_for_author_table/{new_rows_location}/")

    cluster_df = new_rows.union(temp_authors_table.select(F.col('work_author_id_2').alias('work_author_id'), 'author_id'))

    # need to join new rows with features table
    temp_features_table \
        .select(F.col('work_author_id_2').alias('work_author_id'),F.col('orcid_2').alias('orcid'),
                'original_author',F.col('author_2').alias('author')) \
        .join(cluster_df, how='inner', on='work_author_id') \
        .filter(F.col('original_author')!="") \
        .filter(F.col('original_author').isNotNull()) \
        .groupby('author_id') \
        .agg(F.collect_set(F.col('orcid')).alias('orcid'), 
            F.collect_set(F.col('work_author_id')).alias('work_author_id'),
            F.collect_set(F.col('author')).alias('alternate_names'),
            F.collect_set(F.col('author')).alias('names_for_list'),
            F.collect_list(F.col('original_author')).alias('names')) \
        .withColumn('orcid', get_unique_orcid_for_author_table(F.col('orcid'))) \
        .withColumn('display_name', get_most_frequent_name(F.col('names'))) \
        .withColumn('name_match_list', get_name_match_from_alternate_names('names_for_list')) \
        .select(F.explode('work_author_id').alias('work_author_id_2'), 
                'author_id',
                F.col('orcid').alias('orcid_2'), 
                'display_name',
                'alternate_names',
                'name_match_list') \
        .write.mode('overwrite') \
        .parquet(f"{temp_save_path}/temp_authors_table/{new_rows_location}/")

#### Load Transformed New Data

In [None]:
all_new_data = spark.read.parquet(f"{temp_save_path}/new_data_to_disambiguate/") \
    .dropDuplicates(subset=['work_author_id'])
all_new_data.cache().count()

In [None]:
init_cluster_df = spark.read.parquet(f"{prod_save_path}/current_authors_table/")\
    .select('work_author_id',('author_id')) \
    .join(all_new_data.select('work_author_id'), how='leftanti', on='work_author_id')

#### Get current features table

In [None]:
spark.read.parquet(f"{prod_save_path}/current_features_table/") \
    .join(init_cluster_df.select(F.col('work_author_id').alias('work_author_id_2')), how='inner', on='work_author_id_2') \
    .select('work_author_id_2', 'orcid_2', F.col('citations_2').cast(ArrayType(LongType())), 
            'institutions_2', 'author_2', F.col('paper_id_2').cast(LongType()), 'original_author',
            F.col('concepts_shorter_2').cast(ArrayType(LongType())), 'coauthors_shorter_2') \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/temp_features_table/init/")

In [None]:
temp_features_table = spark.read.parquet(f"{temp_save_path}/temp_features_table/init/")

#### Get current authors table

In [None]:
temp_features_table \
    .select(F.col('work_author_id_2').alias('work_author_id'),F.col('orcid_2').alias('orcid'),
            'original_author',F.col('author_2').alias('author')) \
    .join(init_cluster_df, how='inner', on='work_author_id') \
    .filter(F.col('original_author')!="") \
    .filter(F.col('original_author').isNotNull()) \
    .groupby('author_id') \
    .agg(F.collect_set(F.col('orcid')).alias('orcid'), 
         F.collect_set(F.col('work_author_id')).alias('work_author_id'),
         F.collect_set(F.col('author')).alias('alternate_names'),
         F.collect_set(F.col('author')).alias('names_for_list'),
         F.collect_list(F.col('original_author')).alias('names')) \
    .withColumn('orcid', get_unique_orcid_for_author_table(F.col('orcid'))) \
    .withColumn('display_name', get_most_frequent_name(F.col('names'))) \
    .withColumn('name_match_list', get_name_match_from_alternate_names('names_for_list')) \
    .select(F.explode('work_author_id').alias('work_author_id_2'), 
            'author_id',
            F.col('orcid').alias('orcid_2'), 
            'display_name',
            'alternate_names',
            'name_match_list') \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/temp_authors_table/init/")

In [None]:
temp_authors_table = spark.read.parquet(f"{temp_save_path}/temp_authors_table/init/")

#### Get current names matching table

In [None]:
author_names_match = spark.read.parquet(f"{prod_save_path}/current_author_names_match/")

#### Check if work_author_id has been disambiguated before (shows up in final authors table)

In [None]:
spark.read.parquet(f"{prod_save_path}/current_authors_table/")\
    .select('work_author_id','author_id') \
    .join(all_new_data.select('work_author_id'), how='inner', on='work_author_id') \
    .select('work_author_id',F.col('author_id').alias('author_id_old')) \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/temp_authors_to_change_table/")

In [None]:
spark.read.parquet(f"{temp_save_path}/temp_authors_to_change_table/").count()

Need to add logic later on which will perform a set of actions if a work_author is getting disambiguated again

#### Check for merges/changes to the final author table since last update (author merges, work removals/adds)

* for work_author_ids that were removed and need a new cluster, keep a record of clusters it should not match to
* need to make sure that when it goes through disambiguation again, it can't match back up to the same cluster
* for removals, work will be taken away from the author quickly but it will go through the full disambiguation process during the next round (won't be a separate process)

#### Checking if ORCID matches to a cluster with an ORCID

In [None]:
all_new_data.filter(F.col('orcid')!='')\
    .join(temp_authors_table.select(F.col('orcid_2').alias('orcid'),'author_id'),how='inner', on='orcid') \
    .select('work_author_id', 
            'author_id') \
    .dropDuplicates(subset=['work_author_id']) \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/new_rows_for_author_table/orcid_rows_for_author_table/")

In [None]:
new_loc = 'orcid_rows_for_author_table'
_ = create_new_features_table(new_loc)
print("New features table created")
temp_features_table = spark.read.parquet(f"{temp_save_path}/temp_features_table/{new_loc}/")

_ = create_new_author_table(new_loc)
print("New authors table created")
temp_authors_table = spark.read.parquet(f"{temp_save_path}/temp_authors_table/{new_loc}/")

In [None]:
spark.read.parquet(f"{temp_save_path}/new_rows_for_author_table/orcid_rows_for_author_table/").count()

In [None]:
all_new_data \
    .join(temp_authors_table.select(F.col('work_author_id_2').alias('work_author_id')).distinct(), 
                  how='leftanti', on='work_author_id') \
    .select('work_author_id','paper_id','original_author','author','orcid','coauthors_shorter','concepts_shorter',
        'institutions','citations','created_date') \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/round_2_of_clustering/")

#### Check if name has been previously parsed

In [None]:
try:
    work_to_clusters_removed = spark.read.parquet(f"{prod_save_path}/works_removed_from_clusters/")
except:
    work_to_clusters_removed = spark.sparkContext.emptyRDD().toDF(schema=StructType([StructField("work_author_id", StringType()),
                                                                                     StructField("author_id", LongType())]))

In [None]:
round_2_new_data = spark.read.parquet(f"{temp_save_path}/round_2_of_clustering/")

names_match = round_2_new_data \
    .withColumn('paper_id', F.split(F.col('work_author_id'), "_").getItem(0).cast(LongType())) \
    .join(temp_authors_table.select('work_author_id_2', 
                                    'orcid_2', 
                                    'author_id',
                                    F.explode(F.col('alternate_names')).alias('author')),
          how='inner', on='author') \
    .join(work_to_clusters_removed, how='leftanti', on=['work_author_id','author_id']) \
    .filter((F.col('orcid')==F.col('orcid_2')) | 
            (F.col('orcid')=='') | 
            (F.col('orcid_2')=='')) \
    .join(temp_features_table.drop("orcid_2"), how='inner', on='work_author_id_2')

# prepare data for model scoring and score
_ = get_data_features_scored(names_match, "/names_match/")

# send through clustering/matching algorithm
_ = live_clustering_algorithm("/names_match/")

# save new author table rows to file
spark.read.parquet(f"{temp_save_path}/names_match/matched_to_cluster/*") \
    .select('work_author_id', 
            'author_id') \
    .dropDuplicates(subset=['work_author_id']) \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/new_rows_for_author_table/name_match_rows_for_author_table/")

In [None]:
temp_authors_table.groupBy('author_id')\
        .agg(F.collect_set(F.col('orcid_2')).alias('orcids')) \
        .withColumn('orcid_good', check_for_unique_orcid_live_clustering('orcids')) \
        .filter(F.col('orcid_good')==0).count()

In [None]:
new_loc = 'name_match_rows_for_author_table'
_ = create_new_features_table(new_loc)
print("New features table created")
temp_features_table = spark.read.parquet(f"{temp_save_path}/temp_features_table/{new_loc}/")

_ = create_new_author_table(new_loc)
print("New authors table created")
temp_authors_table = spark.read.parquet(f"{temp_save_path}/temp_authors_table/{new_loc}/")

In [None]:
temp_authors_table.groupBy('author_id')\
        .agg(F.collect_set(F.col('orcid_2')).alias('orcids')) \
        .withColumn('orcid_good', check_for_unique_orcid_live_clustering('orcids')) \
        .filter(F.col('orcid_good')==0).count()

In [None]:
all_new_data \
    .join(temp_authors_table.select(F.col('work_author_id_2').alias('work_author_id')).distinct(), 
                  how='leftanti', on='work_author_id') \
    .select('work_author_id','paper_id','original_author','author','orcid','coauthors_shorter','concepts_shorter',
        'institutions','citations','created_date') \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/round_3_of_clustering/")

#### Run "previously-parsed" code one more time

In [None]:
round_3_new_data = spark.read.parquet(f"{temp_save_path}/round_3_of_clustering/")

print(round_3_new_data.count())

names_match_2 = round_3_new_data \
    .withColumn('paper_id', F.split(F.col('work_author_id'), "_").getItem(0).cast(LongType())) \
    .join(temp_authors_table.select('work_author_id_2', 
                                    'orcid_2', 
                                    'author_id',
                                    F.explode(F.col('alternate_names')).alias('author')),
          how='inner', on='author') \
    .join(work_to_clusters_removed, how='leftanti', on=['work_author_id','author_id']) \
    .filter((F.col('orcid')==F.col('orcid_2')) | 
            (F.col('orcid')=='') | 
            (F.col('orcid_2')=='')) \
    .join(temp_features_table.drop("orcid_2"), how='inner', on='work_author_id_2')

# prepare data for model scoring and score
_ = get_data_features_scored(names_match_2, "/names_match_2/")

# send through clustering/matching algorithm
_ = live_clustering_algorithm("/names_match_2/")

# save new author table rows to file
spark.read.parquet(f"{temp_save_path}/names_match_2/matched_to_cluster/*") \
    .select('work_author_id', 
            'author_id') \
    .dropDuplicates(subset=['work_author_id']) \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/new_rows_for_author_table/name_match_rows_for_author_table_2/")

In [None]:
spark.read.parquet(f"{temp_save_path}/names_match_2/matched_to_cluster/*").count()

In [None]:
new_loc = 'name_match_rows_for_author_table_2'
_ = create_new_features_table(new_loc)
print("New features table created")
temp_features_table = spark.read.parquet(f"{temp_save_path}/temp_features_table/{new_loc}/")

_ = create_new_author_table(new_loc)
print("New authors table created")
temp_authors_table = spark.read.parquet(f"{temp_save_path}/temp_authors_table/{new_loc}/")

In [None]:
temp_features_table.count()

In [None]:
temp_authors_table.count()

In [None]:
temp_authors_table.groupBy('author_id')\
        .agg(F.collect_set(F.col('orcid_2')).alias('orcids')) \
        .withColumn('orcid_good', check_for_unique_orcid_live_clustering('orcids')) \
        .filter(F.col('orcid_good')==0).count()

In [None]:
all_new_data \
    .join(temp_authors_table.select(F.col('work_author_id_2').alias('work_author_id')).distinct(), 
                  how='leftanti', on='work_author_id') \
    .select('work_author_id','paper_id','original_author','author','orcid','coauthors_shorter','concepts_shorter',
        'institutions','citations','created_date') \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/round_4_of_clustering/")

#### Not previously parsed

In [None]:
spark.read.parquet(f"{temp_save_path}/round_4_of_clustering/").count()

In [None]:
round_4_new_data = spark.read.parquet(f"{temp_save_path}/round_4_of_clustering/") \
    .filter(F.col('author').isNotNull()) \
    .filter(F.col('author')!='') \
    .withColumn('non_latin_groups', group_non_latin_characters(F.col('author'))) \
    .withColumn('name_to_keep_ind', name_to_keep_ind('non_latin_groups'))

In [None]:
round_4_new_data.count()

In [None]:
round_4_new_data \
    .filter(F.col('name_to_keep_ind')==1) \
    .withColumn('transformed_search_name', transform_name_for_search(F.col('author'))) \
    .withColumn('name_len', F.length(F.col('transformed_search_name'))) \
    .filter(F.col('name_len')>1) \
    .withColumn('name_match_list', get_name_match_list(F.col('transformed_search_name'))) \
    .withColumn('block', only_get_last(F.col('transformed_search_name'))) \
    .select('work_author_id','orcid','name_match_list','transformed_search_name', 'block') \
    .withColumn('block_removed', F.expr("regexp_replace(transformed_search_name, block, '')")) \
    .withColumn('new_block_removed', F.trim(F.expr("regexp_replace(block_removed, '  ', ' ')"))) \
    .withColumn('letter', get_starting_letter(F.col('new_block_removed'))) \
    .select('work_author_id','orcid','name_match_list','transformed_search_name','letter', 'block') \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/for_new_authors_table/names_to_blocks/")

In [None]:
no_names_match = spark.read.parquet(f"{temp_save_path}/for_new_authors_table/names_to_blocks/")
no_names_match.cache().count()

In [None]:
# join those names to authors table alternate names to get work_author_ids to check
full_no_names_match_table = no_names_match \
    .join(author_names_match, how='inner', on=['block','letter']) \
    .withColumn('matched_names', check_block_vs_block(F.col('name_match_list'), F.col('name_match_list_2'))) \
    .filter(F.col('matched_names')==1) \
    .select('work_author_id', 'work_author_id_2') \
    .dropDuplicates() \
    .join(round_4_new_data, how='inner', on='work_author_id') \
    .join(temp_features_table, how='inner', on='work_author_id_2') \
    .filter((F.col('orcid')==F.col('orcid_2')) | 
        (F.col('orcid')=='') | 
        (F.col('orcid_2')==''))

In [None]:
# prepare data for model scoring and score
_ = get_data_features_scored(full_no_names_match_table, "/no_names_match/")

In [None]:
# send through clustering/matching algorithm
_ = live_clustering_algorithm("/no_names_match/")

In [None]:
# save new author table rows to file
spark.read.parquet(f"{temp_save_path}/no_names_match/matched_to_cluster/*") \
    .select('work_author_id', 
            'author_id') \
    .dropDuplicates(subset=['work_author_id']) \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/new_rows_for_author_table/no_name_match_rows_for_author_table/")

In [None]:
spark.read.parquet(f"{temp_save_path}/new_rows_for_author_table/no_name_match_rows_for_author_table/").count()

In [None]:
new_loc = 'no_name_match_rows_for_author_table'
_ = create_new_features_table(new_loc)
print("New features table created")
temp_features_table = spark.read.parquet(f"{temp_save_path}/temp_features_table/{new_loc}/")

_ = create_new_author_table(new_loc)
print("New authors table created")
temp_authors_table = spark.read.parquet(f"{temp_save_path}/temp_authors_table/{new_loc}/")

In [None]:
temp_features_table.count()

In [None]:
temp_authors_table.count()

In [None]:
temp_authors_table.groupBy('author_id')\
        .agg(F.collect_set(F.col('orcid_2')).alias('orcids')) \
        .withColumn('orcid_good', check_for_unique_orcid_live_clustering('orcids')) \
        .filter(F.col('orcid_good')==0).count()

#### Save all non-clustered data to new location

In [None]:
all_new_data \
    .join(temp_authors_table.select(F.col('work_author_id_2').alias('work_author_id')).distinct(), 
                  how='leftanti', on='work_author_id') \
    .select('work_author_id','paper_id','original_author','author','orcid','coauthors_shorter','concepts_shorter',
        'institutions','citations','created_date') \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/end_of_clustering_leftovers/")

In [None]:
spark.read.parquet(f"{temp_save_path}/end_of_clustering_leftovers/").count()

In [None]:
max_id = int(temp_authors_table.select(F.max(F.col('author_id'))).collect()[0][0])

In [None]:
max_id

#### Give non-clustered data an author ID (new single cluster)

In [None]:
w1 = Window.orderBy(F.col('work_author_id'))

spark.read.parquet(f"{temp_save_path}/end_of_clustering_leftovers/") \
    .select('work_author_id').distinct() \
    .withColumn('temp_cluster_num', F.row_number().over(w1)) \
    .withColumn('author_id', F.lit(max_id) + F.col('temp_cluster_num')) \
    .select('work_author_id','author_id') \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/new_rows_for_author_table/new_author_clusters/")

In [None]:
spark.read.parquet(f"{temp_save_path}/new_rows_for_author_table/new_author_clusters/").count()

In [None]:
new_loc = 'new_author_clusters'
_ = create_new_features_table(new_loc)
print("New features table created")
temp_features_table = spark.read.parquet(f"{temp_save_path}/temp_features_table/{new_loc}/")

_ = create_new_author_table(new_loc)
print("New authors table created")
temp_authors_table = spark.read.parquet(f"{temp_save_path}/temp_authors_table/{new_loc}/")

In [None]:
temp_authors_table.groupBy('author_id')\
        .agg(F.collect_set(F.col('orcid_2')).alias('orcids')) \
        .withColumn('orcid_good', check_for_unique_orcid_live_clustering('orcids')) \
        .filter(F.col('orcid_good')==0).count()

In [None]:
temp_features_table.count()

In [None]:
temp_authors_table.count()

### Generate final authors table to write to S3 and postgres

In [None]:
@udf(returnType=IntegerType())
def check_list_vs_list(list_1, list_2):
    set_1 = set(list_1)
    set_2 = set(list_2)
    if set_1 == set_2:
        return 0
    else:
        return 1

In [None]:
init_author_table = spark.read.parquet(f"{prod_save_path}/current_authors_table/") \
    .select('work_author_id', 
            F.col('author_id').alias('author_id_1'),
            F.col('display_name').alias('display_name_1'),
            F.col('alternate_names').alias('alternate_names_1'), 
            F.col('orcid').alias('orcid_1'),
            'created_date',
            'modified_date')

final_author_table = spark.read.parquet(f"{temp_save_path}/temp_authors_table/new_author_clusters/") \
        .select(F.col('work_author_id_2').alias('work_author_id'), 
            F.col('author_id').alias('author_id_2'),
            F.col('display_name').alias('display_name_2'),
            F.col('alternate_names').alias('alternate_names_2'), 
            'orcid_2')

In [None]:
init_author_table.count()

In [None]:
init_author_table.dropDuplicates(subset=['work_author_id']).count()

In [None]:
final_author_table.count()

In [None]:
final_author_table.dropDuplicates(subset=['work_author_id']).count()

In [None]:
# take final author table, compare to init table (all columns) to see if anything has been changed
compare_tables = final_author_table.join(init_author_table, how='inner', on='work_author_id') \
    .withColumn('orcid_compare', F.when(F.col('orcid_1')==F.col('orcid_2'), 0).otherwise(1)) \
    .withColumn('display_name_compare', F.when(F.col('display_name_1')==F.col('display_name_2'), 0).otherwise(1)) \
    .withColumn('author_id_compare', F.when(F.col('author_id_1')==F.col('author_id_2'), 0).otherwise(1)) \
    .withColumn('alternate_names_compare', check_list_vs_list(F.col('alternate_names_1'), F.col('alternate_names_2'))) \
    .withColumn('total_changes', F.col('orcid_compare') + F.col('display_name_compare') + 
                            F.col('author_id_compare') + F.col('alternate_names_compare'))

In [None]:
# if not, write out those rows to a folder
compare_tables.filter(F.col('total_changes')==0) \
    .select('work_author_id', 
            F.col('author_id_1').alias('author_id'),
            F.col('display_name_1').alias('display_name'),
            F.col('alternate_names_1').alias('alternate_names'), 
            F.col('orcid_1').alias('orcid'),
            'created_date',
            'modified_date') \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/final_author_table_part/no_changes/")

In [None]:
# if modified but not created, write out to different folder
compare_tables.filter(F.col('total_changes')>0) \
    .select('work_author_id', 
            F.col('author_id_2').alias('author_id'),
            F.col('display_name_2').alias('display_name'),
            F.col('alternate_names_2').alias('alternate_names'), 
            F.col('orcid_2').alias('orcid'),
            'created_date') \
    .withColumn("modified_date", F.current_timestamp()) \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/final_author_table_part/modified/")

In [None]:
# if created, write out to different folder
final_author_table.join(init_author_table.select('work_author_id'), how='leftanti', on='work_author_id') \
    .select('work_author_id', 
            F.col('author_id_2').alias('author_id'),
            F.col('display_name_2').alias('display_name'),
            F.col('alternate_names_2').alias('alternate_names'), 
            F.col('orcid_2').alias('orcid')) \
    .withColumn("created_date", F.current_timestamp()) \
    .withColumn("modified_date", F.current_timestamp()) \
    .write.mode('overwrite') \
    .parquet(f"{temp_save_path}/final_author_table_part/created/")
    

In [None]:
print("No row changes: ", spark.read.parquet(f"{temp_save_path}/final_author_table_part/no_changes/").count())

In [None]:
print("Modified rows: ", spark.read.parquet(f"{temp_save_path}/final_author_table_part/modified/").count())

In [None]:
print("New cluster rows: ", spark.read.parquet(f"{temp_save_path}/final_author_table_part/created/").count())

In [None]:
print("Total rows: ", spark.read.parquet(f"{temp_save_path}/final_author_table_part/*").count())

### Save all final tables to S3

In [None]:
spark.read.parquet(f"{temp_save_path}/final_author_table_part/*") \
    .repartition(250) \
    .write.mode('overwrite') \
    .parquet(f"{prod_save_path}/current_authors_table/")

In [None]:
spark.read.parquet(f"{temp_save_path}/temp_features_table/new_author_clusters/") \
    .repartition(250) \
    .write.mode('overwrite') \
    .parquet(f"{prod_save_path}/current_features_table/")

In [None]:
spark.read.parquet(f"{prod_save_path}/current_features_table/").count()

In [None]:
spark.read.parquet(f"{temp_save_path}/temp_features_table/new_author_clusters/") \
    .select('work_author_id_2', 'orcid_2', F.col('author_2').alias('transformed_name')) \
    .filter(F.col('transformed_name')!="") \
    .filter(F.col('transformed_name').isNotNull()) \
    .withColumn('transformed_search_name', transform_name_for_search(F.col('transformed_name'))) \
    .withColumn('name_len', F.length(F.col('transformed_search_name'))) \
    .filter(F.col('name_len')>1) \
    .withColumn('name_match_list_2', get_name_match_list(F.col('transformed_search_name'))) \
    .withColumn('block', only_get_last(F.col('transformed_search_name'))) \
    .select('work_author_id_2','name_match_list_2', 'orcid_2', 'transformed_search_name', 'block')\
    .withColumn('block_removed', F.expr("regexp_replace(transformed_search_name, block, '')")) \
    .withColumn('new_block_removed', F.trim(F.expr("regexp_replace(block_removed, '  ', ' ')"))) \
    .withColumn('letter', get_starting_letter(F.col('new_block_removed'))) \
    .select('work_author_id_2','orcid_2','name_match_list_2', 'block', 'letter') \
    .dropDuplicates() \
    .repartition(250) \
    .write.mode('overwrite') \
    .parquet(f"{prod_save_path}/current_author_names_match/")