In [0]:
import pickle
import boto3
import re
import json
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt



In [0]:
from pyspark.sql import SparkSession
sc = spark.sparkContext
from pyspark.sql import SQLContext
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType, StringType, FloatType, ArrayType, DoubleType, StructType, StructField
sqlContext = SQLContext(sc)



In [0]:
base_save_path = "s3://openalex-data-copy/snapshot_2023_02_15/"
iteration_save_path = "s3://author-disambiguation/V3/"

In [0]:
def get_fake_raw_input_(diff_lang_name):
    if diff_lang_name.strip():
        split_name = diff_lang_name.strip().split(" ")
        if len(split_name) == 2:
            rand_float = random.random()
            if rand_float < 0.7:
                return diff_lang_name.strip()
            else:
                return f"{split_name[1]} {split_name[0]}"
        elif len(split_name) == 3:
            rand_float = random.random()
            if rand_float < 0.7:
                return diff_lang_name.strip()
            elif rand_float < 0.85:
                return f"{split_name[2]} {split_name[0]} {split_name[1]}"
            else:
                return f"{split_name[1]} {split_name[2]} {split_name[0]}"
        else:
            return diff_lang_name.strip()
    else:
        return diff_lang_name.strip()
    
get_fake_raw_input = F.udf(get_fake_raw_input_, StringType())

In [0]:
all_train_data = spark.read.parquet(f"{iteration_save_path}all_processed_data_for_model/all_training_data")
all_train_data.cache().count()

Out[5]: 222619352

In [0]:
all_train_data_diff_langs = spark.read.parquet(f"{iteration_save_path}all_processed_data_for_model/all_training_data_other_languages") \
.select(F.col('original_author').alias('output_name')) \
.withColumn('raw_input', get_fake_raw_input(F.col('output_name'))) \
.filter(F.col('raw_input')!="")

all_train_data_diff_langs.cache().count()

Out[6]: 5715600

In [0]:
all_train_data.union(all_train_data_diff_langs.select(*all_train_data.columns)) \
.dropDuplicates() \
.write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/all_data")

In [0]:
all_data = spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/all_data") \
.select('raw_input', F.trim(F.col('output_name')).alias('output_name')) \
.dropDuplicates()
all_data.cache().count()

Out[11]: 224520730

In [0]:
all_data.dropDuplicates(subset=['output_name']).count()

Out[12]: 54682278

In [0]:
all_data.select(F.trim(F.col('output_name')).alias('output_name')).dropDuplicates(subset=['output_name']).count()

Out[13]: 54682278

In [0]:
all_data.sample(0.00001).show(50, truncate=False)

+-------------------------------+-------------------------------+
|raw_input                      |output_name                    |
+-------------------------------+-------------------------------+
|Sara Wallin                    |Sara Wallin                    |
|James A. Pope Marine           |James A. Pope Marine           |
|Rodríguez de Ancos Marcos      |Marcos Rodríguez De Ancos      |
|JVD Bilt                       |J. V. D. Bilt                  |
|JJ Guo                         |J. J Guo                       |
|Paul Mange Johansen            |Paul Mange Johansen            |
|Desai, Bhalloo                 |Bhalloo Desai                  |
|Rudd, Jeffrey A.               |Jeffrey A. Rudd                |
|L. Telezhenko                  |L. Telezhenko                  |
|Kang, Seong-Oun                |Seong-Oun Kang                 |
|Razel Carol                    |Carol Razel                    |
|Punjabi V.                     |Vina Punjabi                   |
|Amraeiniy

In [0]:
w1 = Window().orderBy(F.rand())

all_data.dropDuplicates(subset=['output_name']).select('output_name') \
.withColumn('name_label', F.row_number().over(w1)) \
.write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/name_labels")

In [0]:
name_labels = spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/name_labels")
name_labels.cache().count()

Out[29]: 54682278

In [0]:
train, val, test = name_labels.randomSplit([0.9995, 0.0003, 0.0002], seed=0)

In [0]:
all_data.join(train, how='inner', on='output_name') \
.orderBy(F.rand()) \
.groupBy(['output_name','name_label']) \
.agg(F.collect_list(F.col('raw_input')).alias('raw_input')) \
.coalesce(50).write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/train_data")

In [0]:
all_data.join(val, how='inner', on='output_name') \
.orderBy(F.rand()) \
.groupBy(['output_name','name_label']) \
.agg(F.collect_list(F.col('raw_input')).alias('raw_input')) \
.coalesce(5).write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/val_data")

In [0]:
all_data.join(test, how='inner', on='output_name') \
.orderBy(F.rand()) \
.groupBy(['output_name','name_label']) \
.agg(F.collect_list(F.col('raw_input')).alias('raw_input')) \
.coalesce(1).write.mode('overwrite').parquet(f"{iteration_save_path}name_embedding_training_data/test_data")

In [0]:
spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/train_data").count()

Out[34]: 54655132

In [0]:
spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/val_data").sample(0.1).show(20)

+--------------------+----------+--------------------+
|         output_name|name_label|           raw_input|
+--------------------+----------+--------------------+
|             A Bizze|  35908451|[A Bizze, Bizze A...|
|            A Florou|  26569623|[Florou, A, Ageli...|
|           A Gaganov|  52498660|[Gaganov A., A Ga...|
|            A P Sime|  46836439|[Sime, A P, Sime ...|
|            A Paluka|   2803044|[A. Paluka, Paluk...|
|       A T. Sillitoe|  35405850|[A T. Sillitoe, A...|
| A. Cancino Marambio|  51198746|[A. Cancino Maram...|
|        A. F. Norman|  30860599|[Norman A. F., Al...|
| A. G. Tereshchenkov|  23412116|[Tereshchenkov A....|
|A. Jean-Antoine P...|   2184430|[Jean-Antoine Pic...|
|        A. M. Nassef|  24222910|[Nassef, A. M., N...|
|         A. Mikheeva|  25468706|[A Mikheeva, A MI...|
|       A. P. Bhaduri|  25307058|[A. P. BHADURI, A...|
|A. P. Yepes Quintero|  12081211|[AP Yepes Quinter...|
|       A. Paul Press|    338995|[A. P. Press, A. ...|
|      A. 

In [0]:
spark.read.parquet(f"{iteration_save_path}name_embedding_training_data/test_data").count()

Out[37]: 10791