In [1]:
# %pip install transformers
import pandas as pd
from cogdl.oag import oagbert
import torch
import re
import numpy as np
import ipywidgets as widgets
import requests
import json
from dataclasses import dataclass
from typing import Dict, List, Optional
import os

In [15]:
# Parameters
max_depth = 2
ignore_related = True
ignore_referenced = False
base_works_url = "https://api.openalex.org/works"

In [16]:
@dataclass
class Article:
    # Keeping track of some needed paper details
    id: str
    title: str
    inverted_abstract: Dict[str, List[int]]
    authors: List[str]
    host_venue: str
    affiliations: List[str]
    concepts: List[str]
    references: List[str]
    related: List[str]

    def get_abstract(self) -> str:
        abstract = dict()
        for k, v in self.inverted_abstract.items():
            for i in v:
                abstract[i] = k

        final = ""
        for i in sorted(abstract.keys()):
            final += abstract[i] + " "
        return final
    
    def fetch_references_queries(self):
        # open alex only allows 50 OR joins per request
        queries = list()
        for i in range(0, len(self.references), 50):
            queries.append('|'.join(self.references[i:i+50]))
        return queries
    
    def fetch_related_queries(self):
        # open alex only allows 50 OR joins per request
        queries = list()
        for i in range(0, len(self.related), 50):
            queries.append('|'.join(self.related[i:i+50]))
        return queries
    
    def __str__(self):
        return f"{self.id}: {self.title}\n{self.get_abstract()}"

In [17]:
def fetch_article(result):
    work_id = result["id"].split('/')[-1]
    title = result["title"]
    inverted_abstract = result['abstract_inverted_index']
    authors = [authorship['author']['display_name'] for authorship in result['authorships']]
    host_venue = result['host_venue']['publisher']
    institutions = list()

    for authorship in result['authorships']:
        for institution in authorship['institutions']: 
            if institution['display_name'] not in institutions:
                institutions.append(institution['display_name'])

    concepts = [concept['display_name'] for concept in result['concepts'] if float(concept['score']) > 0.5]
    referenced_works = [work.split('/')[-1] for work in result['referenced_works']]
    related_works = [work.split('/')[-1] for work in result['related_works']]

    return Article(
        work_id,
        title if title else "",
        inverted_abstract if inverted_abstract else {"": [0]},
        authors,
        host_venue if host_venue else "",
        institutions,
        concepts,
        referenced_works,
        related_works
    )

# Search for Article Title
Edit the title variable below to search for a paper. If not exact then returns 25 most relevant papers in the OpenAlex dataset. Select the paper in the dropdown menu.

In [31]:
title = "Attention is all you need"
title = title.replace(" ", "%20")
req = requests.get(base_works_url+f"?filter=title.search:{title}")
response = json.loads(req.content)

relevant_titles = [result['title'] for result in response['results']]
title_selector = widgets.Dropdown(
    options=relevant_titles,
    value=relevant_titles[0],
    description="Title: "
)
display(title_selector)

Dropdown(description='Title: ', options=('Attention is All you Need', 'Attention Is All You Need', 'Channel At…

In [32]:
raise Exception("Please select correct title above. If done, run all cells below this one.")

Exception: Please select correct title above. If done, run all cells below this one.

In [33]:
index = relevant_titles.index(title_selector.value)
papers = dict()
root_id = response['results'][index]['id'].split('/')[-1]

papers[root_id] = fetch_article(response['results'][index])

In [34]:
use_references = ignore_referenced != True
use_related = ignore_related != True

related_works: Dict[int, List[Article]] = {}

def get_relevant_papers(current_depth: int, previous: List[Article]):
    related_works[current_depth] = []
    print(current_depth)
    for parent in previous:
        if use_references and len(parent.references) > 0:
            for query in parent.fetch_references_queries():         
                req = requests.get(base_works_url + f'?filter=openalex_id:{query}')
                res = json.loads(req.content)
                for result in res["results"]:
                    paper_id = result['id'].split('/')[-1]
                    if paper_id not in papers.keys():
                        temp = fetch_article(result)
                        papers[temp.id] = temp
                        related_works[current_depth].append(temp)
            
        if (use_related and len(parent.related) > 0) or len(parent.references) == 0:
            for query in parent.fetch_related_queries():  
                req = requests.get(base_works_url + f'?filter=openalex_id:{query}')
                res = json.loads(req.content)
                for result in res["results"]:
                    paper_id = result['id'].split('/')[-1]
                    if paper_id not in papers.keys():
                        temp = fetch_article(result)
                        papers[temp.id] = temp
                        related_works[current_depth].append(temp)

    if current_depth < max_depth:
        get_relevant_papers(current_depth+1, related_works[current_depth])

In [35]:
get_relevant_papers(1, [papers[root_id]])

1
2


In [36]:
tokenizer, model = oagbert("oagbert-v2")

In [37]:
if not os.path.exists("./embeddings/"):
    os.mkdir("./embeddings/")

files = os.listdir("./embeddings/")
files = [file.split('.')[0] for file in files]

for key in papers.keys():
    if key not in files:
        curr_paper = papers[key]
        input_ids, input_masks, token_type_ids, masked_lm_labels, position_ids, position_ids_second, masked_positions, num_spans = model.build_inputs(
            title=curr_paper.title, 
            abstract=curr_paper.get_abstract(), 
            venue=curr_paper.host_venue, 
            authors=curr_paper.authors, 
            concepts=curr_paper.concepts, 
            affiliations=curr_paper.affiliations
        )

        sequence_output, pooled_output = model.bert.forward(
            input_ids=torch.LongTensor(input_ids).unsqueeze(0),
            token_type_ids=torch.LongTensor(token_type_ids).unsqueeze(0),
            attention_mask=torch.LongTensor(input_masks).unsqueeze(0),
            output_all_encoded_layers=False,
            checkpoint_activations=False,
            position_ids=torch.LongTensor(position_ids).unsqueeze(0),
            position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0)
        )

        torch.save(pooled_output, f"./embeddings/{curr_paper.id}.pt")
        

In [38]:
root_paper_embeddings = torch.load(f"./embeddings/{root_id}.pt")
root_paper_embeddings = torch.nn.functional.normalize(root_paper_embeddings, p=2, dim=1)

paper_keys = list(papers.keys())
paper_keys.remove(root_id)

cols = ["id", "title", "score"]
similarities = pd.DataFrame(columns=cols)

for key in paper_keys:
    paper_embeddings = torch.load(f"./embeddings/{key}.pt")
    paper_embeddings = torch.nn.functional.normalize(paper_embeddings, p=2, dim=1)
    sim = torch.mm(root_paper_embeddings, paper_embeddings.transpose(0, 1))
    results = {
        "id": [key],
        "title": [papers[key].title],
        "score": [sim.detach().numpy()]
    }

    similarities = pd.concat([similarities, pd.DataFrame(results)], ignore_index=True)

In [39]:
similarities.sort_values(by="score", ascending=False).head(25)

Unnamed: 0,id,title,score
241,W2136939460,LSTM can Solve Hard Long Time Lag Problems,[[0.6680576]]
134,W2015861736,Convolutional networks and applications in vision,[[0.66090643]]
100,W2164019165,Improving Word Representations via Global Cont...,[[0.62542415]]
12,W2962739339,Deep Contextualized Word Representations,[[0.61234534]]
237,W2402268235,LSTM neural networks for language modeling,[[0.61059386]]
298,W2978017171,"DistilBERT, a distilled version of BERT: small...",[[0.60469294]]
14,W1902237438,Effective Approaches to Attention-based Neural...,[[0.60294294]]
0,W2194775991,Deep Residual Learning for Image Recognition,[[0.5975584]]
116,W2963918774,Supervised Learning of Universal Sentence Repr...,[[0.59685653]]
296,W2996428491,ALBERT: A Lite BERT for Self-supervised Learni...,[[0.5942767]]


# Print Root Paper Abstract and 5 most similar papers

In [44]:
print(papers[root_id])

W2963403868: Attention is All you Need
The dominant sequence transduction models are based on complex recurrent orconvolutional neural networks in an encoder and decoder configuration. The best performing such models also connect the encoder and decoder through an attentionm echanisms. We propose a novel, simple network architecture based solely onan attention mechanism, dispensing with recurrence and convolutions entirely.Experiments on two machine translation tasks show these models to be superiorin quality while being more parallelizable and requiring significantly less timeto train. Our single model with 165 million parameters, achieves 27.5 BLEU onEnglish-to-German translation, improving over the existing best ensemble result by over 1 BLEU. On English-to-French translation, we outperform the previoussingle state-of-the-art with model by 0.7 BLEU, achieving a BLEU score of 41.1. 


In [46]:
ids = similarities.sort_values(by="score", ascending=False).head(5)["id"]
for id in ids.values:
    print(papers[id])
    print("\n\n")

W2136939460: LSTM can Solve Hard Long Time Lag Problems
Standard recurrent nets cannot deal with long minimal time lags between relevant signals. Several recent NIPS papers propose alternative methods. We first show: problems used to promote various previous algorithms can be solved more quickly by random weight guessing than by the proposed algorithms. We then use LSTM, our own recent algorithm, to solve a hard problem that can neither be quickly solved by random search nor by any other recurrent net algorithm we are aware of. 



W2015861736: Convolutional networks and applications in vision
Intelligent tasks, such as visual perception, auditory perception, and language understanding require the construction of good internal representations of the world (or features)? which must be invariant to irrelevant variations of the input while, preserving relevant information. A major question for Machine Learning is how to learn such good features automatically. Convolutional Networks (ConvN