In [1]:
import json
import pickle
import pandas as pd
import numpy as np
import random

In [2]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [3]:
import tensorflow as tf
from tensorflow.keras import mixed_precision
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import PolynomialDecay
from transformers.keras_callbacks import KerasMetricCallback
import evaluate
from datasets import load_dataset, load_metric, list_metrics
from transformers import create_optimizer
from transformers import create_optimizer, TFAutoModelForSequenceClassification, DistilBertTokenizer
from transformers import DataCollatorWithPadding, TFDistilBertForSequenceClassification
from transformers import TFRobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

In [4]:
tf.__version__

'2.13.0'

In [6]:
def get_final_citation_feature_for_model(list_of_links, num_to_take):
    if len(list_of_links) <= num_to_take:
        return list_of_links
    else:
        return random.sample(list_of_links, num_to_take)

In [7]:
def get_embedding(embedding):
    
    if isinstance(embedding, list):
        return np.array(embedding, dtype=np.float32)
    else:
        return np.zeros(384, dtype=np.float32)

In [9]:
def invert_abstract_to_abstract(invert_abstract):
    invert_abstract = json.loads(invert_abstract)
    ab_len = invert_abstract['IndexLength']
    
    if 30 < ab_len < 1000:
        abstract = [" "]*ab_len
        for key, value in invert_abstract['InvertedIndex'].items():
            for i in value:
                abstract[i] = key
        final_abstract = " ".join(abstract)
    else:
        final_abstract = None
    return final_abstract

def clean_abstract(abstract, inverted=True):
    if inverted:
        if abstract:
            abstract = invert_abstract_to_abstract(abstract)
        else:
            pass
    else:
        if isinstance(abstract, str):
            ab_len = len(abstract)
            if ab_len < 30:
                abstract = None
#     abstract = clean_text(abstract)
    return abstract

def clean_text(text):
    try:
        text = text.lower()

        text = re.sub('[^a-zA-Z0-9 ]+', ' ', text)
        text = re.sub(' +', ' ', text)
        text = text.strip()
        
    except:
        text = ""
    return text

In [10]:
def merge_title_and_abstract(title, abstract):
    if isinstance(title, str):
        if isinstance(abstract, str):
            return f"<TITLE> {title}\n<ABSTRACT> {abstract}"
        else:
            return f"<TITLE> {title}"
    else:
        if isinstance(abstract, str):
            return f"<TITLE> NONE\n<ABSTRACT> {abstract}"
        else:
            return ""

In [11]:
def create_vocab(df, column, starting_num = 0):
    # Create a vocab out of the column
    vocab = df[column].unique()

    # Create a dict that maps vocab to integers
    vocab_to_int = {word: i+starting_num for i, word in enumerate(vocab)}
    
    inv_vocab_to_int = {i:j for j,i in vocab_to_int.items()}

    return vocab_to_int, inv_vocab_to_int

In [12]:
def transform_citation_feature(citations, emb_vocab, gold_to_label_mapping, num_to_keep):
    if citations:
        mapped_cites = [gold_to_label_mapping.get(x) for x in citations if gold_to_label_mapping.get(x)]
        temp_feature = [emb_vocab[x] for x in mapped_cites]
    
        if len(temp_feature) < num_to_keep:
            return temp_feature + [0]*(num_to_keep - len(temp_feature))
        else:
            return temp_feature
    else:
        return [1] + [0]*(num_to_keep - 1)

In [13]:
def get_sorted_string_of_list(list_to_sort):
    list_to_sort.sort()
    return "|".join([str(x) for x in list_to_sort])

In [14]:
def open_pickle(pickle_path):
    # Open the pickle file
    with open(pickle_path, 'rb') as f:
        pickle_dict = pickle.load(f)

    return pickle_dict

In [15]:
def save_pickle(dictionary, file_path):
    # Save the dictionary as a pickle file
    with open(file_path, 'wb') as f:
        pickle.dump(dictionary, f)

#### Getting Gold Citations Mapping

In [16]:
gold_to_id_mapping = pd.read_parquet("s3://data-pull-from-justins-personal-s3/all_concepts_model_data/V4/citation_model_new_less_gold/gold_citation_papers_single_file/part-00000-tid-3822763855470313237-f26ac12e-de75-40aa-84fb-9fc177530185-12895-1-c000.snappy.parquet")
gold_to_id_mapping.shape

(126773, 4)

In [17]:
gold_citation_grouped = gold_to_id_mapping.groupby('gold_citation')['micro_cluster_id'].apply(list).reset_index()
gold_citation_grouped.shape

(124577, 2)

In [18]:
gold_citation_grouped['cluster_string'] = gold_citation_grouped['micro_cluster_id'].apply(get_sorted_string_of_list)

In [19]:
gold_citation_grouped.sample(2)

Unnamed: 0,gold_citation,micro_cluster_id,cluster_string
86397,2149443601,[1667],1667
28642,1999013095,[2315],2315


In [20]:
gold_to_id_mapping_dict = {x:y for x,y in zip(gold_citation_grouped['gold_citation'].tolist(), 
                                              gold_citation_grouped['cluster_string'].tolist())}

#### Processing All Training Data

In [16]:
model_name = "bert-base-multilingual-cased"
task = "openalex-topic-classification-title-abstract"
language_model_name = f"OpenAlex/{model_name}-finetuned-{task}"

In [22]:
all_data = pd.read_parquet("{path_to_all_data_from_004_spark_file}", 
                           columns=['paper_id','new_title','abstract','final_level_0_links',
                                     'final_level_1_links','journal_id','micro_cluster_id','short_label',
                                     'long_label'])

all_data.shape

(4521000, 9)

##### Process title and abstract

In [23]:
all_data['new_title'] = all_data['new_title'].apply(lambda x: None if x=='' else x)

In [24]:
all_data['abstract_processed'] = all_data['abstract'].apply(lambda x: clean_abstract(x, inverted=False))

In [25]:
all_data['title_abstract'] = all_data.apply(lambda x: merge_title_and_abstract(x.new_title, 
                                                                                                 x.abstract_processed), axis=1)

##### Process label

In [26]:
all_data['full_label'] = all_data.apply(lambda x: f"{x.micro_cluster_id}: {x.long_label}", axis=1)

In [27]:
cluster_labels = all_data[['micro_cluster_id','long_label']].drop_duplicates()

In [None]:
target_vocab, inv_target_vocab = create_vocab(all_data, 'full_label')

In [None]:
all_data['label'] = all_data['full_label'].apply(lambda x: target_vocab[x])

##### Process citation features

In [28]:
all_data['level_0_links_feature'] = all_data['final_level_0_links'].apply(lambda x: get_final_citation_feature_for_model(x.tolist(), 16))

In [29]:
all_data['level_1_links_feature'] = all_data['final_level_1_links'].apply(lambda x: get_final_citation_feature_for_model(x.tolist(), 128))

In [31]:
citation_feature_vocab, inv_citation_feature_vocab = create_vocab(gold_citation_grouped, 'cluster_string', starting_num = 2)

In [33]:
all_data['level_0_citation'] = all_data.apply(lambda x :transform_citation_feature(x.level_0_links_feature, citation_feature_vocab, 
                                                                                   gold_to_id_mapping_dict, 16), axis=1)

In [34]:
all_data['level_1_citation'] = all_data.apply(lambda x :transform_citation_feature(x.level_1_links_feature, citation_feature_vocab, 
                                                                                   gold_to_id_mapping_dict, 128), axis=1)

##### Save data and artifacts

In [35]:
all_data[['paper_id','full_label','title_abstract','level_0_citation',
               'level_1_citation','label','journal_id']]\
    .to_parquet("./model_artifacts/all_data_temp.parquet")

In [36]:
_ = save_pickle(target_vocab, './model_artifacts/target_vocab.pkl')
_ = save_pickle(inv_target_vocab , './model_artifacts/inv_target_vocab.pkl')
_ = save_pickle(citation_feature_vocab, './model_artifacts/citation_feature_vocab.pkl')
_ = save_pickle(inv_citation_feature_vocab , './model_artifacts/inv_citation_feature_vocab.pkl')
_ = save_pickle(gold_to_id_mapping_dict , './model_artifacts/gold_to_id_mapping_dict.pkl')

##### Tokenize title and abstract

In [17]:
all_data = pd.read_parquet("./model_artifacts/all_data_temp.parquet")

In [18]:
tokenizer = AutoTokenizer.from_pretrained(language_model_name, truncate=True)

In [19]:
title_abs_tok = tokenizer(all_data['title_abstract'].tolist(), max_length=512, truncation=True, padding='longest')

In [20]:
all_data['input_ids'] = [np.array(x) for x in title_abs_tok['input_ids']]
all_data['attention_mask'] = [np.array(x) for x in title_abs_tok['attention_mask']]

In [22]:
all_data[['paper_id','full_label','title_abstract','level_0_citation',
               'level_1_citation','label','input_ids','attention_mask','journal_id']]\
    .to_parquet("./model_artifacts/all_data.parquet")

#### Processing journal data

In [None]:
def get_journal_emb(journal_name):
    if check_for_non_latin_characters(journal_name) == 1:
        return emb_model.encode(str(journal_name))
    else:
        return np.zeros(384, dtype=np.float32)

In [16]:
all_processed_data = pd.read_parquet("./model_artifacts/all_data.parquet")

In [17]:
all_processed_data['journal_id'] = all_processed_data['journal_id'].fillna(-1).astype('int')

In [25]:
# These journal embeddings were created using the 005 file in the data preprocessing folder
journal_embs = open_pickle('./journal_embs.pkl')

In [27]:
shuffled_data = all_processed_data.sample(all_processed_data.shape[0], random_state=0)
shuffled_data.shape

(4521000, 9)

In [28]:
shuffled_data['journal_emb'] = shuffled_data['journal_id'].apply(lambda x: journal_embs.get(x, np.zeros(384, dtype=np.float32)))

##### Creating random even splits of the training data (so labeled are balanced)

In [29]:
shuffled_data['row_num'] = shuffled_data.groupby('label').cumcount() + 1

In [30]:
train = shuffled_data[shuffled_data['row_num']<=950].copy()
val = shuffled_data[(shuffled_data['row_num']>950) & (shuffled_data['row_num']<=975)].copy()
test = shuffled_data[shuffled_data['row_num']>975].copy()
print(train.shape)
print(val.shape)
print(test.shape)

(4294950, 11)
(113025, 11)
(113025, 11)


In [31]:
train.to_parquet("./model_artifacts/train.parquet")
val.to_parquet("./model_artifacts/val.parquet")
test.to_parquet("./model_artifacts/test.parquet")

#### Load Data and Train Model

In [6]:
def create_model(num_classes, emb_table_size):
    # Load finetuned language model
    model_name = "bert-base-multilingual-cased"
    task = "openalex-topic-classification-title-abstract"
    language_model_name = f"OpenAlex/{model_name}-finetuned-{task}"
    language_model = TFAutoModelForSequenceClassification.from_pretrained(language_model_name, 
                                                                          output_hidden_states=True)
    language_model.trainable = False


    # Inputs
    ids = tf.keras.layers.Input((512,), dtype=tf.int64, name='ids')
    mask = tf.keras.layers.Input((512,), dtype=tf.int64, name='mask')
    citation_0 = tf.keras.layers.Input((16,), dtype=tf.int64, name='citation_0')
    citation_1 = tf.keras.layers.Input((128,), dtype=tf.int64, name='citation_1')
    journal = tf.keras.layers.Input((384,), dtype=tf.float32, name='journal_emb')
    
    language_model_output = language_model(input_ids=ids, attention_mask=mask).hidden_states[-1]
    pooled_language_model_output = tf.keras.layers.GlobalAveragePooling1D()(language_model_output)
    
    citation_emb_layer = tf.keras.layers.Embedding(input_dim=emb_table_size, output_dim=256, mask_zero=True, 
                                                   trainable=True, name='citation_emb_layer')

    citation_0_emb = citation_emb_layer(citation_0)
    citation_1_emb = citation_emb_layer(citation_1)

    pooled_citation_0 = tf.keras.layers.GlobalAveragePooling1D()(citation_0_emb)
    pooled_citation_1 = tf.keras.layers.GlobalAveragePooling1D()(citation_1_emb)

    concat_data = tf.keras.layers.Concatenate(name='concat_data', axis=-1)([pooled_language_model_output, 
                                                                            pooled_citation_0, 
                                                                            pooled_citation_1, journal])

    # Dense layer 1
    dense_output = tf.keras.layers.Dense(2048, activation='relu', 
                                         kernel_regularizer='L2', name="dense_1")(concat_data)
    dense_output = tf.keras.layers.Dropout(0.20, name="dropout_1")(dense_output)
    dense_output = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layer_norm_1")(dense_output)
    
    # Dense layer 2
    dense_output = tf.keras.layers.Dense(1024, activation='relu', 
                                         kernel_regularizer='L2', name="dense_2")(dense_output)
    dense_output = tf.keras.layers.Dropout(0.20, name="dropout_2")(dense_output)
    dense_output = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layer_norm_2")(dense_output)

    # Dense layer 3
    dense_output = tf.keras.layers.Dense(512, activation='relu', 
                                         kernel_regularizer='L2', name="dense_3")(dense_output)
    dense_output = tf.keras.layers.Dropout(0.20, name="dropout_3")(dense_output)
    dense_output = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layer_norm_3")(dense_output)

    class_prior = 1/len(target_vocab)
    last_layer_weight_init = tf.keras.initializers.Constant(class_prior)
    last_layer_bias_init = tf.keras.initializers.Constant(-np.log((1-class_prior)/class_prior))
    
    output_layer = tf.keras.layers.Dense(num_classes, kernel_initializer=last_layer_weight_init,
                                         bias_initializer=last_layer_bias_init,
                                         activation='sigmoid', name='output_layer')(dense_output)
    model = tf.keras.Model(inputs=[ids, mask, citation_0, citation_1, journal], outputs=output_layer)

    loss_fn = tf.keras.losses.CategoricalFocalCrossentropy()

    # Compile the model
    model.compile(optimizer=tf.keras.optimizers.AdamW(), 
                  loss=loss_fn,
                  metrics=[tf.keras.metrics.CategoricalAccuracy(), 
                           tf.keras.metrics.TopKCategoricalAccuracy(k=2, name='top_2_categorical_accuracy'),
                           tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_categorical_accuracy'),
                           tf.keras.metrics.TopKCategoricalAccuracy(k=10, name='top_10_categorical_accuracy')])

    return model

In [32]:
target_vocab = open_pickle('./model_artifacts/target_vocab.pkl')
inv_target_vocab = open_pickle('./model_artifacts/inv_target_vocab.pkl')
citation_feature_vocab = open_pickle('./model_artifacts/citation_feature_vocab.pkl')
inv_citation_feature_vocab = open_pickle('./model_artifacts/inv_citation_feature_vocab.pkl')
gold_to_id_mapping_dict = open_pickle('./model_artifacts/gold_to_id_mapping_dict.pkl')

In [19]:
def preprocess_labels_and_data(examples):
    """
    This function was used to introduce "missing" data into the training data. Because the data we are
    training on is not representative of the data seen in production, we need to teach the model how to
    make predictions when not all of the data is available.
    """
    examples['label'] = tf.keras.utils.to_categorical(examples['label'], num_classes=len(target_vocab))
    if (examples['citation_0'][0]!=1) | (examples['citation_1'][0]!=1):
        rand_num = random.random()
        if rand_num < 0.15:
            examples['ids'] = [101, 102] + [0]*510
            examples['mask'] = [1, 1] + [0]*510

        rand_num = random.random()
        if rand_num < 0.25:
            examples['journal_emb'] = [0.0]*384

    if (examples['citation_0'][0]==1) & (examples['citation_1'][0]==1):
        if np.mean(examples['journal_emb']) != 0.0:
            rand_num = random.random()
            if rand_num < 0.10:
                examples['ids'] = [101, 102] + [0]*510
                examples['mask'] = [1, 1] + [0]*510

    return examples

In [20]:
def preprocess_labels(examples):
    examples['label'] = tf.keras.utils.to_categorical(examples['label'], num_classes=len(target_vocab))
    return examples

In [21]:
# Using the HuggingFace library to load the dataset
all_dataset = load_dataset("parquet", data_files={'train': ["./model_artifacts/train.parquet"]}) \
    .rename_column("level_0_citation", "citation_0") \
    .rename_column("level_1_citation", "citation_1") \
    .rename_column("input_ids", "ids") \
    .rename_column("attention_mask", "mask").map(preprocess_labels_and_data)

In [22]:
# Using the HuggingFace library to load the dataset
all_dataset_val = load_dataset("parquet", data_files={'val': ["./model_artifacts/val.parquet"]}) \
    .rename_column("level_0_citation", "citation_0") \
    .rename_column("level_1_citation", "citation_1") \
    .rename_column("input_ids", "ids") \
    .rename_column("attention_mask", "mask").map(preprocess_labels)

In [None]:
# Allow for use of multiple GPUs
strategy = tf.distribute.MirroredStrategy()
batch_size = 2048

with strategy.scope():

    concept_model = create_model(len(target_vocab), len(citation_feature_vocab)+2)

    tf_train_dataset = all_dataset['train'].to_tf_dataset(
        columns=['ids','mask','citation_0','citation_1','journal_emb'],
        label_cols=["label"],
        batch_size=batch_size,
        shuffle=True
    )

    tf_val_dataset = all_dataset_val['val'].to_tf_dataset(
        columns=['ids','mask','citation_0','citation_1','journal_emb'],
        label_cols=["label"],
        batch_size=batch_size,
        shuffle=True
    )


In [8]:
concept_model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 ids (InputLayer)            [(None, 512)]                0         []                            
                                                                                                  
 mask (InputLayer)           [(None, 512)]                0         []                            
                                                                                                  
 citation_0 (InputLayer)     [(None, 16)]                 0         []                            
                                                                                                  
 citation_1 (InputLayer)     [(None, 128)]                0         []                            
                                                                                              

In [None]:
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=2)

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath='./model_checkpoints/{epoch:02d}-{val_loss:.3f}-{val_categorical_accuracy:.4f}-{val_top_5_categorical_accuracy:.4f}.keras',
        save_weights_only=False,
        save_best_only=False)

callbacks = [early_stopping, model_checkpoint_callback]

In [None]:
history = concept_model.fit(tf_train_dataset, epochs=20, validation_data=tf_val_dataset, callbacks=callbacks)