In [1]:
from sentence_transformers import SentenceTransformer, models
from ogb.nodeproppred import NodePropPredDataset
import pandas as pd
import numpy as np
import pickle
import math
from tqdm import tqdm

In [2]:
# MODEL = "/nlp/scr/ananjan/graph_models/bert_all/checkpoint-2842"
# MODEL = "bert-base-uncased"
# MODEL = "/nlp/scr/ananjan/graph_models/roberta_all/checkpoint-2842"
# MODEL = "/nlp/scr/ananjan/graph_models/minilm_all/checkpoint-5684"
MODEL = "/nlp/scr/ananjan/graph_models/mpnet_all/checkpoint-5684"
DATASET = "ogbn-arxiv"
DATASET_ROOT = "/nlp/scr/ananjan/graph_datasets/"
OUTPUT_FILE = "/nlp/scr/ananjan/graph_embeddings/finetuned/mpnet_arxiv.pkl"
MODE = 'all'

In [3]:
# Load Node-Paper Id Mappings

nodeidx2paperid = pd.read_csv('/nlp/scr/ananjan/graph_datasets/ogbn_arxiv/mapping/nodeidx2paperid.csv')
nodeidx2paperid.head()

Unnamed: 0,node idx,paper id
0,0,9657784
1,1,39886162
2,2,116214155
3,3,121432379
4,4,231147053


In [4]:
# Load Paper Mappings

titleabs = pd.read_csv('/nlp/scr/ananjan/graph_datasets/ogbn_arxiv/mapping/titleabs.tsv', sep='\t')
titleabs.head()

Unnamed: 0,paperid,title,abstract
0,200971.0,ontology as a source for rule generation,This paper discloses the potential of OWL (Web...
1,549074.0,a novel methodology for thermal analysis a 3 d...,The semiconductor industry is reaching a fasci...
2,630234.0,spreadsheets on the move an evaluation of mobi...,The power of mobile devices has increased dram...
3,803423.0,multi view metric learning for multi view vide...,Traditional methods on video summarization are...
4,1102481.0,big data analytics in future internet of things,Current research on Internet of Things (IoT) m...


In [5]:
# Make reverse index for text df

reverse_index= {}

paperids = titleabs["paperid"].tolist()
for idx, paperid in enumerate(paperids):
    if (not math.isnan(paperid)):
        reverse_index[int(paperid)] = idx

In [6]:
# Dataset Creation

mapped_text = []

for idx in tqdm(range(len(nodeidx2paperid))):
    paper_id = nodeidx2paperid.iloc[idx]['paper id']
    reference_idx = reverse_index[paper_id]
    title = titleabs.iloc[reference_idx]['title']
    abstract = titleabs.iloc[reference_idx]['abstract']
    if (MODE == 'title'):
        mapped_text.append("Title: " + title)
    elif (MODE == 'abstract'):
        mapped_text.append(" Abstract: " + abstract)
    else:
        mapped_text.append("Title: " + title + " Abstract: " + abstract)

100%|████████████████████████████████████████████████████████████████████████████████████████████████| 169343/169343 [00:12<00:00, 13101.80it/s]


In [7]:
# Initialize sentence embedding model

word_embedding_model = models.Transformer(MODEL)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode = 'mean')
emb_model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

Some weights of MPNetModel were not initialized from the model checkpoint at /nlp/scr/ananjan/graph_models/mpnet_all/checkpoint-5684 and are newly initialized: ['mpnet.pooler.dense.bias', 'mpnet.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
# Get NumPy embeddings

embeddings = emb_model.encode(sentences = mapped_text,
                             batch_size = 32,
                             show_progress_bar = True,
                             convert_to_numpy = True,
                             normalize_embeddings = True)

Batches:   0%|          | 0/5292 [00:00<?, ?it/s]

In [9]:
# Dump Embeddings

with open(OUTPUT_FILE, 'wb') as f:
    pickle.dump(embeddings, f)