In [0]:
import pickle
import boto3
import re
import json
import pandas as pd
import unicodedata
pd.set_option('display.max_colwidth', None)
import numpy as np
import matplotlib.pyplot as plt

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType, StringType, FloatType, ArrayType, DoubleType, StructType, StructField,LongType

In [0]:
base_save_path = "{save_path_for_openalex_tables}"
iteration_save_path = "{save_path_for_most_data}"
model_save_path = "{higher_level_iteration_save_path}"

### Getting all data

In [0]:
classification_labels = spark.read.parquet(f'{model_data_path}topic_labels_data_from_cwts_new')
classification_labels.cache().count()

4521

In [0]:
classification_labels.sample(0.1).show(1, truncate=False)

+----------------+-----------------+-------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------+
|micro_cluster_id|short_label      |long_label                           |keywords                                                                                                                                          |summary                                                                                      

In [0]:
work_merges = spark.read.parquet(f"{iteration_save_path[:-15]}/work_merges_dist")\
    .select('paper_id',F.col('merge_into_id').alias('paper_id_merge_into')).dropDuplicates(subset=['paper_id'])
work_merges.cache().count()

12559467

In [0]:
new_topic_labels = spark.read.parquet(f'{model_data_path}topics_data_from_cwts_new') \
        .select(F.col('work_id').cast(LongType()).alias('paper_id'), 
                F.col('macro_cluster_id').cast(IntegerType()),
                F.col('meso_cluster_id').cast(IntegerType()),
                F.col('micro_cluster_id').cast(IntegerType())) \
        .filter(F.col('paper_id').isNotNull() & 
                F.col('macro_cluster_id').isNotNull() & 
                F.col('meso_cluster_id').isNotNull() & 
                F.col('micro_cluster_id').isNotNull()) \
        .join(classification_labels, how='inner', on='micro_cluster_id') \
        .join(work_merges, how='left', on='paper_id') \
        .withColumn('new_paper_id', F.when(F.col('paper_id_merge_into').isNotNull(), 
                                           F.col('paper_id_merge_into')).otherwise(F.col('paper_id'))) \
        .select('micro_cluster_id', F.col('new_paper_id').alias('paper_id'), 'macro_cluster_id', 'meso_cluster_id', 
                'short_label', 'long_label', 'keywords', 'summary', 'wikipedia_url')
        
new_topic_labels.cache().count()

70674439

In [0]:
new_topic_labels.groupby('micro_cluster_id').count().select(F.max('count'), F.min('count')).show()

+----------+----------+
|max(count)|min(count)|
+----------+----------+
|    167455|      1000|
+----------+----------+



### Getting Gold Citations

In order to create features based on the citation graph, we decided to create "gold" citations which are pretty much the most highly-cited papers for each micro_cluster_id or topic. So in order to get the gold citations, we grouped all papers based on their topic label (from CWTS). We then grouped all referenced papers and ranked based on how many times that reference paper was cited. We took the top X from each topic and called those our gold citations. We also removed papers that were highly cited across a lot of different topic using the "within_clust_per" variable calculated above.

For the actual features that are going into the model, we created lists of gold citations for each paper that were either directly cited (1 edge away) or cited in a reference paper (2 edges away). The gold citations will be fed into a neural net as an embedding so that the model can learn a representation of the gold citation.

In [0]:
@udf(returnType=IntegerType())
def get_number_of_core_citations_for_cluster(citation_count):
    """
    This function calculates how many gold citations a topic can have based on the amount of labeled
    data available (so the size of the cluster). For larger clusters, we allowed for a larger number
    of gold citations.
    """
    min_core_cites = 25
    max_core_cites = 75

    scale_per = (citation_count - 1000) / (167455 - 1000)
    
    return int(min_core_cites + (max_core_cites - min_core_cites)*scale_per)

In [0]:
cite_counts_to_use = topic_counts \
    .withColumn('cite_counts_thresh', get_number_of_core_citations_for_cluster(F.col('count'))) \
    .select('micro_cluster_id', 'cite_counts_thresh')

In [None]:
new_citations = spark.read.parquet(f"{iteration_save_path[:-15]}/new_work_citations") \
    .dropDuplicates()

citations = spark.read.parquet(f"{base_save_path}static_citations") \
    .union(new_citations.select('paper_id','paper_reference_id')) \
    .dropDuplicates() \
    .join(new_topic_labels, how='inner', on='paper_id') \
    .select('paper_id', 'paper_reference_id', 
            F.col('micro_cluster_id').alias('micro_paper'))

citations.cache().count()

In [0]:
overall_group_by = citations \
    .groupBy('paper_reference_id') \
    .agg(F.size(F.collect_set(F.col('paper_id'))).alias('total_citations'))

overall_group_by.cache().count()

75675000

In [0]:
grouped_citations = citations \
    .groupBy(['paper_reference_id','micro_paper']) \
    .agg(F.size(F.collect_set(F.col('paper_id'))).alias('num_citations')) \
    .join(overall_group_by, how='inner', on='paper_reference_id') \
    .withColumn('within_clust_per', F.col('num_citations')/F.col('total_citations'))

grouped_citations.cache().count()

368318699

In [0]:
w1 = Window.partitionBy(['micro_cluster_id']).orderBy(F.col('num_citations').desc())

gold_citations = grouped_citations \
    .select('paper_reference_id', F.col('micro_paper').alias('micro_cluster_id'), 
            'num_citations', 'within_clust_per') \
    .withColumn('cluster_citation_rank', F.row_number().over(w1)) \
    .join(cite_counts_to_use, how='inner', on='micro_cluster_id') \
    .filter(F.col('within_clust_per')>0.05)

gold_citations.cache().count()

182468281

The following code takes only the top X (number determined by function above) gold citations for each topic. In some cases, the threshold is greater than the number of gold citations available for a given topic so in that case we just used all of the gold citations for that topic.

In [0]:
gold_citations.filter(F.col('cluster_citation_rank')<=F.col('cite_counts_thresh')) \
    .select(F.col('paper_reference_id').alias('gold_citation'),'micro_cluster_id','num_citations','within_clust_per') \
    .write.mode('overwrite') \
    .parquet(f'{iteration_save_path}gold_citation_papers')

In [0]:
spark.read.parquet(f'{iteration_save_path}gold_citation_papers') \
    .coalesce(1) \
    .write.mode('overwrite') \
    .parquet(f'{iteration_save_path}gold_citation_papers_single_file')

### Using Gold Citations to Create Features

##### Looking at the total number of gold citations

In [0]:
gold_citation_papers = spark.read.parquet(f'{iteration_save_path}gold_citation_papers').dropDuplicates(subset=['gold_citation']) \
    .select(F.col('gold_citation').alias('paper_reference_id'), 'micro_cluster_id','within_clust_per')
gold_citation_papers.cache().count()

124577

##### Looking at the minimum and maximum number of gold citations for each topic

In [0]:
gold_citation_papers.groupBy('micro_cluster_id').count().select(F.min('count'), F.max('count')).show()

+----------+----------+
|min(count)|max(count)|
+----------+----------+
|        11|        75|
+----------+----------+



In [0]:
citations_1 = citations.alias('citations_1').select('paper_id','paper_reference_id')
citations_2 = citations.alias('citations_2').select('paper_id','paper_reference_id')

##### Creating level 0 citation feature (gold citation is directly referenced in the paper, i.e., 1 edge away)

In [0]:
level_0 = citations.select('paper_id').dropDuplicates()\
    .join(citations_1.join(gold_citation_papers, how='inner', on='paper_reference_id'), 
           how='left', on='paper_id')

level_0.groupby('paper_id').agg(F.collect_set(F.col('paper_reference_id')).alias('gold_cites')) \
    .write.mode('overwrite') \
    .parquet(f'{iteration_save_path}level_0_citation_links')

In [0]:
spark.read.parquet(f'{iteration_save_path}level_0_citation_links').count()

59807180

In [0]:
spark.read.parquet(f'{iteration_save_path}level_0_citation_links').sample(0.001).show(40)

+----------+--------------------+
|  paper_id|          gold_cites|
+----------+--------------------+
| 114003884|                  []|
| 118302687|                  []|
| 149839387|                  []|
| 318348859|                  []|
| 571605520|[2149863032, 2105...|
| 589554927|                  []|
| 614453582|                  []|
| 885093613|                  []|
| 901667683|                  []|
|1198778148|        [2164748597]|
|1491458199|[2088864916, 1965...|
|1513977309|                  []|
|1514179409|[1715752662, 2002...|
|1529113798|                  []|
|1534823655|[2102463240, 5943...|
|1554441443|[4292280330, 3115...|
|1571911914|                  []|
|1595689375|[2264936970, 2160...|
|1901106884|        [2025613329]|
|1939317862|                  []|
|1963923169|                  []|
|1964429642|[2083005745, 2162...|
|1965385091|                  []|
|1966097699|                  []|
|1969820049|[2112355650, 2159...|
|1969850333|[4244280225, 4298...|
|1972924757|[2

In [0]:
level_0_citations = spark.read.parquet(f'{iteration_save_path}level_0_citation_links') \
    .select('paper_id', F.col('gold_cites').alias('level_0_links'))
level_0_citations.cache().count()

59807180

##### Creating level 1 citation feature (gold citation is cited in a referenced paper, i.e.,2 edges away)

In [0]:
level_1 = citations.select('paper_id').dropDuplicates()\
    .join(citations_1, how='left', on='paper_id') \
    .select(F.col('paper_id').alias('original_paper_id'), 
            F.col('paper_reference_id').alias('paper_id')) \
    .join(citations_2.join(gold_citation_papers, how='inner', on='paper_reference_id'), 
           how='left', on='paper_id')

level_1.groupby('original_paper_id').agg(F.collect_set(F.col('paper_reference_id')).alias('gold_cites')) \
    .write.mode('overwrite') \
    .parquet(f'{iteration_save_path}level_1_citation_links')

In [0]:
level_1_citations = spark.read.parquet(f'{iteration_save_path}level_1_citation_links') \
    .select(F.col('original_paper_id').alias('paper_id'), F.col('gold_cites').alias('level_1_links'))
level_1_citations.cache().count()

59807180

In [0]:
spark.read.parquet(f'{iteration_save_path}level_1_citation_links').sample(0.001).show(40)

+-----------------+--------------------+
|original_paper_id|          gold_cites|
+-----------------+--------------------+
|         55610713|                  []|
|         68910409|[391578156, 19838...|
|        117585249|[4256153703, 1516...|
|        134638304|                  []|
|        175965146|                  []|
|        196092515|[2911865844, 2088...|
|        251029635|[1603121691, 2131...|
|        810677605|[2154820256, 2097...|
|        895993603|                  []|
|       1491467088|[2007180942, 1507...|
|       1507661821|[2053186076, 2040...|
|       1532856579|[2148548654, 2162...|
|       1533791920|[1997252211, 2115...|
|       1546224377|[2058717577, 2152...|
|       1571056597|[2165847290, 2005...|
|       1582319857|[2080274663, 2014...|
|       1584196799|[2089218510, 2134...|
|       1593786842|[2128672031, 2105...|
|       1599515711|                  []|
|       1600996141|[1974718993, 2069...|
|       1623015207|                  []|
|       17167795

In [0]:
@udf(returnType=ArrayType(LongType()))
def move_level_0_to_1(level_0, level_1):
    """
    We move all level 0 citations to level 1 as well since we want to make sure that both features have a
    value if the level 0 feature is present.
    """
    return list(set(level_0 + level_1))

In [0]:
level_0_citations \
    .join(level_1_citations, how='inner', on='paper_id') \
    .withColumn('final_level_1', move_level_0_to_1(F.col('level_0_links'), F.col('level_1_links'))) \
    .select('paper_id', 'level_0_links', F.col('final_level_1').alias('level_1_links')) \
    .write.mode('overwrite') \
    .parquet(f'{iteration_save_path}all_citation_features')

##### Looking to see how much of the training data will have citation features

In [None]:
spark.read.parquet(f'{iteration_save_path}all_citation_features') \
    .withColumn('L0_size', F.size(F.col('level_0_links'))) \
    .withColumn('L1_size', F.size(F.col('level_1_links'))) \
    .filter(F.col('L0_size')>0).count()

In [None]:
spark.read.parquet(f'{iteration_save_path}all_citation_features') \
    .withColumn('L0_size', F.size(F.col('level_0_links'))) \
    .withColumn('L1_size', F.size(F.col('level_1_links'))) \
    .filter(F.col('L1_size')>0).count()

### Pulling all data into one dataframe

In [0]:
new_works = spark.read.parquet(f"{model_save_path}new_work_titles") \
    .dropDuplicates(subset=['paper_id'])

works = spark.read.parquet(f"{base_save_path}static_works") \
    .select(*new_works.columns) \
    .union(new_works) \
    .select('paper_id','original_title','publication_date','journal_id') \
    .dropDuplicates(subset=['paper_id'])


works.cache().count()

247623528

In [0]:
new_abstracts = spark.read.parquet(f"{model_save_path}new_work_abstracts") \
    .dropDuplicates(subset=['paper_id'])

abstracts = spark.read.parquet(f"{base_save_path}static_abstracts") \
    .select('paper_id', 'abstract') \
    .union(new_abstracts.select('paper_id','abstract')) \
    .dropDuplicates(subset=['paper_id'])

abstracts.cache().count()

133620738

In [0]:
citation_features = spark.read.parquet(f'{iteration_save_path}all_citation_features')
citation_features.cache().count()

59807180

In [0]:
works \
    .join(new_topic_labels.select('paper_id','micro_cluster_id','short_label','long_label','keywords'), 
          how='inner', on='paper_id') \
    .join(citation_features, how='left', on='paper_id') \
    .join(abstracts, how='left', on='paper_id') \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}all_training_data/")

## Transforming Data

##### The following functions are used to determine if the title can be used and also to transform the text of the title if necessary

In [0]:
def name_to_keep_ind(groups):
    groups_to_skip = ['HIRAGANA', 'CJK', 'KATAKANA','ARABIC', 'HANGUL', 'THAI','DEVANAGARI','BENGALI',
                      'THAANA','GUJARATI','CYRILLIC']
    
    if any(x in groups_to_skip for x in groups):
        return 0
    else:
        return 1
    
def group_non_latin_characters(text):
    groups = []
    latin_chars = []
    text = text.replace(".", "").replace(" ", "")
    for char in text:
        try:
            script = unicodedata.name(char).split(" ")[0]
            if script == 'LATIN':
                latin_chars.append(script)
            else:
                if script not in groups:
                    groups.append(script)
        except:
            if "UNK" not in groups:
                groups.append("UNK")
    return groups, len(latin_chars)

def remove_non_latin_characters(text):
    final_char = []
    groups_to_skip = ['HIRAGANA', 'CJK', 'KATAKANA','ARABIC', 'HANGUL', 'THAI','DEVANAGARI','BENGALI',
                      'THAANA','GUJARATI','CYRILLIC']
    for char in text:
        try:
            script = unicodedata.name(char).split(" ")[0]
            if script not in groups_to_skip:
                final_char.append(char)
        except:
            pass
    return "".join(final_char)

@udf(returnType=IntegerType())
def check_for_non_latin_characters(text):
    groups, latin_chars = group_non_latin_characters(text)
    if name_to_keep_ind(groups) == 1:
        return 1
    elif latin_chars > 30:
        return 1
    else:
        return 0
    
@udf(returnType=StringType())
def clean_title(old_title, keep_title):
    if keep_title:
        new_title = remove_non_latin_characters(old_title)
        if '<' in new_title:
            new_title = new_title.replace("<i>", "").replace("</i>","")\
                                 .replace("<sub>", "").replace("</sub>","") \
                                 .replace("<sup>", "").replace("</sup>","") \
                                 .replace("<em>", "").replace("</em>","") \
                                 .replace("<b>", "").replace("</b>","") \
                                 .replace("<I>", "").replace("</I>", "") \
                                 .replace("<SUB>", "").replace("</SUB>", "") \
                                 .replace("<scp>", "").replace("</scp>", "") \
                                 .replace("<font>", "").replace("</font>", "") \
                                 .replace("<inf>","").replace("</inf>", "") \
                                 .replace("<i /> ", "") \
                                 .replace("<p>", "").replace("</p>","") \
                                 .replace("<![CDATA[<B>", "").replace("</B>]]>", "") \
                                 .replace("<italic>", "").replace("</italic>","")\
                                 .replace("<title>", "").replace("</title>", "") \
                                 .replace("<br>", "").replace("</br>","").replace("<br/>","") \
                                 .replace("<B>", "").replace("</B>", "") \
                                 .replace("<em>", "").replace("</em>", "") \
                                 .replace("<BR>", "").replace("</BR>", "") \
                                 .replace("<title>", "").replace("</title>", "") \
                                 .replace("<strong>", "").replace("</strong>", "") \
                                 .replace("<formula>", "").replace("</formula>", "") \
                                 .replace("<roman>", "").replace("</roman>", "") \
                                 .replace("<SUP>", "").replace("</SUP>", "") \
                                 .replace("<SSUP>", "").replace("</SSUP>", "") \
                                 .replace("<sc>", "").replace("</sc>", "") \
                                 .replace("<subtitle>", "").replace("</subtitle>", "") \
                                 .replace("<emph/>", "").replace("<emph>", "").replace("</emph>", "") \
                                 .replace("""<p class="Body">""", "") \
                                 .replace("<TITLE>", "").replace("</TITLE>", "") \
                                 .replace("<sub />", "").replace("<sub/>", "") \
                                 .replace("<mi>", "").replace("</mi>", "") \
                                 .replace("<bold>", "").replace("</bold>", "") \
                                 .replace("<mtext>", "").replace("</mtext>", "") \
                                 .replace("<msub>", "").replace("</msub>", "") \
                                 .replace("<mrow>", "").replace("</mrow>", "") \
                                 .replace("</mfenced>", "").replace("</math>", "")

            if '<mml' in new_title:
                all_parts = [x for y in [i.split("mml:math>") for i in new_title.split("<mml:math")] for x in y if x]
                final_parts = []
                for part in all_parts:
                    if re.search(r"\>[$%#!^*\w.,/()+-]*\<", part):
                        pull_out = re.findall(r"\>[$%#!^*\w.,/()+-]*\<", part)
                        final_pieces = []
                        for piece in pull_out:
                            final_pieces.append(piece.replace(">", "").replace("<", ""))
                        
                        final_parts.append(" "+ "".join(final_pieces) + " ")
                    else:
                        final_parts.append(part)
                
                new_title = "".join(final_parts).strip()
            else:
                pass

            if '<xref' in new_title:
                new_title = re.sub(r"\<xref[^/]*\/xref\>", "", new_title)

            if '<inline-formula' in new_title:
                new_title = re.sub(r"\<inline-formula[^/]*\/inline-formula\>", "", new_title)

            if '<title' in new_title:
                new_title = re.sub(r"\<title[^/]*\/title\>", "", new_title)

            if '<p class=' in new_title:
                new_title = re.sub(r"\<p class=[^>]*\>", "", new_title)
            
            if '<span class=' in new_title:
                new_title = re.sub(r"\<span class=[^>]*\>", "", new_title)

            if 'mfenced open' in new_title:
                new_title = re.sub(r"\<mfenced open=[^>]*\>", "", new_title)
            
            if 'math xmlns' in new_title:
                new_title = re.sub(r"\<math xmlns=[^>]*\>", "", new_title)

        if '<' in new_title:
            new_title = new_title.replace(">i<", "").replace(">/i<", "") \
                                 .replace(">b<", "").replace(">/b<", "") \
                                 .replace("<inline-formula>", "").replace("</inline-formula>","")

        return new_title
    else:
        return ''

In [0]:
def invert_abstract_to_abstract(invert_abstract):
    invert_abstract = json.loads(invert_abstract)
    ab_len = invert_abstract['IndexLength']
    
    if 30 < ab_len < 1000:
        abstract = [" "]*ab_len
        for key, value in invert_abstract['InvertedIndex'].items():
            for i in value:
                abstract[i] = key
        final_abstract = " ".join(abstract)
    else:
        final_abstract = None
    return final_abstract

def clean_abstract(abstract, inverted=True):
    if inverted:
        if abstract:
            abstract = invert_abstract_to_abstract(abstract)
        else:
            pass
    else:
        pass
#     abstract = clean_text(abstract)
    return abstract

def clean_text(text):
    try:
        text = text.lower()

        text = re.sub('[^a-zA-Z0-9 ]+', ' ', text)
        text = re.sub(' +', ' ', text)
        text = text.strip()
        
    except:
        text = ""
    return text

In [0]:
@udf(returnType=ArrayType(LongType()))
def get_final_citations_feature(curr_cites):
    if isinstance(curr_cites, list):
        return curr_cites
    else:
        return []

In [0]:
w2 = Window.partitionBy('micro_cluster_id').orderBy('random_num')

In [0]:
spark.read.parquet(f"{iteration_save_path}all_training_data/").count()

70674439

In [0]:
spark.read.parquet(f"{iteration_save_path}all_training_data/") \
    .withColumn('random_num', F.rand()) \
    .withColumn('cluster_rank', F.row_number().over(w2)) \
    .filter(F.col('cluster_rank')<=1000) \
    .withColumn('keep_title', check_for_non_latin_characters(F.col('original_title')))\
    .withColumn('new_title', clean_title(F.col('original_title'), F.col('keep_title'))) \
    .withColumn('final_level_0_links', get_final_citations_feature(F.col('level_0_links'))) \
    .withColumn('final_level_1_links', get_final_citations_feature(F.col('level_1_links'))) \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}all_training_data_subset/")

In [0]:
spark.read.parquet(f"{iteration_save_path}all_training_data_subset/").count()

4521000

##### Checking to see how big the citation features need to be in order to make use of most of the data

In [0]:
# 8 max len for level_0
# 128 max len for level_1

link_features = spark.read.parquet(f"{iteration_save_path}all_training_data_subset/") \
    .withColumn('cite_0_len', F.size(F.col('level_0_links'))) \
    .withColumn('cite_1_len', F.size(F.col('level_1_links'))) \
    .select('cite_0_len','cite_1_len').toPandas()

In [0]:
link_features.shape

(4521000, 2)

In [0]:
link_features[link_features['cite_0_len']<16].shape[0]/link_features.shape[0]

0.9968918380889183

In [0]:
link_features[link_features['cite_1_len']<128].shape[0]/link_features.shape[0]

0.9675266533952666

##### Saving training data to single file so that it can be used on a single machine

In [0]:
spark.read.parquet(f"{iteration_save_path}all_training_data_subset/") \
    .select('paper_id','publication_date', 'journal_id', 'final_level_0_links', 'final_level_1_links', 
            'abstract', 'micro_cluster_id', 'short_label', 'long_label', 'keywords', 
            'new_title')\
    .coalesce(1) \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}all_training_data_subset_single_file/")