In [0]:
import pickle
import boto3
import re
import json
import pandas as pd
import unicodedata
pd.set_option('display.max_colwidth', None)
import numpy as np
import matplotlib.pyplot as plt

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType, StringType, FloatType, ArrayType, DoubleType, StructType, StructField,LongType

In [0]:
base_save_path = "{save_path_for_openalex_tables}"
iteration_save_path = "{save_path_for_most_data}"

### Getting all data

In [0]:
classification_labels = spark.read.parquet(f'{iteration_save_path}topic_labels_data_from_cwts_new')
classification_labels.cache().count()

4521

In [0]:
classification_labels.filter(F.col('long_label').contains('Machine Learning')).show(truncate=False)

+----------------+--------------------------------+--------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------+
|micro_cluster

In [0]:
new_topic_labels = spark.read.parquet(f'{iteration_save_path}topics_data_from_cwts_new') \
    .select(F.col('work_id').cast(LongType()).alias('paper_id'), 
            F.col('macro_cluster_id').cast(IntegerType()),
            F.col('meso_cluster_id').cast(IntegerType()),
            F.col('micro_cluster_id').cast(IntegerType())) \
    .filter(F.col('paper_id').isNotNull() & 
            F.col('macro_cluster_id').isNotNull() & 
            F.col('meso_cluster_id').isNotNull() & 
            F.col('micro_cluster_id').isNotNull()) \
    .join(classification_labels, how='inner', on='micro_cluster_id')
    
new_topic_labels.cache().count()

70674439

#### Creating Dataset

In [0]:
new_works = spark.read.parquet(f"{iteration_save_path}new_work_titles") \
    .dropDuplicates(subset=['paper_id'])

works = spark.read.parquet(f"{base_save_path}static_works") \
    .select('paper_id','original_title') \
    .union(new_works.select('paper_id','original_title')) \
    .dropDuplicates(subset=['paper_id'])


works.cache().count()

247622936

In [0]:
works.select(F.max('paper_id')).show()

+-------------+
|max(paper_id)|
+-------------+
|   4390089447|
+-------------+



In [0]:
new_abstracts = spark.read.parquet(f"{iteration_save_path}new_work_abstracts") \
    .dropDuplicates(subset=['paper_id'])

abstracts = spark.read.parquet(f"{base_save_path}static_abstracts") \
    .select('paper_id', 'abstract') \
    .union(new_abstracts.select('paper_id','abstract')) \
    .dropDuplicates(subset=['paper_id'])

abstracts.cache().count()

In [0]:
abstracts.select(F.max('paper_id')).show()

+-------------+
|max(paper_id)|
+-------------+
|   4390088933|
+-------------+



In [0]:
works \
    .join(abstracts, how='left', on='paper_id') \
    .join(new_topic_labels.select('paper_id','micro_cluster_id','short_label','long_label','keywords'), 
          how='inner', on='paper_id').dropDuplicates(subset=['paper_id']) \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}language_model/all_training_data/")

#### Transforming Data

In [0]:
def name_to_keep_ind(groups):
    groups_to_skip = ['HIRAGANA', 'CJK', 'KATAKANA','ARABIC', 'HANGUL', 'THAI','DEVANAGARI','BENGALI',
                      'THAANA','GUJARATI','CYRILLIC']
    
    if any(x in groups_to_skip for x in groups):
        return 0
    else:
        return 1
    
def group_non_latin_characters(text):
    groups = []
    latin_chars = []
    text = text.replace(".", "").replace(" ", "")
    for char in text:
        try:
            script = unicodedata.name(char).split(" ")[0]
            if script == 'LATIN':
                latin_chars.append(script)
            else:
                if script not in groups:
                    groups.append(script)
        except:
            if "UNK" not in groups:
                groups.append("UNK")
    return groups, len(latin_chars)

def remove_non_latin_characters(text):
    final_char = []
    groups_to_skip = ['HIRAGANA', 'CJK', 'KATAKANA','ARABIC', 'HANGUL', 'THAI','DEVANAGARI','BENGALI',
                      'THAANA','GUJARATI','CYRILLIC']
    for char in text:
        try:
            script = unicodedata.name(char).split(" ")[0]
            if script not in groups_to_skip:
                final_char.append(char)
        except:
            pass
    return "".join(final_char)

@udf(returnType=IntegerType())
def check_for_non_latin_characters(text):
    groups, latin_chars = group_non_latin_characters(text)
    if name_to_keep_ind(groups) == 1:
        return 1
    elif latin_chars > 30:
        return 1
    else:
        return 0
    
@udf(returnType=StringType())
def clean_title(old_title, keep_title):
    if keep_title:
        new_title = remove_non_latin_characters(old_title)
        if '<' in new_title:
            new_title = new_title.replace("<i>", "").replace("</i>","")\
                                 .replace("<sub>", "").replace("</sub>","") \
                                 .replace("<sup>", "").replace("</sup>","") \
                                 .replace("<em>", "").replace("</em>","") \
                                 .replace("<b>", "").replace("</b>","") \
                                 .replace("<I>", "").replace("</I>", "") \
                                 .replace("<SUB>", "").replace("</SUB>", "") \
                                 .replace("<scp>", "").replace("</scp>", "") \
                                 .replace("<font>", "").replace("</font>", "") \
                                 .replace("<inf>","").replace("</inf>", "") \
                                 .replace("<i /> ", "") \
                                 .replace("<p>", "").replace("</p>","") \
                                 .replace("<![CDATA[<B>", "").replace("</B>]]>", "") \
                                 .replace("<italic>", "").replace("</italic>","")\
                                 .replace("<title>", "").replace("</title>", "") \
                                 .replace("<br>", "").replace("</br>","").replace("<br/>","") \
                                 .replace("<B>", "").replace("</B>", "") \
                                 .replace("<em>", "").replace("</em>", "") \
                                 .replace("<BR>", "").replace("</BR>", "") \
                                 .replace("<title>", "").replace("</title>", "") \
                                 .replace("<strong>", "").replace("</strong>", "") \
                                 .replace("<formula>", "").replace("</formula>", "") \
                                 .replace("<roman>", "").replace("</roman>", "") \
                                 .replace("<SUP>", "").replace("</SUP>", "") \
                                 .replace("<SSUP>", "").replace("</SSUP>", "") \
                                 .replace("<sc>", "").replace("</sc>", "") \
                                 .replace("<subtitle>", "").replace("</subtitle>", "") \
                                 .replace("<emph/>", "").replace("<emph>", "").replace("</emph>", "") \
                                 .replace("""<p class="Body">""", "") \
                                 .replace("<TITLE>", "").replace("</TITLE>", "") \
                                 .replace("<sub />", "").replace("<sub/>", "") \
                                 .replace("<mi>", "").replace("</mi>", "") \
                                 .replace("<bold>", "").replace("</bold>", "") \
                                 .replace("<mtext>", "").replace("</mtext>", "") \
                                 .replace("<msub>", "").replace("</msub>", "") \
                                 .replace("<mrow>", "").replace("</mrow>", "") \
                                 .replace("</mfenced>", "").replace("</math>", "")

            if '<mml' in new_title:
                all_parts = [x for y in [i.split("mml:math>") for i in new_title.split("<mml:math")] for x in y if x]
                final_parts = []
                for part in all_parts:
                    if re.search(r"\>[$%#!^*\w.,/()+-]*\<", part):
                        pull_out = re.findall(r"\>[$%#!^*\w.,/()+-]*\<", part)
                        final_pieces = []
                        for piece in pull_out:
                            final_pieces.append(piece.replace(">", "").replace("<", ""))
                        
                        final_parts.append(" "+ "".join(final_pieces) + " ")
                    else:
                        final_parts.append(part)
                
                new_title = "".join(final_parts).strip()
            else:
                pass

            if '<xref' in new_title:
                new_title = re.sub(r"\<xref[^/]*\/xref\>", "", new_title)

            if '<inline-formula' in new_title:
                new_title = re.sub(r"\<inline-formula[^/]*\/inline-formula\>", "", new_title)

            if '<title' in new_title:
                new_title = re.sub(r"\<title[^/]*\/title\>", "", new_title)

            if '<p class=' in new_title:
                new_title = re.sub(r"\<p class=[^>]*\>", "", new_title)
            
            if '<span class=' in new_title:
                new_title = re.sub(r"\<span class=[^>]*\>", "", new_title)

            if 'mfenced open' in new_title:
                new_title = re.sub(r"\<mfenced open=[^>]*\>", "", new_title)
            
            if 'math xmlns' in new_title:
                new_title = re.sub(r"\<math xmlns=[^>]*\>", "", new_title)

        if '<' in new_title:
            new_title = new_title.replace(">i<", "").replace(">/i<", "") \
                                 .replace(">b<", "").replace(">/b<", "") \
                                 .replace("<inline-formula>", "").replace("</inline-formula>","")

        return new_title
    else:
        return ''

In [0]:
w2 = Window.partitionBy('micro_cluster_id').orderBy('random_num')

In [0]:
train_data = spark.read.parquet(f"{iteration_save_path}language_model/all_training_data/")

train_data.cache().count()

70674410

#### Only taking 1000 samples of each cluster ID in order to limit the amount of training data and also keep the labeled data balanced

In [0]:
train_data \
    .withColumn('random_num', F.rand()) \
    .withColumn('cluster_rank', F.row_number().over(w2)) \
    .filter(F.col('cluster_rank')<=1000) \
    .withColumn('keep_title', check_for_non_latin_characters(F.col('original_title')))\
    .withColumn('new_title', clean_title(F.col('original_title'), F.col('keep_title'))) \
    .select('paper_id','new_title','abstract','micro_cluster_id','short_label','long_label','keywords') \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}language_model/all_training_data_subset_new/")

In [0]:
spark.read.parquet(f"{iteration_save_path}language_model/all_training_data_subset_new/").count()

4521000

In [0]:
spark.read.parquet(f"{iteration_save_path}language_model/all_training_data_subset_new/") \
    .coalesce(1) \
    .write.mode('overwrite') \
    .parquet(f"{iteration_save_path}language_model/all_training_data_subset_single_file_new/")