In [0]:
import pickle
import boto3
import re
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt



In [0]:
from pyspark.sql import SparkSession
sc = spark.sparkContext
from pyspark.sql import SQLContext
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType, StringType, FloatType, ArrayType, DoubleType, StructType, StructField
sqlContext = SQLContext(sc)



In [0]:
# These paths should be changed to wherever you want to save the general data and where you want to save
# iteration specific data
base_save_path = "./current_directory/"
iteration_save_path = "./current_directory/institutional_affiliation_classification/"

### Getting all data (From saved OpenAlex DB snapshot)

In [0]:
institutions = spark.read.parquet(f"{base_save_path}static_institutions") \
    .filter(F.col('ror_id')!='')

In [0]:
institutions.cache().count()

Out[5]: 102392

In [0]:
affiliations = spark.read.parquet(f"{base_save_path}static_affiliations")

In [0]:
affiliations.cache().count()

Out[5]: 634179075

#### Getting ROR aff strings

In [0]:
dedup_affs = affiliations.select(F.trim(F.col('original_affiliation')).alias('original_affiliation'), 'affiliation_id')\
.filter(F.col('original_affiliation').isNotNull())\
.filter(F.col('original_affiliation')!='')\
.withColumn('aff_len', F.length(F.col('original_affiliation')))\
.filter(F.col('aff_len')>2)\
.groupby(['original_affiliation','affiliation_id']) \
.agg(F.count(F.col('affiliation_id')).alias('aff_string_counts'))

In [0]:
dedup_affs.cache().count()

Out[7]: 71022311

In [0]:
ror_data = spark.read.parquet(f"{iteration_save_path}ror_strings.parquet") \
.select('original_affiliation','affiliation_id')

In [0]:
ror_data.cache().count()

Out[12]: 9545485

### Gathering training data

Since we are looking at all institutions, we need to up-sample the institutions that don't have many affiliation strings and down-sample the institutions that have large numbers of strings. There was a balance here that needed to be acheived. The more samples that are taken for each institution, the more overall training data we will have and the longer our model will take to train. However, more samples also means more ways of an institution showing up in an affiliation string. The number of samples was set to 50 as it was determined this was a good optimization point based on affiliation string count distribution and time it would take to train the model. However, unlike in V1 where we tried to keep all institutions at 50, for V2 we gave additional samples for institutions with more strings available. Specifically, we allowed those institutions to have up to 25 additional strings, for a total of 75.

In [0]:
num_samples_to_get = 50 

In [0]:
w1 = Window.partitionBy('affiliation_id')

filled_affiliations = dedup_affs \
    .join(ror_data.select('affiliation_id'), how='inner', on='affiliation_id') \
    .select('original_affiliation','affiliation_id') \
    .union(ror_data.select('original_affiliation','affiliation_id')) \
    .filter(~F.col('affiliation_id').isNull()) \
    .dropDuplicates() \
    .withColumn('random_prob', F.rand(seed=20)) \
    .withColumn('id_count', F.count(F.col('affiliation_id')).over(w1)) \
    .withColumn('scaled_count', F.lit(1)-((F.col('id_count') - F.lit(num_samples_to_get))/(F.lit(3500000) - F.lit(num_samples_to_get)))) \
    .withColumn('final_prob', F.col('random_prob')*F.col('scaled_count'))

In [0]:
filled_affiliations.select('affiliation_id').distinct().count()

Out[38]: 102392

In [0]:
less_than = filled_affiliations.dropDuplicates(subset=['affiliation_id']).filter(F.col('id_count') < num_samples_to_get).toPandas()
less_than.shape

Out[39]: (29482, 6)

In [0]:
less_than.sample(10)

Unnamed: 0,original_affiliation,affiliation_id,random_prob,id_count,scaled_count,final_prob
18956,"Kryton International (Canada), Vancouver, Brit...",4210107101,0.229724,44,1.000002,0.229724
23769,Prague Security Studies Institute Hlavni mesto...,2801162436,0.429828,36,1.000004,0.42983
3498,"Griffin Foundation, Naples, FL, USA.",4210097069,0.579718,16,1.00001,0.579723
1590,"SPDI, Shenzhen, Guangdong",4210153611,0.794875,35,1.000004,0.794879
8944,"Heart of Passion, United States",4210094562,0.812761,13,1.000011,0.81277
8078,Research Center Conoship International (Nether...,4210151141,0.29833,23,1.000008,0.298333
29430,Axone (Switzerland),4210159503,0.824708,24,1.000007,0.824714
13317,Stiftung Berliner Sparkasse,4210098415,0.639455,12,1.000011,0.639462
5702,"Shwachman Diamond Syndrome Foundation, United ...",4210139487,0.981985,13,1.000011,0.981995
8607,Institute of Electronic Business Berlin Germany,4210086833,0.506604,17,1.000009,0.506609


In [0]:
temp_df_list = []
for aff_id in less_than['affiliation_id'].unique():
    temp_df = less_than[less_than['affiliation_id']==aff_id].copy()
    help_df = temp_df.sample(num_samples_to_get - temp_df.shape[0], replace=True)
    temp_df_list.append(pd.concat([temp_df, help_df], axis=0))
less_than_df = pd.concat(temp_df_list, axis=0)

In [0]:
less_than_df.shape

Out[42]: (1474100, 6)

In [0]:
# only install fsspec and s3fs
less_than_df[['original_affiliation', 'affiliation_id']].to_parquet(f"{iteration_save_path}lower_than_{num_samples_to_get}.parquet")

In [0]:
w1 = Window.partitionBy('affiliation_id').orderBy('random_prob')

more_than = filled_affiliations.filter(F.col('id_count') >= num_samples_to_get) \
.withColumn('row_number', F.row_number().over(w1)) \
.filter(F.col('row_number') <= num_samples_to_get+25)

In [0]:
more_than.cache().count()

Out[46]: 5250661

In [0]:
more_than.select('original_affiliation', 'affiliation_id') \
.coalesce(1).write.mode('overwrite').parquet(f"{iteration_save_path}more_than_{num_samples_to_get}")