In [None]:
import os
import boto3
import pickle
import unidecode
import pandas as pd
import numpy as np
pd.set_option('display.max_colwidth', None)
from glob import glob
from math import ceil
from datetime import datetime
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import confusion_matrix

### Getting Data Straight From S3 (PySpark)

In [None]:
cols_for_pred = ['author_1','author_2','inst_match', 'inst_sum','concepts_shortest_match','concepts_shortest_sum',
                 'concepts_shorter_match','concepts_shorter_sum','concepts_match','concepts_sum',
              'coauthors_shorter_match','coauthors_shorter_sum','coauthors_match','coauthors_sum',
                 'citation_match','citation_sum','citation_work_match']

In [None]:
def does_either_work_show_in_citations(paper_id_1, paper_id_2, citation_1, citation_2):
    if paper_id_1 in citation_2:
        return 1
    elif paper_id_2 in citation_1:
        return 1
    else:
        return 0

In [None]:
def transform_name_for_search(name):
    name = unidecode.unidecode(unicodedata.normalize('NFKC', name))
    name = name.lower().replace(" ", " ").replace(".", " ").replace(",", " ").replace("|", " ").replace(")", "").replace("(", "")\
        .replace("-", "").replace("&", "").replace("$", "").replace("#", "").replace("@", "").replace("%", "").replace("0", "") \
        .replace("1", "").replace("2", "").replace("3", "").replace("4", "").replace("5", "").replace("6", "").replace("7", "") \
        .replace("8", "").replace("9", "").replace("*", "").replace("^", "").replace("{", "").replace("}", "").replace("+", "") \
        .replace("=", "").replace("_", "").replace("~", "").replace("`", "").replace("[", "").replace("]", "").replace("\\", "") \
        .replace("<", "").replace(">", "").replace("?", "").replace("/", "").replace(";", "").replace(":", "").replace("\'", "") \
        .replace("\"", "")
    name = " ".join(name.split())
    return name

In [None]:
def get_name_match_list(name):
    name_split_1 = name.replace("-", "").split()
    name_split_2 = ""
    if "-" in name:
        name_split_2 = name.replace("-", " ").split()

    fn = []
    fni = []
    
    m1 = []
    m1i = []
    m2 = []
    m2i = []
    m3 = []
    m3i = []
    m4 = []
    m4i = []
    m5 = []
    m5i = []

    ln = []
    lni = []
    for name_split in [name_split_1, name_split_2]:
        if len(name_split) == 0:
            pass
        elif len(name_split) == 1:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[0]) > 1:
                ln.append(name_split[0])
                lni.append(name_split[0][0])
            else:
                lni.append(name_split[0][0])
            
        elif len(name_split) == 2:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 3:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 4:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 5:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])
                
            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 6:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])
            
            if len(name_split[4]) > 1:
                m4.append(name_split[4])
                m4i.append(name_split[4][0])
            else:
                m4i.append(name_split[4][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        elif len(name_split) == 7:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])
            
            if len(name_split[4]) > 1:
                m4.append(name_split[4])
                m4i.append(name_split[4][0])
            else:
                m4i.append(name_split[4][0])

            if len(name_split[5]) > 1:
                m5.append(name_split[5])
                m5i.append(name_split[5][0])
            else:
                m5i.append(name_split[5][0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
        else:
            if len(name_split[0]) > 1:
                fn.append(name_split[0])
                fni.append(name_split[0][0])
            else:
                fni.append(name_split[0][0])

            if len(name_split[1]) > 1:
                m1.append(name_split[1])
                m1i.append(name_split[1][0])
            else:
                m1i.append(name_split[1][0])

            if len(name_split[2]) > 1:
                m2.append(name_split[2])
                m2i.append(name_split[2][0])
            else:
                m2i.append(name_split[2][0])

            if len(name_split[3]) > 1:
                m3.append(name_split[3])
                m3i.append(name_split[3][0])
            else:
                m3i.append(name_split[3][0])
                
            if len(name_split[4]) > 1:
                m4.append(name_split[4])
                m4i.append(name_split[4][0])
            else:
                m4i.append(name_split[4][0])

            joined_names = " ".join(name_split[5:-1])
            m5.append(joined_names)
            m5i.append(joined_names[0])

            if len(name_split[-1]) > 1:
                ln.append(name_split[-1])
                lni.append(name_split[-1][0])
            else:
                lni.append(name_split[-1][0])
            

    return [list(set(x)) for x in [fn,fni,m1,m1i,m2,m2i,m3,m3i,m4,m4i,m5,m5i,ln,lni]]

In [None]:
def check_block_vs_block(block_1_names_list, block_2_names_list):
    
    # check first names
    first_check, _ = match_block_names(block_1_names_list[0], block_1_names_list[1], block_2_names_list[0], 
                                    block_2_names_list[1])
    # print(f"FIRST {first_check}")
    
    if first_check:
        last_check, _ = match_block_names(block_1_names_list[-2], block_1_names_list[-1], block_2_names_list[-2], 
                                           block_2_names_list[-1])
        # print(f"LAST {last_check}")
        if last_check:
            m1_check, more_to_go = match_block_names(block_1_names_list[2], block_1_names_list[3], block_2_names_list[2], 
                                           block_2_names_list[3])
            if m1_check:
                if not more_to_go:
                    return 1
                m2_check, more_to_go = match_block_names(block_1_names_list[4], block_1_names_list[5], block_2_names_list[4], 
                                                block_2_names_list[5])
                
                if m2_check:
                    if not more_to_go:
                        return 1
                    m3_check, more_to_go = match_block_names(block_1_names_list[6], block_1_names_list[7], block_2_names_list[6], 
                                                block_2_names_list[7])
                    if m3_check:
                        if not more_to_go:
                            return 1
                        m4_check, more_to_go = match_block_names(block_1_names_list[8], block_1_names_list[8], block_2_names_list[8], 
                                                block_2_names_list[9])
                        if m4_check:
                            if not more_to_go:
                                return 1
                            m5_check, _ = match_block_names(block_1_names_list[10], block_1_names_list[11], block_2_names_list[10], 
                                                block_2_names_list[11])
                            if m5_check:
                                return 1
                            else:
                                return 0
                        else:
                            return 0
                    else:
                        return 0
                else:
                    return 0
            else:
                return 0
        else:
            return 0
    else:
        swap_check = check_if_last_name_swapped_to_front_creates_match(block_1_names_list, block_2_names_list)
        # print(f"SWAP {swap_check}")
        if swap_check:
            return 1
        else:
            return 0
        
def get_name_from_name_list(name_list):
    name = []
    for i in range(0,12,2):
        if name_list[i]:
            name.append(name_list[i][0])
        elif name_list[i+1]:
            name.append(name_list[i+1][0])
        else:
            break
    if name_list[-2]:
        name.append(name_list[-2][0])
    elif name_list[-1]:
        name.append(name_list[-1][0])
    else:
        pass

    return name
        
def check_if_last_name_swapped_to_front_creates_match(block_1, block_2):
    name_1 = get_name_from_name_list(block_1)
    if len(name_1) != 2:
        return False
    else:
        name_2 = get_name_from_name_list(block_2)
        if len(name_2)==2:
            if " ".join(name_1) == " ".join(name_2[-1:] + name_2[:-1]):
                return True
            else:
                return False
        else:
            return False
    
def match_block_names(block_1_names, block_1_initials, block_2_names, block_2_initials):
    if block_1_names and block_2_names:
        if any(x in block_1_names for x in block_2_names):
            return True, True
        else:
            return False, True
    elif block_1_names and not block_2_names:
        if block_2_initials:
            if any(x in block_1_initials for x in block_2_initials):
                return True, True
            else:
                return False, True
        else:
            return True, True
    elif not block_1_names and block_2_names:
        if block_1_initials:
            if any(x in block_1_initials for x in block_2_initials):
                return True, True
            else:
                return False, True
        else:
            return True, True
    elif block_1_initials and block_2_initials:
        if any(x in block_1_initials for x in block_2_initials):
            return True, True
        else:
            return False, True
    else:
        return True, False

In [None]:
def get_cosine_sim_between_name_cols(col_1, col_2):
    emb_1 = emb_model.encode(col_1)
    emb_2 = emb_model.encode(col_2)

    return [round(cosine_similarity(emb_1i.reshape(1, -1), emb_2i.reshape(1, -1))[0][0], 4)
            for emb_1i,emb_2i in zip(emb_1, emb_2)]

In [None]:
def get_dataframe_from_S3(filename, data_type = 'train'):
    cols_to_get = ['sample_type'] + cols_for_pred
    df = pd.read_parquet(filename, columns=cols_to_get)
    
    df['author_1_name_list'] = df['author_1'].apply(get_name_match_list)
    df['author_2_name_list'] = df['author_2'].apply(get_name_match_list)
    df['author_name_check'] = df.apply(lambda x: check_block_vs_block(x.author_1_name_list, x.author_2_name_list), 
                                       axis=1)
    df['exact_match'] = df.apply(lambda x: 1 if x.author_1==x.author_2 else 0, axis=1)
    df['name_1_len'] = df['author_1'].apply(len)
    df['name_1_spaces'] = df['author_1'].apply(lambda x: len(x.split(" ")))
    df['exact_match_len'] = df['exact_match'] * df['name_1_len']
    df['exact_match_spaces'] = df['exact_match'] * df['name_1_spaces']
    
    df['inst_per'] = df['inst_match'].apply(lambda x: 1 if x > 0 else 0)
    df['concepts_per'] = (df['concepts_match']/df['concepts_sum']).apply(lambda x: round(x, 4))
    df['concepts_shorter_per'] = (df['concepts_shorter_match']/df['concepts_shorter_sum']).apply(lambda x: 
                                                                                                 round(x, 4))
    df['concepts_shortest_per'] = (df['concepts_shortest_match']/df['concepts_shortest_sum']).apply(lambda x: 
                                                                                                 round(x, 4))
    df['coauthors_per'] = (df['coauthors_match']/df['coauthors_sum']).apply(lambda x: round(x, 4))
    df['coauthors_shorter_per'] = (df['coauthors_shorter_match']/df['coauthors_shorter_sum']).apply(lambda x: 
                                                                                                 round(x, 4))
    df['citation_per'] = (df['citation_match']/df['citation_sum']).apply(lambda x: round(x, 4))
    
    print(df.shape)
    
    df['label'] = df['sample_type'].apply(lambda x: 1 if x=='positive' else 0)
    
    if data_type == 'train':
        df = df[df['author_name_check']==1].copy()
        print(df.shape)
    
    
    df_label_val_counts = df['label'].value_counts()
    num_to_sample = min(df_label_val_counts)
    
    if data_type == 'train':
        first_df = df[df['label']==0].copy().sample(ceil(num_to_sample))
        second_df = df[df['label']==1].copy().sample(ceil(num_to_sample*0.4))
        df = pd.concat([first_df, second_df], axis=0).sample(int(num_to_sample*1.4))
    
    return df.fillna(0.0)

In [None]:
%%time
train_df = get_dataframe_from_S3("<path-to-training-data>", 'train')
train_df.shape

In [None]:
val_df = get_dataframe_from_S3("<path-to-validation-data>", 'val')
val_df.shape

In [None]:
test_df = get_dataframe_from_S3("<path-to-testing-data>", 'test')
test_df.shape

In [None]:
train_df.to_parquet("./datasets_to_share/disambiguator_training_data/train.parquet")

In [None]:
val_df.to_parquet("./datasets_to_share/disambiguator_training_data/val.parquet")

In [None]:
test_df.to_parquet("./datasets_to_share/disambiguator_training_data/test.parquet")

### Training XGB Model

In [None]:
# !pip install xgboost

In [None]:
import pandas as pd
import numpy as np
import xgboost as xgb
from xgboost.sklearn import XGBClassifier
from sklearn import metrics
from sklearn.model_selection import GridSearchCV
import pickle

In [None]:
target='label'

In [None]:
# test_results = pd.read_csv('test_results.csv')
def modelfit(alg, dtrain, dtest, predictors, useTrainCV=True, cv_folds=5, early_stopping_rounds=50):
    
    if useTrainCV:
        xgb_param = alg.get_xgb_params()
        xgtrain = xgb.DMatrix(dtrain[predictors].values, label=dtrain[target].values)
        xgtest = xgb.DMatrix(dtest[predictors].values)
        cvresult = xgb.cv(xgb_param, xgtrain, num_boost_round=alg.get_params()['n_estimators'], nfold=cv_folds,
            metrics='auc', early_stopping_rounds=early_stopping_rounds)
        alg.set_params(n_estimators=cvresult.shape[0],eval_metric='auc')
    
    # Fit the algorithm on the data
    alg.fit(dtrain[predictors], dtrain[target])
        
    # Predict training set:
    dtrain_predictions = alg.predict(dtrain[predictors])
    dtrain_predprob = alg.predict_proba(dtrain[predictors])[:,1]
        
    #Print model report:
    print("\nModel Report")
    print("Accuracy : %.4g" % metrics.accuracy_score(dtrain[target].values, dtrain_predictions))
    print("AUC Score (Train): %f" % metrics.roc_auc_score(dtrain[target], dtrain_predprob))
    print(f'Precision (Train): {metrics.average_precision_score(dtrain[target], dtrain_predprob)}')
    
    # Predict on testing data:
    dtest_pred_prob = alg.predict_proba(dtest[predictors])[:,1]
    print('AUC Score (Test): %f' % metrics.roc_auc_score(dtest[target], dtest_pred_prob))
    print(f'Precision (Test): {metrics.average_precision_score(dtest[target], dtest_pred_prob)}')
                
    print("")
    
    return alg

In [None]:
# predictors = [x for x in train_df.columns if x not in [target, 'sample_type','author_1', 'insts_1', 
#                                                        'concepts_1', 'coauthors_1', 'author_2', 'insts_2', 
#                                                        'concepts_2', 'coauthors_2']]

In [None]:
all_predictors = ['emb_sim', 'inst_match','inst_1_len','inst_2_len','inst_sum', 'inst_per', 'concepts_1_len',
                  'concepts_2_len','concepts_match', 'concepts_sum', 'concepts_per',
                  'concepts_shorter_1_len','concepts_shorter_2_len','concepts_shorter_match', 'concepts_shorter_sum',
                  'concepts_shorter_per', 'concepts_shortest_1_len','concepts_shortest_2_len'
                  'concepts_shortest_match','concepts_shortest_sum','concepts_shortest_per','coauthors_1_len',
                  'coauthors_2_len','coauthors_match','coauthors_sum', 'coauthors_per','citation_match',
                  'citation_1_len','citation_2_len','citation_sum','citation_per','citation_work_match']

In [None]:
predictors = ['inst_per','concepts_shorter_per', 'coauthors_shorter_per','exact_match_len','exact_match_spaces','citation_per','citation_work_match']

In [None]:
train_df[predictors].describe()

In [None]:
param_test1 = {
    'max_depth': [15, 40, 65],
    'min_child_weight': [1, 4, 8],
    'colsample_bytree':[0.4, 0.6, 0.8], 
    'n_estimators': [50, 90, 200],
    'learning_rate':[0.1, 0.2]
    
}
gsearch1 = GridSearchCV(estimator = XGBClassifier( learning_rate =0.1, n_estimators=140, max_depth=5,
                                        min_child_weight=1, gamma=0, subsample=0.8, colsample_bytree=0.8,
                                        objective= 'binary:logistic', nthread=4, scale_pos_weight=1, seed=27), 
                       param_grid = param_test1, scoring='average_precision',n_jobs=4, cv=5)
gsearch1.fit(train_df[predictors],train_df[target])

In [None]:
gsearch1.best_score_

In [None]:
gsearch1.best_params_,

In [None]:
xgb1 = XGBClassifier(
            learning_rate =0.2,
            n_estimators=100,
            max_depth=12,
            min_child_weight=1,
            gamma=0.0,
            subsample=0.7,
            colsample_bytree=0.9,
            objective= 'binary:logistic',
            nthread=4,
            scale_pos_weight=1,
            seed=27)
trained_model = modelfit(xgb1, train_df, val_df, predictors)

In [None]:
for i, j in zip(predictors, trained_model.feature_importances_.tolist()):
    print(f"{i} - {j}")

In [None]:
with open("<local-path-to-model>/Disambiguator.pkl", "wb") as f:
    pickle.dump(trained_model, f)