In [1]:
import pickle
import json
import random
import pandas as pd
import numpy as np
import pathlib
import unicodedata
import torch

In [3]:
# !pip install tensorflow==2.13.0

In [4]:
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, pipeline, AutoTokenizer, TextClassificationPipeline, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer

2024-01-26 16:39:42.381705: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-01-26 16:39:42.430216: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
def name_to_keep_ind(groups):
    # Groups of characters that do not perform well
    groups_to_skip = ['HIRAGANA', 'CJK', 'KATAKANA','ARABIC', 'HANGUL', 'THAI','DEVANAGARI','BENGALI',
                      'THAANA','GUJARATI','CYRILLIC']
    
    if any(x in groups_to_skip for x in groups):
        return 0
    else:
        return 1
    
def remove_non_latin_characters(text):
    final_char = []
    groups_to_skip = ['HIRAGANA', 'CJK', 'KATAKANA','ARABIC', 'HANGUL', 'THAI','DEVANAGARI','BENGALI',
                      'THAANA','GUJARATI','CYRILLIC']
    for char in text:
        try:
            script = unicodedata.name(char).split(" ")[0]
            if script not in groups_to_skip:
                final_char.append(char)
        except:
            pass
    return "".join(final_char)
    
def group_non_latin_characters(text):
    groups = []
    latin_chars = []
    text = text.replace(".", "").replace(" ", "")
    for char in text:
        try:
            script = unicodedata.name(char).split(" ")[0]
            if script == 'LATIN':
                latin_chars.append(script)
            else:
                if script not in groups:
                    groups.append(script)
        except:
            if "UNK" not in groups:
                groups.append("UNK")
    return groups, len(latin_chars)

def check_for_non_latin_characters(text):
    groups, latin_chars = group_non_latin_characters(str(text))
    if name_to_keep_ind(groups) == 1:
        return 1
    elif latin_chars > 20:
        return 1
    else:
        return 0

In [6]:
def get_journal_emb(journal_name):
    # Strip white space
    if isinstance(journal_name, str):
        journal_name = journal_name.strip()

        # Removing all journal names with eBook (most are not descriptive)
        if 'eBooks' in journal_name:
            return np.zeros(384, dtype=np.float32)

        # Check if non-latin characters are dominant (embedding model not good for that)
        elif check_for_non_latin_characters(journal_name) == 1:
            return emb_model.encode(journal_name)

        elif journal_name == '':
            return np.zeros(384, dtype=np.float32)

        else:
            return np.zeros(384, dtype=np.float32)
    else:
        return np.zeros(384, dtype=np.float32)

In [7]:
def save_pickle(dictionary, file_path):
    # Save the dictionary as a pickle file
    with open(file_path, 'wb') as f:
        pickle.dump(dictionary, f)
        
def open_pickle(pickle_path):
    # Open the pickle file
    with open(pickle_path, 'rb') as f:
        pickle_dict = pickle.load(f)
    return pickle_dict

In [47]:
def tokenize(seq, **kwargs):
    tok_data = tokenizer(seq, max_length=512, padding='max_length', truncation=True, **kwargs)
    return [tok_data['input_ids'], tok_data['attention_mask']]

In [9]:
def move_level_0_to_1(level_0, level_1):
    return list(set(level_0 + level_1))

In [10]:
def get_final_citations_for_model(list_of_links, num_to_take):
    if len(list_of_links) <= num_to_take:
        return list_of_links
    else:
        return random.sample(list_of_links, num_to_take)

def get_final_citations_feature(citations, num_to_keep):
    if citations:
        new_citations = get_final_citations_for_model(citations, num_to_keep)
        mapped_cites = [gold_to_label_mapping.get(x) for x in new_citations 
                        if gold_to_label_mapping.get(x)]
        temp_feature = [citation_feature_vocab[x] for x in mapped_cites]
    
        if len(temp_feature) < num_to_keep:
            return temp_feature + [0]*(num_to_keep - len(temp_feature))
        else:
            return temp_feature
    else:
        return [1] + [0]*(num_to_keep - 1)

In [11]:
def merge_title_and_abstract(title, abstract):
    if isinstance(title, str):
        if isinstance(abstract, str):
            if len(abstract) >=30:
                return f"<TITLE> {title}\n<ABSTRACT> {abstract[:2500]}"
            else:
                return f"<TITLE> {title}"
        else:
            return f"<TITLE> {title}"
    else:
        if isinstance(abstract, str):
            if len(abstract) >=30:
                return f"<TITLE> NONE\n<ABSTRACT> {abstract[:2500]}"
            else:
                return ""
        else:
            return ""

In [12]:
def clean_title(old_title):
    keep_title = check_for_non_latin_characters(old_title)
    if keep_title == 1:
        new_title = remove_non_latin_characters(old_title)
        if '<' in new_title:
            new_title = new_title.replace("<i>", "").replace("</i>","")\
                                 .replace("<sub>", "").replace("</sub>","") \
                                 .replace("<sup>", "").replace("</sup>","") \
                                 .replace("<em>", "").replace("</em>","") \
                                 .replace("<b>", "").replace("</b>","") \
                                 .replace("<I>", "").replace("</I>", "") \
                                 .replace("<SUB>", "").replace("</SUB>", "") \
                                 .replace("<scp>", "").replace("</scp>", "") \
                                 .replace("<font>", "").replace("</font>", "") \
                                 .replace("<inf>","").replace("</inf>", "") \
                                 .replace("<i /> ", "") \
                                 .replace("<p>", "").replace("</p>","") \
                                 .replace("<![CDATA[<B>", "").replace("</B>]]>", "") \
                                 .replace("<italic>", "").replace("</italic>","")\
                                 .replace("<title>", "").replace("</title>", "") \
                                 .replace("<br>", "").replace("</br>","").replace("<br/>","") \
                                 .replace("<B>", "").replace("</B>", "") \
                                 .replace("<em>", "").replace("</em>", "") \
                                 .replace("<BR>", "").replace("</BR>", "") \
                                 .replace("<title>", "").replace("</title>", "") \
                                 .replace("<strong>", "").replace("</strong>", "") \
                                 .replace("<formula>", "").replace("</formula>", "") \
                                 .replace("<roman>", "").replace("</roman>", "") \
                                 .replace("<SUP>", "").replace("</SUP>", "") \
                                 .replace("<SSUP>", "").replace("</SSUP>", "") \
                                 .replace("<sc>", "").replace("</sc>", "") \
                                 .replace("<subtitle>", "").replace("</subtitle>", "") \
                                 .replace("<emph/>", "").replace("<emph>", "").replace("</emph>", "") \
                                 .replace("""<p class="Body">""", "") \
                                 .replace("<TITLE>", "").replace("</TITLE>", "") \
                                 .replace("<sub />", "").replace("<sub/>", "") \
                                 .replace("<mi>", "").replace("</mi>", "") \
                                 .replace("<bold>", "").replace("</bold>", "") \
                                 .replace("<mtext>", "").replace("</mtext>", "") \
                                 .replace("<msub>", "").replace("</msub>", "") \
                                 .replace("<mrow>", "").replace("</mrow>", "") \
                                 .replace("</mfenced>", "").replace("</math>", "")

            if '<mml' in new_title:
                all_parts = [x for y in [i.split("mml:math>") for i in new_title.split("<mml:math")] for x in y if x]
                final_parts = []
                for part in all_parts:
                    if re.search(r"\>[$%#!^*\w.,/()+-]*\<", part):
                        pull_out = re.findall(r"\>[$%#!^*\w.,/()+-]*\<", part)
                        final_pieces = []
                        for piece in pull_out:
                            final_pieces.append(piece.replace(">", "").replace("<", ""))
                        
                        final_parts.append(" "+ "".join(final_pieces) + " ")
                    else:
                        final_parts.append(part)
                
                new_title = "".join(final_parts).strip()
            else:
                pass

            if '<xref' in new_title:
                new_title = re.sub(r"\<xref[^/]*\/xref\>", "", new_title)

            if '<inline-formula' in new_title:
                new_title = re.sub(r"\<inline-formula[^/]*\/inline-formula\>", "", new_title)

            if '<title' in new_title:
                new_title = re.sub(r"\<title[^/]*\/title\>", "", new_title)

            if '<p class=' in new_title:
                new_title = re.sub(r"\<p class=[^>]*\>", "", new_title)
            
            if '<span class=' in new_title:
                new_title = re.sub(r"\<span class=[^>]*\>", "", new_title)

            if 'mfenced open' in new_title:
                new_title = re.sub(r"\<mfenced open=[^>]*\>", "", new_title)
            
            if 'math xmlns' in new_title:
                new_title = re.sub(r"\<math xmlns=[^>]*\>", "", new_title)

        if '<' in new_title:
            new_title = new_title.replace(">i<", "").replace(">/i<", "") \
                                 .replace(">b<", "").replace(">/b<", "") \
                                 .replace("<inline-formula>", "").replace("</inline-formula>","")

        return new_title
    else:
        return ''

In [13]:
def clean_abstract(raw_abstract, inverted=False):
    if inverted:
        if isinstance(raw_abstract, dict) | isinstance(raw_abstract, str):
            if isinstance(raw_abstract, dict):
                invert_abstract = raw_abstract
            else:
                invert_abstract = json.loads(raw_abstract)
            
            if invert_abstract.get('IndexLength'):
                ab_len = invert_abstract['IndexLength']

                if ab_len > 15:
                    abstract = [" "]*ab_len
                    for key, value in invert_abstract['InvertedIndex'].items():
                        for i in value:
                            abstract[i] = key
                    final_abstract = " ".join(abstract)[:2500]
                else:
                    final_abstract = None
            else:
                if len(invert_abstract) > 15:
                    abstract = [" "]*1200
                    for key, value in invert_abstract.items():
                        for i in value:
                            try:
                                abstract[i] = key
                            except:
                                pass
                    final_abstract = " ".join(abstract)[:2500]
                else:
                    final_abstract = None
                
        else:
            final_abstract = None
    else:
        ab_len = len(raw_abstract)
        if ab_len > 30:
            final_abstract = raw_abstract[:2500]
        else:
            final_abstract = None
            
    return final_abstract

In [14]:
def create_input_feature(features):

    # Convert to a tensorflow feature
    input_feature = [tf.expand_dims(tf.convert_to_tensor(x), axis=0) for x in [np.array(features[0], dtype=np.int32), 
                                                                             np.array(features[1], dtype=np.int32), 
                                                                             features[2]]]

    return input_feature

In [15]:
def get_gold_citations_from_all_citations(all_citations, gold_dict, non_gold_dict):
    if isinstance(all_citations, list):
        if len(all_citations) > 200:
            all_citations = random.sample(all_citations, 200)
        
        level_0_gold_temp = [[x, gold_dict.get(x)] for x in all_citations if gold_dict.get(x)]

        level_1_gold_temp = [non_gold_dict.get(x) for x in all_citations if non_gold_dict.get(x)]

        level_0_gold = [x[0] for x in level_0_gold_temp]
        level_1_gold = [y for z in [x[1] for x in level_0_gold_temp] for y in z] + \
                        [x for y in level_1_gold_temp for x in y]

        return level_0_gold, level_1_gold
    else:
        return [], []

In [16]:
def get_lang_model_output(input_ids, attention_mask):
    """
    Returning XLA optimized output from language model
    
    Input:
    input_ids: tokenized title/abstract
    attention_mask: tokenized title/abstract attention mask
    
    Output:
    last layer output from language model
    """
    return xla_predict_lang_model(input_ids=input_ids, attention_mask=attention_mask).hidden_states[-1]
    

def create_model(num_classes, emb_table_size, model_chkpt, topk=5):
    """
    Function to create full model.
    
    Input:
    num_classes: number of classes
    emb_table_size: size of embedding table
    model_chkpt: path to model checkpoint
    topk: number of predictions to return
    
    Output:
    model: full model
    """
    # Inputs
    citation_0 = tf.keras.layers.Input((16,), dtype=tf.int64, name='citation_0')
    citation_1 = tf.keras.layers.Input((128,), dtype=tf.int64, name='citation_1')
    journal = tf.keras.layers.Input((384,), dtype=tf.float32, name='journal_emb')
    language_model_output = tf.keras.layers.Input((512, 768,), dtype=tf.float32, name='lang_model_output')
    
    # Create a multi-class classification model using functional API
    pooled_language_model_output = tf.keras.layers.GlobalAveragePooling1D()(language_model_output)
    citation_emb_layer = tf.keras.layers.Embedding(input_dim=emb_table_size, output_dim=256, mask_zero=True, 
                                                   trainable=True, name='citation_emb_layer')

    citation_0_emb = citation_emb_layer(citation_0)
    citation_1_emb = citation_emb_layer(citation_1)

    pooled_citation_0 = tf.keras.layers.GlobalAveragePooling1D()(citation_0_emb)
    pooled_citation_1 = tf.keras.layers.GlobalAveragePooling1D()(citation_1_emb)

    concat_data = tf.keras.layers.Concatenate(name='concat_data', axis=-1)([pooled_language_model_output, pooled_citation_0, 
                                                                            pooled_citation_1, journal])

    # Dense layer 1
    dense_output = tf.keras.layers.Dense(2048, activation='relu', kernel_regularizer='L2', name="dense_1")(concat_data)
    dense_output = tf.keras.layers.Dropout(0.20, name="dropout_1")(dense_output)
    dense_output = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layer_norm_1")(dense_output)
    
    # Dense layer 2
    dense_output = tf.keras.layers.Dense(1024, activation='relu', kernel_regularizer='L2', name="dense_2")(dense_output)
    dense_output = tf.keras.layers.Dropout(0.20, name="dropout_2")(dense_output)
    dense_output = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layer_norm_2")(dense_output)

    # Dense layer 3
    dense_output_l3 = tf.keras.layers.Dense(512, activation='relu', kernel_regularizer='L2', name="dense_3")(dense_output)
    dense_output = tf.keras.layers.Dropout(0.20, name="dropout_3")(dense_output_l3)
    dense_output = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layer_norm_3")(dense_output)
    
    output_layer = tf.keras.layers.Dense(num_classes, activation='sigmoid', name='output_layer')(dense_output)
    topk_outputs = tf.math.top_k(output_layer, k=topk)
    
    model = tf.keras.Model(inputs=[citation_0, citation_1, journal, language_model_output], 
                           outputs=topk_outputs)

    model.load_weights(model_chkpt)
    model.trainable = False

    return model

def get_final_ids_and_scores_bad(topic_ids, score, labels, title, abstract, threshold=0.04):
    """
    Function to apply some rules to get the final prediction (some clusters performed worse than others).
    
    Input:
    topic_ids: all ids for raw prediction output
    score: all scores for raw prediction output
    labels: all labels for raw prediction output
    title: title of the work
    abstract: abstract of the work
    
    Output:
    final_ids: post-processed final ids
    final_scores: post-processed final scores
    final_labels: post-processed final labels
    """
    final_ids = [-1]
    final_scores = [0.0]
    final_labels = [None]
    if any(topic_id in topic_ids for topic_id in [13241]):
        return final_ids, final_scores, final_labels
    elif any(topic_id in topic_ids for topic_id in [12705,13003]):
        if title != '':
            if check_for_non_latin_characters(title) == 1:
                if len(title.split(" ")) > 9:
                    if not isinstance(abstract, str):
                        final_ids = [x for x,y in zip(topic_ids, score) if y > threshold]
                        final_scores = [y for y in score if y > threshold]
                        final_labels = [x for x,y in zip(labels, score) if y > threshold]
                        if final_ids:
                            return final_ids, final_scores, final_labels
                        else:
                            return [-1], [0.0], [None]
                    elif isinstance(abstract, str):
                        if check_for_non_latin_characters(abstract) == 1:
                            final_ids = [x for x,y in zip(topic_ids, score) if y > threshold]
                            final_scores = [y for y in score if y > 0.05]
                            final_labels = [x for x,y in zip(labels, score) if y > threshold]
                            if final_ids:
                                return final_ids, final_scores, final_labels
                            else:
                                return [-1], [0.0], [None]
                        else:
                            return final_ids, final_scores, final_labels
                    else:
                        return final_ids, final_scores, final_labels
                else:
                    return final_ids, final_scores, final_labels
            else:
                return final_ids, final_scores, final_labels
        else:
            return final_ids, final_scores, final_labels
    else:
        if any(topic_id in topic_ids for topic_id in [12718,14377,13686,13723]):
            final_ids = [x for x,y in zip(topic_ids, score) if (x not in [12718,14377,13686,13723]) & (y > 0.80)]
            final_scores = [y for x,y in zip(topic_ids, score) if (x not in [12718,14377,13686,13723]) & (y > 0.80)]
            final_labels = [y for x,y,z in zip(topic_ids, labels, score) if (x not in [12718,14377,13686,13723]) & (z > 0.80)]
            if final_ids:
                return final_ids, final_scores, final_labels
            else:
                return [-1], [0.0], [None]
        elif any(topic_id in topic_ids for topic_id in [13064, 13537]):
            if title == 'Frontmatter':
                return [-1], [0.0], [None]
            else:
                final_ids = [x for x,y in zip(topic_ids, score) if (((x in [13064, 13537]) & (y > 0.95)) | 
                                                                ((x not in [13064, 13537]) & (y > threshold)))]
                final_scores = [y for x,y in zip(topic_ids, score) if (((x in [13064, 13537]) & (y > 0.95)) | 
                                                                    ((x not in [13064, 13537]) & (y > threshold)))]
                final_labels = [z for x,y,z in zip(topic_ids, score, labels) if (((x in [13064, 13537]) & (y > 0.95)) | 
                                                                    ((x not in [13064, 13537]) & (y > threshold)))]
                if final_ids:
                    return final_ids, final_scores, final_labels
                else:
                    return [-1], [0.0], [None]
        elif any(topic_id in topic_ids for topic_id in [11893, 13459]):
            test_scores = [y for x,y in zip(topic_ids, score) if (x in [11893, 13459])]
            if topic_ids[0] in [11893, 13459]:
                first_pred = 1
            else:
                first_pred = 0
            
            if [x for x in test_scores if x > 0.95] & (first_pred == 1):
                final_ids = [x for x,y in zip(topic_ids, score) if y > threshold]
                final_scores = [y for y in score if y > 0.05]
                final_labels = [x for x,y in zip(labels, score) if y > threshold]

                if final_ids:
                    return final_ids, final_scores, final_labels
                else:
                    return [-1], [0.0], [None]
            elif first_pred == 0:
                final_ids = [x for x,y in zip(topic_ids, score) if y > threshold]
                final_scores = [y for y in score if y > threshold]
                final_labels = [x for x,y in zip(labels, score) if y > threshold]

                if final_ids:
                    return final_ids, final_scores, final_labels
                else:
                    return [-1], [0.0], [None]
            else:
                return [-1], [0.0], [None]
        else:
            if isinstance(abstract, str) & (title != ''):
                if (check_for_non_latin_characters(title) == 1) & (check_for_non_latin_characters(abstract) == 1):
                    final_ids = [x for x,y in zip(topic_ids, score) if y > threshold]
                    final_scores = [y for y in score if y > threshold]
                    final_labels = [x for x,y in zip(labels, score) if y > threshold]
    
                    if final_ids:
                        return final_ids, final_scores, final_labels
                    else:
                        return [-1], [0.0], [None]
                else:
                    return [-1], [0.0], [None]
            elif title != '':
                if (check_for_non_latin_characters(title) == 1):
                    final_ids = [x for x,y in zip(topic_ids, score) if y > threshold]
                    final_scores = [y for y in score if y > threshold]
                    final_labels = [x for x,y in zip(labels, score) if y > threshold]
    
                    if final_ids:
                        return final_ids, final_scores, final_labels
                    else:
                        return [-1], [0.0], [None]
                else:
                    return [-1], [0.0], [None]
            elif isinstance(abstract, str):
                if (check_for_non_latin_characters(abstract) == 1):
                    final_ids = [x for x,y in zip(topic_ids, score) if y > threshold]
                    final_scores = [y for y in score if y > threshold]
                    final_labels = [x for x,y in zip(labels, score) if y > threshold]
    
                    if final_ids:
                        return final_ids, final_scores, final_labels
                    else:
                        return [-1], [0.0], [None]
                else:
                    return [-1], [0.0], [None]
            else:
                return [-1], [0.0], [None]

def process_data_as_df(new_df):
    """
    Function to process data as a dataframe (in batch).
    
    Input:
    new_df: dataframe of data
    
    Output:
    input_df: dataframe of data with predictions
    """
    input_df = new_df.copy()
    # Get citations into integer format
    input_df['referenced_works'] = input_df['referenced_works'].apply(lambda x: [int(i.split("https://openalex.org/W")[1]) for 
                                                                             i in x])

     # Process title and abstract and tokenize
    input_df['title'] = input_df['title'].apply(lambda x: clean_title(x))
    input_df['abstract_inverted_index'] = input_df.apply(lambda x: clean_abstract(x.abstract_inverted_index, x.inverted), axis=1)
    title_abstract = input_df.apply(lambda x: merge_title_and_abstract(x.title, x.abstract_inverted_index), axis=1).tolist()

    # pipeline_output = test_pipeline(title_abstract)
    tok_inputs_pt = tokenize(title_abstract, return_tensors='pt')
    with torch.no_grad():
        last_output = pt_model(*tok_inputs_pt).hidden_states[-1]
    lang_model_output = last_output.numpy()
    
    # Take citations and return only gold citations (and then convert to label ids)
    input_df['referenced_works'] = input_df['referenced_works'].apply(lambda x: get_gold_citations_from_all_citations(x, gold_dict, 
                                                                                                                      non_gold_dict))
    input_df['citation_0'] = input_df['referenced_works'].apply(lambda x: get_final_citations_feature(x[0], 16))
    input_df['citation_1'] = input_df['referenced_works'].apply(lambda x: get_final_citations_feature(x[1], 128))    
    
    # Take in journal name and output journal embedding
    input_df['journal_emb'] = input_df['journal_display_name'].apply(get_journal_emb)

    # Check completeness of input data
    input_df['score_data'] = input_df\
        .apply(lambda x: 0 if ((x.title == "") & 
                               (not x.abstract_inverted_index) & 
                               (x.citation_0[0]==1) & 
                               (x.citation_1[0]==1)) else 1, axis=1)

    data_to_score = input_df[input_df['score_data']==1].copy()
    data_to_not_score = input_df[input_df['score_data']==0][['UID']].copy()

    if data_to_score.shape[0] > 0:
        # Transform into output for model
        data_to_score['input_feature'] = data_to_score.apply(lambda x: create_input_feature([x.citation_0, x.citation_1, 
                                                                                             x.journal_emb]), axis=1)
        
        all_rows = [tf.convert_to_tensor([x[0][0] for x in data_to_score['input_feature'].tolist()]), 
                    tf.convert_to_tensor([x[1][0] for x in data_to_score['input_feature'].tolist()]), 
                    tf.convert_to_tensor([x[2][0] for x in data_to_score['input_feature'].tolist()]), 
                    tf.convert_to_tensor(lang_model_output)]
        
        preds = xla_predict(all_rows)
        
        data_to_score['preds'] = preds.indices.numpy().tolist()
        data_to_score['scores'] = preds.values.numpy().tolist()
    else:
        data_to_score['preds'] = [[-1]]*data_to_not_score.shape[0]
        data_to_score['scores'] = [[0.0000]]*data_to_not_score.shape[0]
    
    data_to_not_score['preds'] = [[-1]]*data_to_not_score.shape[0]
    data_to_not_score['scores'] = [[0.0000]]*data_to_not_score.shape[0]
    
    return input_df[['UID','title','abstract_inverted_index']].merge(pd.concat([data_to_score[['UID','preds','scores']], 
                                              data_to_not_score[['UID','preds','scores']]], axis=0), 
                                   how='left', on='UID')

def last_pred_check(old_preds, old_scores, old_labels):
    """
    Function to apply some rules to get the final prediction based on scores
    
    Input:
    old_preds: all ids for prediction output
    old_scores: all scores for prediction output
    old_labels: all labels for prediction output
    
    Output:
    final_ids: post-processed final ids
    final_scores: post-processed final scores
    final_labels: post-processed final labels
    """
    pred_scores = [[x,y,z] for x,y,z in zip(old_preds, old_scores, old_labels)]

    # if any of scores are over 0.9
    if [x[1] for x in pred_scores if x[1] > 0.9]:
        final_pred_scores = [[x[0], x[1], x[2]] for x in pred_scores if x[1] > 0.9]
    elif len(pred_scores) == 1:
        final_pred_scores = pred_scores.copy()
    elif len(pred_scores) == 2:
        scores = [x[1] for x in pred_scores]
        if scores[1] < (scores[0]/2):
            final_pred_scores = pred_scores[:1].copy()
        else:
            final_pred_scores = pred_scores.copy()
    else:
        preds = [x[0] for x in pred_scores]
        scores = [x[1] for x in pred_scores]
        labels = [x[2] for x in pred_scores]

        score_sum = scores[0]
        final_pred_scores = pred_scores[:1].copy()
        for i, (pred, score, label) in enumerate(zip(preds[1:], scores[1:], labels[1:])):
            if score < (score_sum/(i+1)*0.85):
                break
            else:
                final_pred_scores.append([pred, score, label])
                score_sum += score

    final_preds = [x[0] for x in final_pred_scores]
    final_scores = [x[1] for x in final_pred_scores]
    final_labels = [x[2] for x in final_pred_scores]
    return final_preds, final_scores, final_labels

In [17]:
# from typing import Dict

In [18]:
# Testing a custom pipeline for speed
class CustomHiddenOutputPipeline(TextClassificationPipeline):
    def __init__(self, model_path, *args, **kwargs):
        super().__init__(
            model=TFAutoModelForSequenceClassification.from_pretrained(model_path, output_hidden_states=True_path, truncate=True),
            *args,
            **kwargs
        )
        self.model.trainable = False

    def preprocess(self, inputs, **tokenizer_kwargs):
        return_tensors = self.framework
        if isinstance(inputs, dict):
            return self.tokenizer(**inputs, return_tensors=return_tensors, **tokenizer_kwargs)
        elif isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], list) and len(inputs[0]) == 2:
            # It used to be valid to use a list of list of list for text pairs, keeping this path for BC
            return self.tokenizer(
                text=inputs[0][0], text_pair=inputs[0][1], return_tensors=return_tensors, **tokenizer_kwargs
            )
        elif isinstance(inputs, list):
            # This is likely an invalid usage of the pipeline attempting to pass text pairs.
            raise ValueError(
                "The pipeline received invalid inputs, if you are trying to send text pairs, you can try to send a"
                ' dictionary `{"text": "My text", "text_pair": "My pair"}` in order to send a text pair.'
            )
        return self.tokenizer(inputs, return_tensors=return_tensors, max_length=512, truncation=True, padding='max_length')),
            tokenizer=AutoTokenizer.from_pretrained(model

    def _forward(self, model_inputs):
        # Forward
        outputs = self.model(**model_inputs)
        hidden_state = outputs[1]

        return {
            "hidden_state": hidden_state
        }

    def postprocess(self, model_outputs):
        outputs = model_outputs["hidden_state"][-1].numpy()
        return {"hidden_state": outputs}

In [106]:
test_pipeline = CustomHiddenOutputPipeline(model_path='OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract')

Some layers from the model checkpoint at OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract were not used when initializing TFBertForSequenceClassification: ['dropout_37']
- This IS expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertForSequenceClassification were initialized from the model checkpoint at OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract.
If your task is similar to the task the model of the checkpoint was trained on, you can already use

In [41]:
##### The following are empty inputs for each feature (title, abstract, citation_0, citation_1, journal_emb
# [101, 102] + [0]*510
# [1, 1] + [0]*510
# [1]+[0]*15,
# [1]+[0]*127, 
# np.zeros(384, dtype=np.float32)

artifacts_folder = './full_model_iter6' # Change this to location of model artifacts

target_vocab = open_pickle(f'{artifacts_folder}/model_artifacts/target_vocab.pkl')
inv_target_vocab = open_pickle(f'{artifacts_folder}/model_artifacts/inv_target_vocab.pkl')
citation_feature_vocab = open_pickle(f'{artifacts_folder}/model_artifacts/citation_feature_vocab.pkl')
gold_to_label_mapping = open_pickle(f'{artifacts_folder}/model_artifacts/gold_to_id_mapping_dict.pkl')
emb_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
language_model_name = \
    "OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract"
tokenizer = AutoTokenizer.from_pretrained(language_model_name , truncate=True, pad=True)
gold_dict = open_pickle(f'{artifacts_folder}/model_artifacts/gold_citations_dict.pkl')
non_gold_dict = open_pickle(f'{artifacts_folder}/model_artifacts/non_gold_citations_dict.pkl')

In [77]:
test_pipeline = CustomHiddenOutputPipeline(language_model_name)

Some layers from the model checkpoint at OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract were not used when initializing TFBertForSequenceClassification: ['dropout_37']
- This IS expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertForSequenceClassification were initialized from the model checkpoint at OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract.
If your task is similar to the task the model of the checkpoint was trained on, you can already use

In [68]:
# Loading the models
language_model_name = \
    "OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract"

pred_model = create_model(len(target_vocab), 
                          len(citation_feature_vocab)+2,
                          "citation_part_only.keras", topk=5)
xla_predict = tf.function(pred_model, jit_compile=True)

# test_pipeline = CustomHiddenOutputPipeline(language_model_name)

language_model = TFAutoModelForSequenceClassification.from_pretrained(language_model_name, output_hidden_states=True)
language_model.trainable = False
xla_predict_lang_model = tf.function(language_model, jit_compile=True)

Some layers from the model checkpoint at OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract were not used when initializing TFBertForSequenceClassification: ['dropout_37']
- This IS expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertForSequenceClassification were initialized from the model checkpoint at OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract.
If your task is similar to the task the model of the checkpoint was trained on, you can already use

In [106]:
from optimum.bettertransformer import BetterTransformer

In [22]:
test_text = """<TITLE> The Shape of the Olfactory Bulb Predicts Olfactory Function
<ABSTRACT> The olfactory bulb (OB) plays a key role in the processing of olfactory information. A large body of research has shown that OB volumes correlate with olfactory function, which provides diagnostic and prognostic information in olfactory dysfunction. Still, the potential value of the OB shape remains unclear. Based on our clinical experience we hypothesized that the shape of the OB predicts olfactory function, and that it is linked to olfactory loss, age, and gender. The aim of this study was to produce a classification of OB shape in the human brain, scalable to clinical and research applications. Results from patients with the five most frequent causes of olfactory dysfunction (n = 192) as well as age/gender-matched healthy controls (n = 77) were included. Olfactory function was examined in great detail using the extended "Sniffin' Sticks" test. A high-resolution structural T2-weighted MRI scan was obtained for all. The planimetric contours (surface in mm2) of OB were delineated manually, and then all surfaces were added and multiplied to obtain the OB volume in mm3. OB shapes were outlined manually and characterized on a selected slice through the posterior coronal plane tangential to the eyeballs. We looked at OB shapes in terms of convexity and defined two patterns/seven categories based on OB contours: convex (olive, circle, and plano-convex) and non-convex (banana, irregular, plane, and scattered). Categorization of OB shapes is possible with a substantial inter-rater agreement (Cohen's Kappa = 0.73). Our results suggested that non-convex OB patterns were significantly more often observed in patients than in controls. OB shapes were correlated with olfactory function in the whole group, independent of age, gender, and OB volume. OB shapes seemed to change with age in healthy subjects. Importantly, the results indicated that OB shapes were associated with certain causes of olfactory disorders, i.e., an irregular OB shape was significantly more often observed in post-traumatic olfactory loss. Our study provides evidence that the shape of the OB can be used as a biomarker for olfactory dysfunction."""

In [48]:
tok_inputs_pt = tokenize(test_text[:50], return_tensors='pt')
tok_inputs_tf = tokenize(test_text[:50], return_tensors='tf')

In [126]:
pt_model = AutoModelForSequenceClassification.from_pretrained(language_model_name, output_hidden_states=True)

config.json:   0%|          | 0.00/617k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/725M [00:00<?, ?B/s]

In [24]:
import requests

In [25]:
%%time
open_req = "https://api.openalex.org/works/W4205779344"
resp = requests.get(open_req).json()
print(resp['id'])

if resp['primary_location']['source']:
    journal_display_name = resp['primary_location']['source']['display_name']
else:
    journal_display_name = ""


input_json = [{'title': resp['title'], 
               'abstract_inverted_index': resp['abstract_inverted_index'], 
               'journal_display_name': journal_display_name, 
               'referenced_works': resp['referenced_works'],
               'inverted': True}]

https://openalex.org/W4205779344
CPU times: user 33.9 ms, sys: 667 µs, total: 34.6 ms
Wall time: 141 ms


In [26]:
input_df = pd.DataFrame.from_dict(input_json).reset_index().rename(columns={'index': 'UID'})

In [27]:
input_df.shape

(1, 6)

In [52]:
%%time
final_preds = process_data_as_df(input_df)

<TITLE> The Shape of the Olfactory Bulb Predicts Olfactory Function
<ABSTRACT> The olfactory bulb (OB) plays a key role in the processing of olfactory information. A large body of research has shown that OB volumes correlate with olfactory function, which provides diagnostic and prognostic information in olfactory dysfunction. Still, the potential value of the OB shape remains unclear. Based on our clinical experience we hypothesized that the shape of the OB predicts olfactory function, and that it is linked to olfactory loss, age, and gender. The aim of this study was to produce a classification of OB shape in the human brain, scalable to clinical and research applications. Results from patients with the five most frequent causes of olfactory dysfunction (n = 192) as well as age/gender-matched healthy controls (n = 77) were included. Olfactory function was examined in great detail using the extended "Sniffin' Sticks" test. A high-resolution structural T2-weighted MRI scan was obtained

#### Loading language model only from huggingface

In [74]:
classifier_multi = \
    pipeline(model="OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract", top_k=5)

Some layers from the model checkpoint at OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract were not used when initializing TFBertForSequenceClassification: ['dropout_37']
- This IS expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertForSequenceClassification were initialized from the model checkpoint at OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract.
If your task is similar to the task the model of the checkpoint was trained on, you can already use

In [77]:
%%time
classifier_multi("""<TITLE>Supplemental Material: Estimating paleotidal constituents from Pliocene “tidal gauges”—an example from the paleo-Orinoco Delta, Trinidad""")

CPU times: user 486 ms, sys: 0 ns, total: 486 ms
Wall time: 220 ms


[[{'label': '3404: Geodynamics of the Northern Andes and Caribbean Region',
   'score': 0.4785984754562378},
  {'label': "965: Sedimentary Processes in Earth's Geology",
   'score': 0.1968356966972351},
  {'label': '2014: Biogeography and Conservation of Neotropical Freshwater Fishes',
   'score': 0.051199547946453094},
  {'label': '109: Paleoredox and Paleoproductivity Proxies',
   'score': 0.01857946440577507},
  {'label': '3205: Geodynamic Evolution of Western Mediterranean Region',
   'score': 0.016884787008166313}]]