In [None]:
import os
import re
import json
import pickle
import unidecode
import redshift_connector
import pandas as pd
import numpy as np
import tensorflow as tf
from langdetect import detect
from transformers import TFAutoModelForSequenceClassification, DistilBertTokenizer
from transformers import DataCollatorWithPadding, PreTrainedTokenizerFast

### Getting all affiliations and also affiliation counts

In [None]:
# Redshift credentials for querying the OpenAlex database
with open("redshift_creds.txt", "r") as f:
    host = f.readline()[:-1]
    password= f.readline()[:-1]

In [None]:
# Creating a connection
conn = redshift_connector.connect(
     host=host,
     database='dev',
     user='app_user',
     password=password
  )

cursor = conn.cursor()

In [None]:
query = """select affiliation_id, display_name, city, region, country
           from mid.institution"""

In [None]:
cursor.execute("ROLLBACK;")
cursor.execute(query)
df = cursor.fetch_dataframe()
df.shape

In [None]:
query = """select affiliation_id, count(affiliation_id)
           from mid.affiliation
           group by affiliation_id"""

In [None]:
cursor.execute("ROLLBACK;")
cursor.execute(query)
weights = cursor.fetch_dataframe().dropna()
weights['affiliation_id'] = weights['affiliation_id'].astype('int')

### Using the exact code that will be used for deployment

In [None]:
# Define the path
prefix = './path_to_model_files/' # insert path to model files here
model_path = os.path.join(prefix, 'model_files')

# Load the needed files
with open(os.path.join(model_path, "departments_list.pkl"), "rb") as f:
    departments_list = pickle.load(f)

print("Loaded list of departments")

with open(os.path.join(model_path, "full_affiliation_dict.pkl"), "rb") as f:
    full_affiliation_dict = pickle.load(f)

print("Loaded affiliation dictionary")

with open(os.path.join(model_path, "countries_list_flat.pkl"), "rb") as f:
    countries_list_flat = pickle.load(f)

print("Loaded flat list of countries")

with open(os.path.join(model_path, "countries.json"), "r") as f:
    countries_dict = json.load(f)

print("Loaded countries dictionary")

with open(os.path.join(model_path, "city_country_list.pkl"), "rb") as f:
    city_country_list = pickle.load(f)

print("Loaded strings of city/country combinations")

with open(os.path.join(model_path, "affiliation_vocab_basic.pkl"), "rb") as f:
    affiliation_vocab_basic = pickle.load(f)
    
inverse_affiliation_vocab_basic = {i:j for j,i in affiliation_vocab_basic.items()}

print("Loaded basic affiliation vocab")

with open(os.path.join(model_path, "language_model/vocab.pkl"), "rb") as f:
    affiliation_vocab_language = pickle.load(f)

inverse_affiliation_vocab_language = {i:j for j,i in affiliation_vocab_language.items()}

print("Loaded language affiliation vocab")

# Load the tokenizers
language_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased", return_tensors='tf')
data_collator = DataCollatorWithPadding(tokenizer=language_tokenizer, 
                                        return_tensors='tf')

basic_tokenizer = PreTrainedTokenizerFast(tokenizer_file=os.path.join(model_path, "basic_model_tokenizer"))

# Load the models
language_model = TFAutoModelForSequenceClassification.from_pretrained(os.path.join(model_path, "language_model"))
language_model.trainable = False

basic_model = tf.keras.models.load_model(os.path.join(model_path, "basic_model"), compile=False)
basic_model.trainable = False

def get_country_in_string(text):
    """
    Looks for countries in the affiliation string to be used in filtering later on.
    """
    countries_in_string = []
    _ = [countries_in_string.append(x) for x,y in countries_dict.items() if 
         np.max([1 if re.search(fr"\b{i}\b", text) else 0 for i in y]) > 0]
    _ = [countries_in_string.append(x) for x,y in countries_dict.items() if 
         np.max([1 if re.search(fr"\b{i}\b", text.replace(".","")) else 0 for i in y]) > 0]
    return list(set(countries_in_string))

def max_len_and_pad(tok_sent):
    """
    Processes the basic model data to the correct input length.
    """
    max_len = 128
    tok_sent = tok_sent[:max_len]
    tok_sent = tok_sent + [0]*(max_len - len(tok_sent))
    return tok_sent


def get_language(orig_aff_string):
    """
    Guesses the language of the affiliation string to be used for filtering later.
    """
    try:
        string_lang = detect(orig_aff_string)
    except:
        string_lang = 'en'
        
    return string_lang

def get_initial_pred(orig_aff_string, string_lang, countries_in_string, comma_split_len):
    """
    Initial hard-coded filtering of the affiliation text to ensure that meaningless strings
    and strings in other languages are not given an institution.
    """
    if string_lang in ['fa','ko','zh-cn','zh-tw','ja','uk','ru','vi']:
        init_pred = None
    elif not str(orig_aff_string).strip():
        init_pred = None
    elif ((orig_aff_string.startswith("Dep") | 
           orig_aff_string.startswith("School") | 
           orig_aff_string.startswith("Ministry")) & 
          (comma_split_len < 2) & 
          (not countries_in_string)):
        init_pred = None
    elif orig_aff_string in departments_list:
        init_pred = None
    elif orig_aff_string in city_country_list:
        init_pred = None
    elif re.search(r"\b(LIANG|YANG|LIU|et al|XIE|JIA|ZHANG|QU)\b", 
                   orig_aff_string):
        init_pred = None
    else:
        init_pred = 0
    return init_pred

def get_language_model_prediction(decoded_text, all_countries):
    """
    Preprocesses the decoded text and gets the output labels and scores for the language model.
    """
    lang_tok_data = language_tokenizer(decoded_text, truncation=True, padding=True, max_length=512)
    
    data = data_collator(lang_tok_data)
    all_scores, all_labels = tf.math.top_k(tf.nn.softmax(
            language_model.predict([data['input_ids'], 
                                    data['attention_mask']]).logits).numpy(), 5)
    
    all_scores = all_scores.numpy().tolist()
    all_labels = all_labels.numpy().tolist()
    
    final_preds_scores = []
    for scores, labels, countries in zip(all_scores, all_labels, all_countries):
        final_pred, final_score = get_final_basic_or_language_model_pred(scores, labels, countries,
                                                                         affiliation_vocab_language, 
                                                                         inverse_affiliation_vocab_language)
        final_preds_scores.append([final_pred, final_score])
    
    return final_preds_scores

def get_basic_model_prediction(decoded_text, all_countries):
    """
    Preprocesses the decoded text and gets the output labels and scores for the basic model.
    """
    basic_tok_data = basic_tokenizer(decoded_text)['input_ids']
    basic_tok_data = [max_len_and_pad(x) for x in basic_tok_data]
    basic_tok_tensor = tf.convert_to_tensor(basic_tok_data, dtype=tf.int64)
    all_scores, all_labels = tf.math.top_k(basic_model.predict(basic_tok_data), 5)
    
    all_scores = all_scores.numpy().tolist()
    all_labels = all_labels.numpy().tolist()
    
    final_preds_scores = []
    for scores, labels, countries in zip(all_scores, all_labels, all_countries):
        final_pred, final_score = get_final_basic_or_language_model_pred(scores, labels, countries,
                                                                         affiliation_vocab_basic, 
                                                                         inverse_affiliation_vocab_basic)
        final_preds_scores.append([final_pred, final_score])
    
    return final_preds_scores


def get_final_basic_or_language_model_pred(scores, labels, countries, vocab, inv_vocab):
    """
    Takes the scores and labels from either model and performs a quick country matching
    to see if the country found in the string can be matched to the country of the
    predicted institution.
    """
    mapped_labels = [inv_vocab[i] for i,j in zip(labels,scores) if i!=vocab[-1]]
    scores = [j for i,j in zip(labels,scores) if i!=vocab[-1]]
    final_pred = mapped_labels[0]
    final_score = scores[0]
    if not full_affiliation_dict[mapped_labels[0]]['country']:
        pass
    else:
        if not countries:
            pass
        else:
            for pred,score in zip(mapped_labels, scores):
                if not full_affiliation_dict[pred]['country']:
                    # trying pass instead of break to give time to find the correct country
                    pass
                elif full_affiliation_dict[pred]['country'] in countries:
                    final_pred = pred
                    final_score = score
                    break
                else:
                    pass
    return final_pred, final_score

def get_final_prediction(basic_pred_score, lang_pred_score, countries, raw_sentence, lang_thresh, basic_thresh):
    """
    Performs the model comparison and filtering to get the final prediction.
    """
    
    # Getting the individual preds and scores for both models
    pred_lang, score_lang = lang_pred_score
    pred_basic, score_basic = basic_pred_score
    
    # Logic for combining the two models
    if pred_lang == pred_basic:
        final_pred = pred_lang
        final_score = score_lang
        final_cat = 'match'
    elif score_basic > basic_thresh:
        final_pred = pred_basic
        final_score = score_basic
        final_cat = 'basic_thresh'
    elif score_lang > lang_thresh:
        final_pred = pred_lang
        final_score = score_lang
        final_cat = 'lang_thresh'
    elif (score_basic > 0.01) & ('China' in countries) & ('Natural Resource' in raw_sentence):
        final_pred = pred_basic
        final_score = score_basic
        final_cat = 'basic_thresh_second'
    else:
        final_pred = None
        final_score = 0.0
        final_cat = 'nothing'
    return [final_pred, final_score, final_cat]

def raw_data_to_predictions(df, lang_thresh, basic_thresh):
    """
    High level function to go from a raw input dataframe to the final dataframe with affiliation
    ID prediction.
    """
    # Implementing the functions above
    df['lang'] = df['affiliation_string'].apply(get_language)
    df['country_in_string'] = df['affiliation_string'].apply(get_country_in_string)
    df['comma_split_len'] = df['affiliation_string'].apply(lambda x: len([i if i else "" for i in 
                                                                          x.split(",")]))

    # Gets initial indicator of whether or not the string should go through the models
    df['affiliation_id'] = df.apply(lambda x: get_initial_pred(x.affiliation_string, x.lang, 
                                                               x.country_in_string, x.comma_split_len), axis=1)
    
    # Filter out strings that won't go through the models
    to_predict = df[df['affiliation_id']==0.0].drop_duplicates(subset=['affiliation_string']).copy()
    to_predict['affiliation_id'] = to_predict['affiliation_id'].astype('int')

    # Decode text so only ASCII characters are used
    to_predict['decoded_text'] = to_predict['affiliation_string'].apply(unidecode.unidecode)

    # Get predictions and scores for each model
    to_predict['lang_pred_score'] = get_language_model_prediction(to_predict['decoded_text'].to_list(), 
                                                                  to_predict['country_in_string'].to_list())
    to_predict['basic_pred_score'] = get_basic_model_prediction(to_predict['decoded_text'].to_list(), 
                                                                to_predict['country_in_string'].to_list())

    # Get the final prediction for each affiliation string
    to_predict['affiliation_id'] = to_predict.apply(lambda x: 
                                                    get_final_prediction(x.basic_pred_score, 
                                                                         x.lang_pred_score, 
                                                                         x.country_in_string, 
                                                                         x.affiliation_string, 
                                                                         lang_thresh, basic_thresh)[0], axis=1)

    # Merge predictions to original dataframe to get the same order as the data that was requested
    final_df = df[['affiliation_string']].merge(to_predict[['affiliation_string','affiliation_id']], 
                                                how='left', on='index')
    
    final_df['affiliation_id'] = final_df['affiliation_id'].fillna(-1).astype('int')
    return final_df


print("Models initialized")

In [None]:
def check_for_correct_pred(pred, target_list):
    if pred in target_list:
        return 1
    else:
        return 0

##### Make sure the gold_1000.parquet and gold_500.parquet files are in the current directory.

In [None]:
# Loading the Gold 500 dataset (strings with empty affiliation IDs in MAG/OpenAlex)
data_500 = pd.read_parquet("gold_500.parquet", 
                           columns=['raw_affiliation','true_affiliation_id']) \
             .rename(columns={'raw_affiliation': 'affiliation_string'})

# Attaching the weights to be used for sampling
data_500['affiliation_id_for_weights'] = data_500['true_affiliation_id'].apply(lambda x: x[0])
data_500 = data_500.merge(weights, how='left', left_on='affiliation_id_for_weights', right_on='affiliation_id')
data_500['count'] = data_500['count'].fillna(7).astype('int')
data_500 = data_500[['affiliation_string','true_affiliation_id','count']].copy()

# Loading the Gold 1000 dataset (strings with affiliation ID in MAG/OpenAlex)
data_1000 = pd.read_parquet("gold_1000.parquet", 
                            columns=['raw_affiliation','true_affiliation_id']) \
             .rename(columns={'raw_affiliation': 'affiliation_string'})

# Attaching the weights to be used for sampling
data_1000['affiliation_id_for_weights'] = data_1000['true_affiliation_id'].apply(lambda x: x[0])
data_1000 = data_1000.merge(weights, how='left', left_on='affiliation_id_for_weights', right_on='affiliation_id')
data_1000['count'] = data_1000['count'].fillna(57170).astype('int')
data_1000 = data_1000[['affiliation_string','true_affiliation_id','count']].copy()

#### Looking at multiple runs of sampling to see what the precision and recall would be for the chosen threshold

In [None]:
%%time
# Using 0.99 and 0.75 thresholds after grid search
for lang_thresh in [0.99]:
    for basic_thresh in [0.75]:
        print(f"Basic: {basic_thresh}   Lang: {lang_thresh}")
        for i in range(15):
            # Sampling from the gold datasets to get a distribution of samples similar to real data
            sampled_500_preds = data_500.sample(197, weights=data_500['count'])
            sampled_1000_preds = data_1000.sample(803, weights=data_1000['count'])
            input_df = pd.concat([sampled_500_preds, sampled_1000_preds], axis=0)
            
            # Getting predictions
            final_df = raw_data_to_predictions(input_df, lang_thresh, basic_thresh)

            # Filling in empty affiliations with -1 (if any)
            final_df['affiliation_id'] = final_df['affiliation_id'].fillna(-1).astype('int')

            # Checking if pred matches the label
            final_df['pred_correct'] = final_df.apply(lambda x: check_for_correct_pred(x.affiliation_id,
                                                                                     x.true_affiliation_id), axis=1)

            final_df['equals_negative_one'] = final_df['true_affiliation_id'] \
            .apply(lambda x: x[0]==-1).astype('int')

            final_df['not_equals_negative_one'] = final_df['true_affiliation_id'] \
            .apply(lambda x: x[0]!=-1).astype('int')

            # Getting true positives, false positives, and false negatives
            TP = final_df[(final_df['pred_correct']==1) & 
                          (final_df['not_equals_negative_one']==1)].shape[0]

            TN = final_df[(final_df['pred_correct']==1) & 
                          (final_df['not_equals_negative_one']==0)].shape[0]

            FP =  final_df[(final_df['pred_correct']==0) & 
                           (final_df['affiliation_id']!=-1)].shape[0]

            FN = final_df[(final_df['not_equals_negative_one']==1) & (final_df['affiliation_id']==-1)].shape[0]

            # Calculating precision and recall
            precision = TP/(TP+FP)
            recall = TP/(TP+FN)
            print(f"-------Precision: {round(precision, 3)}     Recall: {round(recall, 3)}")