In [215]:
import pandas as pd
import numpy as np
import datetime
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

# ML Packages
from xgboost import XGBClassifier
import xgboost
import statsmodels.api as sm

from sklearn import metrics, model_selection, svm
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc, f1_score
from sklearn.metrics import mean_squared_error as MSE
from sklearn.model_selection import train_test_split, GridSearchCV, RandomizedSearchCV
from sklearn.preprocessing import StandardScaler, LabelBinarizer, LabelEncoder
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import BaggingClassifier, RandomForestRegressor
from sklearn.decomposition import PCA
from keras.applications.resnet import preprocess_input

In [197]:
# Loading raw tweet data
data = pd.read_json('data/aminer/dblp-ref-0.json', lines=True)
data = data.append(pd.read_json('data/aminer/dblp-ref-1.json', lines=True))
data = data.append(pd.read_json('data/aminer/dblp-ref-2.json', lines=True))
data = data.append(pd.read_json('data/aminer/dblp-ref-3.json', lines=True))

In [198]:
data

Unnamed: 0,abstract,authors,n_citation,references,title,venue,year,id
0,The purpose of this study is to develop a lear...,"[Makoto Satoh, Ryo Muramatsu, Mizue Kayama, Ka...",0,"[51c7e02e-f5ed-431a-8cf5-f761f266d4be, 69b625b...",Preliminary Design of a Network Protocol Learn...,international conference on human-computer int...,2013,00127ee2-cb05-48ce-bc49-9de556b93346
1,This paper describes the design and implementa...,"[Gareth Beale, Graeme Earl]",50,"[10482dd3-4642-4193-842f-85f3b70fcf65, 3133714...",A methodology for the physically accurate visu...,visual analytics science and technology,2011,001c58d3-26ad-46b3-ab3a-c1e557d16821
2,This article applied GARCH model instead AR or...,"[Altaf Hossain, Faisal Zaman, Mohammed Nasser,...",50,"[2d84c0f2-e656-4ce7-b018-90eda1c132fe, a083a1b...","Comparison of GARCH, Neural Network and Suppor...",pattern recognition and machine intelligence,2009,001c8744-73c4-4b04-9364-22d31a10dbf1
3,,"[Jea-Bum Park, Byungmok Kim, Jian Shen, Sun-Yo...",0,"[8c78e4b0-632b-4293-b491-85b1976675e6, 9cdc54f...",Development of Remote Monitoring and Control D...,,2011,00338203-9eb3-40c5-9f31-cbac73a519ec
4,,"[Giovanna Guerrini, Isabella Merlo]",2,,Reasonig about Set-Oriented Methods in Object ...,,1998,0040b022-1472-4f70-a753-74832df65266
...,...,...,...,...,...,...,...,...
79002,,"[Hassan Charaf, Peter Ekler, Tamás Mészáros, I...",50,,Mobile Platforms and Multi-Mobile Platform Dev...,Acta Cybernetica,2014,ff5ce050-ea8d-40e8-a25f-c629bed2ff9c
79003,,"[Saul Blecker, Stuart D. Katz, Leora I. Horwit...",0,,Comparison of Approaches for Heart Failure Cas...,,2016,ff5f5e4d-b650-496a-bfdd-91affb718488
79004,,"[Dzmitry Bahdanau, Tom Bosc, Stanisław Jastrzę...",0,,Learning to Compute Word Embeddings on the Fly,,2017,ff8fba62-4bf4-40cd-8555-46b8c64dddd7
79005,,"[Kirsti Askedal, Leif Skiftenes Flak, Eirik Ab...",0,,Reviewing Effects of ICT in Primary Healthcare...,,2017,ff90ffea-c94e-4ac5-a36a-05e1eccd6a76


In [199]:
# Checking out some of the larger venues
data['venue'].value_counts().sort_values(ascending=False).head(20)

                                                                            506699
Lecture Notes in Computer Science                                            32137
international conference on acoustics, speech, and signal processing         26621
international conference on robotics and automation                          19943
international conference on image processing                                 18336
international conference on communications                                   17679
international symposium on circuits and systems                              16945
global communications conference                                             15850
international geoscience and remote sensing symposium                        15390
intelligent robots and systems                                               14698
conference of the international speech communication association             13510
human factors in computing systems                                           13120
Appl

In [200]:
# selection = data.loc[data['categories'] == 'cs.LG']
selection = data.loc[data['venue'] == 'international joint conference on artificial intelligence']
selection = selection.append(data.loc[data['venue'] == 'national conference on artificial intelligence'])
selection = selection.dropna(axis=0)
selection = selection.reset_index()

In [201]:
selection

Unnamed: 0,index,abstract,authors,n_citation,references,title,venue,year,id
0,64,The paper discusses the design principles and ...,"[B. K. Bog, K. Sparck Jones]",50,"[11aceee0-863e-4bae-81d1-899bc1edaac3, 982ac74...",A general semantic analyser for data base access,international joint conference on artificial i...,1981,03cda805-9746-48bb-a04d-02c2dac201c7
1,78,The acquisition of concepts induced by structu...,[Pierre E. Bonzon],0,"[84d1c991-a589-4b22-8341-f6f8e5bade27, f9bb0c0...",Learning of abstractions from structural descr...,international joint conference on artificial i...,1979,04f16960-42b0-4f58-a1ff-23844582c40c
2,358,Knowledge-based recommenders support users in ...,"[Alexander Felfernig, Gerhard Friedrich, Monik...",66,"[2a73f987-98cc-4d8b-bf27-9d8d0834c57c, 2b31c2c...",Plausible repairs for inconsistent requirements,international joint conference on artificial i...,2009,1740b58a-bdb0-4171-9255-c32b1784ba6e
3,475,We investigate the problem of mining closed se...,"[Gemma C. Garriga, Roni Khardon, Luc De Raedt]",50,"[0df8cc1d-5dd6-4ecc-89cf-87b65286063d, 29ecc9e...",On mining closed sets in multi-relational data,international joint conference on artificial i...,2007,2044d4cf-a873-4102-b46e-8a590df70c54
4,496,Voting is a simple mechanism to aggregate the ...,[Toby Walsh],87,"[0181b1dc-346e-4c70-bff0-2768954ecd5f, 06678ba...",Where are the really hard manipulation problem...,international joint conference on artificial i...,2009,21ab7cbc-12ec-4def-a407-48fbee0fcb06
...,...,...,...,...,...,...,...,...,...
13753,67141,We prove that the scale map of the zero-crossi...,"[Alan L. Yuille, T. Poggio]",50,"[14582bca-c63d-4f41-8e8e-dedffb742728, 2fa2e5b...",Fingerprints theorems,national conference on artificial intelligence,1984,101b8e77-5e6e-4a3a-b2ef-f48eeb3e0af7
13754,69085,Optimal heuristic searches such as A* search a...,"[Maxim Likhachev, Anthony Stentz]",73,"[53798482-4334-42ed-962d-622d50d0501b, 65cc167...",R* search,national conference on artificial intelligence,2008,8d0fda05-4912-4f16-bea2-ba668c209252
13755,69297,Online services such as web search and e-comme...,"[Adish Singla, Eric Horvitz, Ece Kamar, Ryen W...",50,"[1a54bb93-00e6-4295-9b4b-52c54d13dbf2, 29f1154...",Stochastic privacy,national conference on artificial intelligence,2014,9b30e138-601a-4c4a-922e-dfcf3972c819
13756,76170,"We introduce Wubble World, a virtual environme...","[Daniel Hewlett, Shane Hoversten, Wesley Kerr,...",50,"[361ed44b-6128-4422-9e52-cf8b18524bd1, 91526f3...",Wubble World,national conference on artificial intelligence,2007,258fff35-d7f9-4fde-a9c0-3f88d7301bff


In [202]:
# extracting tweets to a list of documents
docs = selection.abstract.tolist()

In [177]:
# Pulling out document embeddings using SentenceTransformer
# Can take hours to run this cell!
start = datetime.datetime.now()
model = SentenceTransformer('sentence-transformers/all-roberta-large-v1') # all-roberta-large-v1 # all-MiniLM-L6-v2
embeddings = model.encode(docs)

end = datetime.datetime.now(); elapsed = end-start
print('Training took a total of {}'.format(elapsed))

Training took a total of 4:30:30.052754


In [194]:
# Check shape to see if it worked
embeddings.shape

(13758, 1024)

In [180]:
embeddings = pd.DataFrame(embeddings).add_suffix('_ABS_LG')

In [181]:
# # Saving our embeddings
# extended_df = selection.join(pd.DataFrame(embeddings))
# pd.DataFrame(extended_df).to_csv('data/abs_emb_lg.csv')

In [None]:
# # Combining to extended_df with title and abstract embeddings
# extended_df = extended_df.join(embeddings)

In [203]:
# loading extended_df
extended_df = pd.read_csv('data/ai_data_and_embeddings.csv')

In [204]:
print(extended_df.shape)
extended_df.head(2)

(13758, 778)


Unnamed: 0.1,Unnamed: 0,index,abstract,authors,n_citation,references,title,venue,year,id,...,374_Title,375_Title,376_Title,377_Title,378_Title,379_Title,380_Title,381_Title,382_Title,383_Title
0,0,64,The paper discusses the design principles and ...,"['B. K. Bog', 'K. Sparck Jones']",50,"['11aceee0-863e-4bae-81d1-899bc1edaac3', '982a...",A general semantic analyser for data base access,international joint conference on artificial i...,1981,03cda805-9746-48bb-a04d-02c2dac201c7,...,-0.006491,0.007468,-0.065261,0.005759,-0.03451,-0.024764,-0.012584,0.049668,0.031229,0.018817
1,1,78,The acquisition of concepts induced by structu...,['Pierre E. Bonzon'],0,"['84d1c991-a589-4b22-8341-f6f8e5bade27', 'f9bb...",Learning of abstractions from structural descr...,international joint conference on artificial i...,1979,04f16960-42b0-4f58-a1ff-23844582c40c,...,-0.023129,-0.08961,0.029168,-0.046427,0.017737,0.054145,-0.021988,0.073036,0.046273,-0.018433


### Train_test_split and feature selection

In [205]:
# Assigning X and y values
X = extended_df.drop(['n_citation', 'abstract', 'index', 'authors', 'references', 'title', 'venue', 'year', 'id'], axis=1) #
y = extended_df['n_citation']

In [206]:
# Log transform of y
y_log = np.log1p(y)

# Train-test Split
X_train, X_test, y_train, y_test = train_test_split(X, y_log, random_state=42) # can switch y to y_log

# Creating Validation Set
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, random_state=42)

In [207]:
# Selecting the top 256 correlating features with target 'retweets' within X_train
top_correlating_features = X_train.corrwith(y_train).sort_values(ascending=False).head(256)
display(top_correlating_features.head(20)) # .filter(regex='_TOP')

# Filtering X_train, X_val, and X_test to just our 256 top correlating features
X_train = X_train.filter(pd.DataFrame(top_correlating_features).transpose().columns)
X_val = X_val.filter(pd.DataFrame(top_correlating_features).transpose().columns)
X_test = X_test.filter(pd.DataFrame(top_correlating_features).transpose().columns)

187          0.109360
196          0.107690
382          0.098432
45           0.091720
161          0.090510
196_Title    0.086272
123          0.084453
382_Title    0.082792
245          0.081748
83           0.081423
34           0.080386
211          0.075705
243          0.074542
243_Title    0.074052
249_Title    0.072514
137          0.072252
187_Title    0.071402
352          0.071124
349_Title    0.070135
69           0.068164
dtype: float64

## Machine Learning

In [208]:
X_train.shape

(7738, 256)

In [209]:
### Model Evaluation Functions
def plot_history(history):
  plt.figure()
  plt.xlabel('Epoch')
  plt.ylabel('Mean Abs Error')
  plt.plot(history.epoch, np.array(history.history['root_mean_squared_error']), label='Train')
  plt.plot(history.epoch, np.array(history.history['val_root_mean_squared_error']), label = 'Val')
  plt.legend()
  plt.ylim([0,max(history.history['val_root_mean_squared_error'])])

def plot_prediction(test_labels, test_predictions):
  f = plt.figure(figsize=(8, 4))

  f1 = f.add_subplot(1,2,1)
  plt.scatter(test_labels, test_predictions)
  plt.xlabel('True Values')
  plt.ylabel('Predictions')
  plt.axis('equal')
  plt.xlim(plt.xlim())
  plt.ylim(plt.ylim())

  error = test_predictions - test_labels
  f2 = f.add_subplot(1,2,2)
  plt.hist(error, bins = 50)
  plt.xlabel("Prediction Error")
  plt.ylabel("Count")
    
  plt.tight_layout()
  plt.show()

In [210]:
simple_baseline = pd.DataFrame(y_train)
simple_baseline['nulls'] = 0
simple_rmse = np.sqrt(MSE(y_train, simple_baseline['nulls']))
print("ALL 0 RMSE : % f" %(simple_rmse))

ALL 0 RMSE :  3.337606


In [211]:
simple_baseline = pd.DataFrame(y_train)
simple_baseline['nulls'] = y_train.mean()
simple_rmse = np.sqrt(MSE(y_train, simple_baseline['nulls']))
print("AVG CONST RMSE : % f" %(simple_rmse))

AVG CONST RMSE :  1.544978


In [222]:
# Linear Regression
linreg = LinearRegression().fit(X_train, y_train)
pred = linreg.predict(X_val); linreg_rmse = np.sqrt(MSE(y_val, pred)); print("LR RMSE : % f" %(linreg_rmse))
predictions = pd.DataFrame(pred).rename({0:'linreg_baseline'}, axis=1) # recording prediction
R2 = round(pd.DataFrame(pred).join(y_val.reset_index(drop=True)).corr()['n_citation'][0]**2, 4)
print(f'R2 = {R2}')

LR RMSE :  1.510595
R2 = 0.0432


In [223]:
start = datetime.datetime.now()

progress = dict(); eval_metric = ["rmse"]; eval_set = [(X_val, y_val)]
XGB = xgboost.XGBRegressor(early_stopping_rounds=10) # , early_stopping_rounds=1
XGBhist = XGB.fit(X_train, y_train, eval_metric=eval_metric, eval_set=eval_set)

end = datetime.datetime.now(); elapsed = end-start; print('Training took a total of {}'.format(elapsed))
pred = XGB.predict(X_val); xgb_b_rmse = np.sqrt(MSE(y_val, pred)); print("LR RMSE : % f" %(xgb_b_rmse))

R2 = round(pd.DataFrame(pred).join(y_val.reset_index(drop=True)).corr()['n_citation'][0]**2, 4)
print(f'R2 = {R2}')

[0]	validation_0-rmse:2.28828
[1]	validation_0-rmse:1.93593
[2]	validation_0-rmse:1.73541




[3]	validation_0-rmse:1.62786
[4]	validation_0-rmse:1.57510
[5]	validation_0-rmse:1.55123
[6]	validation_0-rmse:1.54077
[7]	validation_0-rmse:1.53780
[8]	validation_0-rmse:1.53412
[9]	validation_0-rmse:1.53560
[10]	validation_0-rmse:1.53328
[11]	validation_0-rmse:1.53272
[12]	validation_0-rmse:1.53495
[13]	validation_0-rmse:1.53512
[14]	validation_0-rmse:1.53778
[15]	validation_0-rmse:1.54105
[16]	validation_0-rmse:1.54218
[17]	validation_0-rmse:1.54439
[18]	validation_0-rmse:1.54291
[19]	validation_0-rmse:1.54314
[20]	validation_0-rmse:1.54176
Training took a total of 0:00:01.317995
LR RMSE :  1.532724
R2 = 0.023


Looks like LR is beating our XGBoost model when interpreting our validation set. I'm not feeling super impressed with the explainability here, and I think to have more explainability we would need to zoom out and think about systemic changes in our approach.

### Conclusion
Extraction of whitepaper document embeddings using ROBERTA-Large and linear regression statical prediction can explain about 4.3% of variance (R2) in citation count.

To know more we may need to look at the structure of the academic citation graph, or metadata about researchers. Citation Count is also not a great target, and digging deeper into academic impact measures may be the way to go. Fortunately these methods can be prototyped on many different NLP tasks and citation graphs, some of which may have better targets for prediction.

### The End :)