In [0]:
import pickle
import boto3
import re
import json
import random
import unicodedata
import unidecode
import pandas as pd
import numpy as np
import time
from datetime import datetime, timedelta
from nameparser import HumanName

from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType, StringType, FloatType, ArrayType, DoubleType, StructType, StructField



In [0]:
spark.catalog.clearCache()

In [0]:
base_save_path = "<S3path>"
iteration_save_path = "<S3path>"

#### Getting all author data

In [0]:
@udf(returnType=StringType())
def transform_name_for_search(name):
    name = unidecode.unidecode(unicodedata.normalize('NFKC', name))
    name = name.lower().replace(" ", " ").replace(".", " ").replace(",", " ").replace("|", " ").replace(")", "").replace("(", "")\
        .replace("-", "").replace("&", "").replace("$", "").replace("#", "").replace("@", "").replace("%", "").replace("0", "") \
        .replace("1", "").replace("2", "").replace("3", "").replace("4", "").replace("5", "").replace("6", "").replace("7", "") \
        .replace("8", "").replace("9", "").replace("*", "").replace("^", "").replace("{", "").replace("}", "").replace("+", "") \
        .replace("=", "").replace("_", "").replace("~", "").replace("`", "").replace("[", "").replace("]", "").replace("\\", "") \
        .replace("<", "").replace(">", "").replace("?", "").replace("/", "").replace(";", "").replace(":", "").replace("\'", "") \
        .replace("\"", "")
    name = " ".join(name.split())
    return name

#### Now explode and run matching function on all rows to get only names that could potentially match

In [0]:
@udf(returnType=IntegerType())
def check_block_vs_block(block_1_names_list, block_2_names_list):
    
    # check first names
    first_check, _ = match_block_names(block_1_names_list[0], block_1_names_list[1], block_2_names_list[0], 
                                    block_2_names_list[1])
    # print(f"FIRST {first_check}")
    
    if first_check:
        last_check, _ = match_block_names(block_1_names_list[-2], block_1_names_list[-1], block_2_names_list[-2], 
                                           block_2_names_list[-1])
        # print(f"LAST {last_check}")
        if last_check:
            m1_check, more_to_go = match_block_names(block_1_names_list[2], block_1_names_list[3], block_2_names_list[2], 
                                           block_2_names_list[3])
            if m1_check:
                if not more_to_go:
                    return 1
                m2_check, more_to_go = match_block_names(block_1_names_list[4], block_1_names_list[5], block_2_names_list[4], 
                                                block_2_names_list[5])
                
                if m2_check:
                    if not more_to_go:
                        return 1
                    m3_check, more_to_go = match_block_names(block_1_names_list[6], block_1_names_list[7], block_2_names_list[6], 
                                                block_2_names_list[7])
                    if m3_check:
                        if not more_to_go:
                            return 1
                        m4_check, more_to_go = match_block_names(block_1_names_list[8], block_1_names_list[8], block_2_names_list[8], 
                                                block_2_names_list[9])
                        if m4_check:
                            if not more_to_go:
                                return 1
                            m5_check, _ = match_block_names(block_1_names_list[10], block_1_names_list[11], block_2_names_list[10], 
                                                block_2_names_list[11])
                            if m5_check:
                                return 1
                            else:
                                return 0
                        else:
                            return 0
                    else:
                        return 0
                else:
                    return 0
            else:
                return 0
        else:
            return 0
    else:
        swap_check = check_if_last_name_swapped_to_front_creates_match(block_1_names_list, block_2_names_list)
        # print(f"SWAP {swap_check}")
        if swap_check:
            return 1
        else:
            return 0
        
def get_name_from_name_list(name_list):
    name = []
    for i in range(0,12,2):
        if name_list[i]:
            name.append(name_list[i][0])
        elif name_list[i+1]:
            name.append(name_list[i+1][0])
        else:
            break
    if name_list[-2]:
        name.append(name_list[-2][0])
    elif name_list[-1]:
        name.append(name_list[-1][0])
    else:
        pass

    return name
        
def check_if_last_name_swapped_to_front_creates_match(block_1, block_2):
    name_1 = get_name_from_name_list(block_1)
    if len(name_1) != 2:
        return False
    else:
        name_2 = get_name_from_name_list(block_2)
        if len(name_2)==2:
            if " ".join(name_1) == " ".join(name_2[-1:] + name_2[:-1]):
                return True
            else:
                return False
        else:
            return False
    
def match_block_names(block_1_names, block_1_initials, block_2_names, block_2_initials):
    if block_1_names and block_2_names:
        if any(x in block_1_names for x in block_2_names):
            return True, True
        else:
            return False, True
    elif block_1_names and not block_2_names:
        if block_2_initials:
            if any(x in block_1_initials for x in block_2_initials):
                return True, True
            else:
                return False, True
        else:
            return True, True
    elif not block_1_names and block_2_names:
        if block_1_initials:
            if any(x in block_1_initials for x in block_2_initials):
                return True, True
            else:
                return False, True
        else:
            return True, True
    elif block_1_initials and block_2_initials:
        if any(x in block_1_initials for x in block_2_initials):
            return True, True
        else:
            return False, True
    else:
        return True, False

In [0]:
@udf(returnType=ArrayType(ArrayType(StringType())))
def get_name_match_list(name):
    name_split_1 = name.replace("-", "").split()
    name_split_2 = ""
    if "-" in name:
        name_split_2 = name.replace("-", " ").split()

    fn = []
    fni = []
    
    m1 = []
    m1i = []
    m2 = []
    m2i = []
    m3 = []
    m3i = []
    m4 = []
    m4i = []
    m5 = []
    m5i = []

    ln = []
    lni = []
    for name_split in [name_split_1, name_split_2]:
        if len(name_split) == 0:
            pass
        elif len(name_split) == 1:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[0]) > 1:
                ln.append(name_split[0])
                lni.append(name_split[0][0])
            else:
                lni.append(name_split[0][0])
            
        elif len(name_split) == 2:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 3:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 4:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 5:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])
                
            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 6:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])
            
            if len(name_split[4]) > 1:
                m4.append(name_split[4])
                m4i.append(name_split[4][0])
            else:
                m4i.append(name_split[4][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 7:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])
            
            if len(name_split[4]) > 1:
                m4.append(name_split[4])
                m4i.append(name_split[4][0])
            else:
                m4i.append(name_split[4][0])

            if len(name_split[5]) > 1:
                m5.append(name_split[5])
                m5i.append(name_split[5][0])
            else:
                m5i.append(name_split[5][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        else:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])
                
            if len(name_split[4]) > 1:
                m4.append(name_split[4])
                m4i.append(name_split[4][0])
            else:
                m4i.append(name_split[4][0])

            joined_names = " ".join(name_split[5:-1])
            m5.append(joined_names)
            m5i.append(joined_names[0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
            

    return [list(set(x)) for x in [fn,fni,m1,m1i,m2,m2i,m3,m3i,m4,m4i,m5,m5i,ln,lni]]

@udf(returnType=StringType())
def string_of_both_names(name_1, name_2):
    name_list = [name_1, name_2]
    name_list.sort()
    return "|".join(name_list)

In [None]:
num_partitions = # number of partitions of the data (partitioned by names)

In [None]:
for i in range(num_partitions):
    start_time = time.time()
    print(i, (datetime.now() - timedelta(hours=4)).strftime("%m/%d/%y %H:%M"))
    part_num = i

    spark.catalog.clearCache()

    # reading author dataframes
    author_names_blocked_explode = spark.read.parquet(f"{iteration_save_path}final_model_data/block_creation/names_blocked_to_name_check_parts/random_part={part_num}/") \
       .select('block', F.col('transformed_search_name').alias('name_to_match_1')) \
        .dropDuplicates() \
        .withColumn('check_letter_full', F.expr("regexp_replace(name_to_match_1, block, '')")) \
        .withColumn('letters_in_name', F.when(F.col('block')!='', F.substring(F.col('check_letter_full'), 1, 1)).otherwise("")) \
        .select('block','name_to_match_1', 'letters_in_name')

    author_names_blocked_explode.cache().count()

    # reading in second version
    author_names_blocked = spark.read.parquet(f"{iteration_save_path}final_model_data/block_creation/names_blocked_to_name_check_parts/random_part={part_num}/") \
    .select('block', F.col('transformed_search_name').alias('name_to_match_2'), 'letters_in_name')

    author_names_blocked.cache().count()

    # writing temp data to file
    author_names_blocked_explode.join(author_names_blocked, how='inner', on=['block','letters_in_name']) \
        .select('block','name_to_match_1','name_to_match_2') \
        .withColumn('string_rep', string_of_both_names(F.col('name_to_match_1'), F.col('name_to_match_2'))) \
        .dropDuplicates(subset=['block','string_rep']) \
        .select('block','name_to_match_1','name_to_match_2') \
        .write.mode('overwrite') \
        .parquet(f"{iteration_save_path}final_model_data/block_creation/temp_all_names_blocked_and_matched/random_part={part_num}/")

    # reading in temp data and getting match lists created and checked
    temp_names_blocked = spark.read.parquet(f"{iteration_save_path}final_model_data/block_creation/temp_all_names_blocked_and_matched/random_part={part_num}/") \
        .withColumn('name_match_list_1', get_name_match_list(F.col('name_to_match_1'))) \
        .withColumn('name_match_list_2', get_name_match_list(F.col('name_to_match_2'))) \
        .withColumn('matched_names', check_block_vs_block(F.col('name_match_list_1'), F.col('name_match_list_2'))) \
        .filter(F.col('matched_names')==1) \
        .select('block','name_to_match_1','name_to_match_2')

    temp_names_blocked.cache().count()

    # reading in author tables
    author_names = spark.read.parquet(f"{iteration_save_path}final_model_data/all_authors_for_each_work_indexed") \
        .withColumn('transformed_search_name', transform_name_for_search(F.col('author_name')))

    author_names.cache().count()

    name_transformations_1 = spark.read \
        .parquet(f"{iteration_save_path}final_model_data/author_name_transformations") \
        .select('original_author', F.col('transformed_name').alias('author_name')) \
        .join(author_names, how='inner', on='author_name') \
        .select(F.col('original_author').alias('original_author_1'), F.col('transformed_search_name').alias('name_to_match_1'))

    name_transformations_1.cache().count()

    name_transformations_2 = spark.read \
        .parquet(f"{iteration_save_path}final_model_data/author_name_transformations") \
        .select('original_author', F.col('transformed_name').alias('author_name')) \
        .join(author_names, how='inner', on='author_name') \
        .select(F.col('original_author').alias('original_author_2'), F.col('transformed_search_name').alias('name_to_match_2'))

    name_transformations_2.cache().count()

    # getting names saved with raw author name text
    temp_names_blocked \
        .join(name_transformations_1, how='inner', on='name_to_match_1') \
        .join(name_transformations_2, how='inner', on='name_to_match_2') \
        .select('block','original_author_1','original_author_2') \
        .write.mode('overwrite') \
        .parquet(f"{iteration_save_path}final_model_data/block_creation/all_names_blocked_and_matched/random_part={part_num}/")

    # attaching work IDs
    for_attach_work_ids = spark.read.parquet(f"{iteration_save_path}final_model_data/block_creation/all_names_blocked_and_matched/random_part={part_num}/")
    for_attach_work_ids.cache().count()

    aff_data_1 = spark.read.parquet(f"{base_save_path}static_affiliations") \
        .select(F.concat_ws("_", F.col('paper_id'), F.col('author_sequence_number')).alias('work_id_1'),
                F.trim(F.col('original_author')).alias('original_author_1')) \
        .filter(F.col('original_author_1')!="")
    aff_data_1.cache().count()

    aff_data_2 = spark.read.parquet(f"{base_save_path}static_affiliations") \
        .select(F.concat_ws("_", F.col('paper_id'), F.col('author_sequence_number')).alias('work_id_2'),
                F.trim(F.col('original_author')).alias('original_author_2')) \
        .filter(F.col('original_author_2')!="")
    aff_data_2.cache().count()

    for_attach_work_ids \
        .join(aff_data_1, how='inner', on='original_author_1') \
        .join(aff_data_2, how='inner', on='original_author_2') \
        .filter(F.col('work_id_1')!=F.col('work_id_2')) \
        .withColumn('string_rep', string_of_both_names(F.col('work_id_1'), F.col('work_id_2'))) \
        .dropDuplicates(subset=['block','string_rep']) \
        .select('block','work_id_1','work_id_2') \
        .write.mode('overwrite') \
        .parquet(f"{iteration_save_path}final_model_data/block_creation/all_names_blocked_and_matched_work_ids/random_part={part_num}/")

    print(f"-------------- total time: {round((time.time()-start_time)/60/60, 3)} hours")