In [1]:
# %pip install transformers
import pandas as pd
from cogdl.oag import oagbert
import torch
import re
import numpy as np
import ipywidgets as widgets
import requests
import json
from dataclasses import dataclass
from typing import Dict, List, Optional
import os
import pymilvus
from pymilvus import (
    connections,
    utility,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection
)

In [2]:
# Parameters
max_depth = 2
ignore_related = True
ignore_referenced = False
base_works_url = "https://api.openalex.org/works"

In [3]:
@dataclass
class Article:
    # Keeping track of some needed paper details
    id: str
    title: str
    inverted_abstract: Dict[str, List[int]]
    authors: List[str]
    host_venue: str
    affiliations: List[str]
    concepts: List[str]
    references: List[str]
    related: List[str]

    def get_abstract(self) -> str:
        abstract = dict()
        for k, v in self.inverted_abstract.items():
            for i in v:
                abstract[i] = k

        final = ""
        for i in sorted(abstract.keys()):
            final += abstract[i] + " "
        return final
    
    def fetch_references_queries(self):
        # open alex only allows 50 OR joins per request
        queries = list()
        for i in range(0, len(self.references), 50):
            queries.append('|'.join(self.references[i:i+50]))
        return queries
    
    def fetch_related_queries(self):
        # open alex only allows 50 OR joins per request
        queries = list()
        for i in range(0, len(self.related), 50):
            queries.append('|'.join(self.related[i:i+50]))
        return queries
    
    def __str__(self):
        return f"{self.id}: {self.title}\n{self.get_abstract()}"

In [4]:
def fetch_article(result):
    work_id = result["id"].split('/')[-1]
    title = result["title"]
    inverted_abstract = result['abstract_inverted_index']
    authors = [authorship['author']['display_name'] for authorship in result['authorships']]
    host_venue = result['host_venue']['publisher']
    institutions = list()

    for authorship in result['authorships']:
        for institution in authorship['institutions']: 
            if institution['display_name'] not in institutions:
                institutions.append(institution['display_name'])

    concepts = [concept['display_name'] for concept in result['concepts'] if float(concept['score']) > 0.5]
    referenced_works = [work.split('/')[-1] for work in result['referenced_works']]
    related_works = [work.split('/')[-1] for work in result['related_works']]

    return Article(
        work_id,
        title if title else "",
        inverted_abstract if inverted_abstract else {"": [0]},
        authors,
        host_venue if host_venue else "",
        institutions,
        concepts,
        referenced_works,
        related_works
    )

In [5]:
connection = pymilvus.connections.connect(
    alias='default',
    host='localhost',
    port='19530'
)

In [63]:
fields = [
    FieldSchema(name='work_id', dtype=DataType.VARCHAR, max_length=32, is_primary=True),
    FieldSchema(name='embeddings', dtype=DataType.FLOAT_VECTOR, dim=768)
]
collection_name = 'Article_Vectors'
schema = CollectionSchema(fields, "Testing")
paper_trail_collection = Collection(collection_name, schema)

In [62]:
utility.drop_collection(collection_name)

# Search for Article Title
Edit the title variable below to search for a paper. If not exact then returns 25 most relevant papers in the OpenAlex dataset. Select the paper in the dropdown menu.

In [64]:
title = "BERT"
title = title.replace(" ", "%20")
req = requests.get(base_works_url+f"?filter=title.search:{title}")
response = json.loads(req.content)

relevant_titles = [result['title'] for result in response['results']]
title_selector = widgets.Dropdown(
    options=relevant_titles,
    value=relevant_titles[0],
    description="Title: "
)
display(title_selector)

Dropdown(description='Title: ', options=('RoBERTa: A Robustly Optimized BERT Pretraining Approach', 'BERT: Pre…

In [103]:
raise Exception("Please select correct title above. If done, run all cells below this one.")

Exception: Please select correct title above. If done, run all cells below this one.

In [85]:
index = relevant_titles.index(title_selector.value)
papers = dict()
root_id = response['results'][index]['id'].split('/')[-1]

papers[root_id] = fetch_article(response['results'][index])

In [86]:
use_references = ignore_referenced != True
use_related = ignore_related != True

related_works: Dict[int, List[Article]] = {}

def get_relevant_papers(current_depth: int, previous: List[Article]):
    related_works[current_depth] = []
    print(current_depth)
    for parent in previous:
        if use_references and len(parent.references) > 0:
            for query in parent.fetch_references_queries():         
                req = requests.get(base_works_url + f'?filter=openalex_id:{query}')
                res = json.loads(req.content)
                for result in res["results"]:
                    paper_id = result['id'].split('/')[-1]
                    if paper_id not in papers.keys():
                        temp = fetch_article(result)
                        papers[temp.id] = temp
                        related_works[current_depth].append(temp)
            
        if (use_related and len(parent.related) > 0) or len(parent.references) == 0:
            for query in parent.fetch_related_queries():  
                req = requests.get(base_works_url + f'?filter=openalex_id:{query}')
                res = json.loads(req.content)
                for result in res["results"]:
                    paper_id = result['id'].split('/')[-1]
                    if paper_id not in papers.keys():
                        temp = fetch_article(result)
                        papers[temp.id] = temp
                        related_works[current_depth].append(temp)

    if current_depth < max_depth:
        get_relevant_papers(current_depth+1, related_works[current_depth])

In [87]:
get_relevant_papers(1, [papers[root_id]])

1
2


In [14]:
tokenizer, model = oagbert("oagbert-v2")

In [88]:
for key in papers.keys():
    curr_paper = papers[key]
    input_ids, input_masks, token_type_ids, masked_lm_labels, position_ids, position_ids_second, masked_positions, num_spans = model.build_inputs(
        title=curr_paper.title, 
        abstract=curr_paper.get_abstract(), 
        venue=curr_paper.host_venue, 
        authors=curr_paper.authors, 
        concepts=curr_paper.concepts, 
        affiliations=curr_paper.affiliations
    )

    sequence_output, pooled_output = model.bert.forward(
        input_ids=torch.LongTensor(input_ids).unsqueeze(0),
        token_type_ids=torch.LongTensor(token_type_ids).unsqueeze(0),
        attention_mask=torch.LongTensor(input_masks).unsqueeze(0),
        output_all_encoded_layers=False,
        checkpoint_activations=False,
        position_ids=torch.LongTensor(position_ids).unsqueeze(0),
        position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0)
    )

    pooled_normalized = torch.nn.functional.normalize(pooled_output, p=2, dim=1)

    paper_trail_collection.insert([
            [key], 
            [pooled_normalized.tolist()[0]]
        ])
    
paper_trail_collection.flush()

In [69]:
print(paper_trail_collection.num_entities)

484


In [91]:
index_params = {
    "metric_type": "IP",
    "index_type": "IVF_FLAT",
    "params": {"nlist": 128}
}

paper_trail_collection.create_index(field_name="embeddings", index_params=index_params)

RPC error: [create_index], <MilvusException: (code=1, message=create index failed, collection is loaded, please release it first)>, <Time:{'RPC start': '2023-02-23 10:39:54.327895', 'RPC error': '2023-02-23 10:39:54.329801'}>


MilvusException: <MilvusException: (code=1, message=create index failed, collection is loaded, please release it first)>

In [89]:
paper_trail_collection.load()
root_paper_embeddings = paper_trail_collection.query(
    expr = f'work_id == "{root_id}"',
    output_fields=['embeddings']
)
root_paper_embeddings = torch.Tensor(root_paper_embeddings[0]['embeddings'])
root_paper_embeddings.shape

torch.Size([768])

In [90]:
paper_trail_collection.load()
root_paper_embeddings = paper_trail_collection.query(
    expr = f'work_id == "{root_id}"',
    output_fields=['embeddings']
)
root_paper_embeddings = torch.Tensor([root_paper_embeddings[0]['embeddings']])

paper_keys = list(papers.keys())
paper_keys.remove(root_id)

cols = ["id", "title", "score"]
similarities = pd.DataFrame(columns=cols)

for key in paper_keys:
    paper_embeddings = paper_trail_collection.query(
        expr = f'work_id == "{key}"',
        output_fields=['embeddings']
    )
    paper_embeddings = torch.Tensor([paper_embeddings[0]['embeddings']])
    sim = torch.mm(root_paper_embeddings, paper_embeddings.transpose(0, 1))
    results = {
        "id": [key],
        "title": [papers[key].title],
        "score": [sim.detach().numpy()]
    }

    similarities = pd.concat([similarities, pd.DataFrame(results)], ignore_index=True)

In [92]:
similarities.sort_values(by="score", ascending=False).head(25)

Unnamed: 0,id,title,score
76,W2962784628,Neural Machine Translation of Rare Words with ...,[[0.9945146]]
12,W2963918774,Supervised Learning of Universal Sentence Repr...,[[0.99435556]]
26,W2153579005,Distributed Representations of Words and Phras...,[[0.9941311]]
40,W2158139315,Word Representations: A Simple and General Met...,[[0.99392474]]
41,W2251803266,"Don't count, predict! A systematic comparison ...",[[0.9939003]]
43,W2164019165,Improving Word Representations via Global Cont...,[[0.99379086]]
141,W1973942085,A structured vector space model for word meani...,[[0.99360996]]
317,W2153508793,Automatic Evaluation of Translation Quality fo...,[[0.9935808]]
182,W1897507002,"Entailment, intensionality and text understanding",[[0.99333894]]
47,W2250189634,Linguistic Regularities in Sparse and Explicit...,[[0.99310374]]


# Print Root Paper Abstract and 5 most similar papers

In [93]:
print(papers[root_id])

W2970641574: Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks
BERT (Devlin et al., 2018) and RoBERTa (Liu et al., 2019) has set a new state-of-the-art performance on sentence-pair regression tasks like semantic textual similarity (STS). However, it requires that both sentences are fed into the network, which causes a massive computational overhead: Finding the most similar pair in a collection of 10,000 sentences requires about 50 million inference computations (~65 hours) with BERT. The construction of BERT makes it unsuitable for semantic similarity search as well as for unsupervised tasks like clustering. In this publication, we present Sentence-BERT (SBERT), a modification of the pretrained BERT network that use siamese and triplet network structures to derive semantically meaningful sentence embeddings that can be compared using cosine-similarity. This reduces the effort for finding the most similar pair from 65 hours with BERT / RoBERTa to about 5 seconds with SBERT

In [94]:
ids = similarities.sort_values(by="score", ascending=False).head(5)["id"]
for id in ids.values:
    print(papers[id])
    print("\n\n")

W2962784628: Neural Machine Translation of Rare Words with Subword Units
Neural machine translation (NMT) models typically operate with a fixed vocabulary, but translation is an open-vocabulary problem. Previous work addresses the translation of out-of-vocabulary words by backing off to a dictionary. In this paper, we introduce a simpler and more effective approach, making the NMT model capable of open-vocabulary translation by encoding rare and unknown words as sequences of subword units. This is based on the intuition that various word classes are translatable via smaller units than words, for instance names (via character copying or transliteration), compounds (via compositional translation), and cognates and loanwords (via phonological and morphological transformations). We discuss the suitability of different word segmentation techniques, including simple character ngram models and a segmentation based on the byte pair encoding compression algorithm, and empirically show that subw

# Testing Milvus Search Functionality (cosine similarity)

Slightly different results from above, but still very similar

In [95]:
root_paper_embeddings = root_paper_embeddings.tolist()[0]

In [96]:
len(root_paper_embeddings)

768

In [97]:
search_params = {"metric_type": "IP", "params": {"nprobe": 10}, "offset": 5}

In [98]:
results = paper_trail_collection.search(
	data=[root_paper_embeddings], 
	anns_field="embeddings", 
	param=search_params,
	limit=10, 
	expr=None,
	consistency_level="Strong"
)

In [99]:
results[0].ids

['W2115792525', 'W2130237711', 'W2158139315', 'W2251803266', 'W2164019165', 'W2963216553', 'W2949547296', 'W2085766370', 'W2087556608', 'W2131744502']

In [100]:
results[0].distances

[0.9941201210021973, 0.9940099120140076, 0.9939247369766235, 0.9939004182815552, 0.9937907457351685, 0.9937602281570435, 0.9937376976013184, 0.9937025308609009, 0.9936977624893188, 0.9936720132827759]

In [101]:
for i in range(5):
    print(papers[results[0].ids[i]])
    print(f"Distance: {results[0].distances[i]}\n\n")

W2115792525: The Berkeley FrameNet Project
FrameNet is a three-year NSF-supported project in corpus-based computational lexicography, now in its second year (NSF IRI-9618838, Tools for Lexicon Building). The project's key features are (a) a commitment to corpus evidence for semantic and syntactic generalizations, and (b) the representation of the valences of its target words (mostly nouns, adjectives, and verbs) in which the semantic portion makes use of semantics. The resulting database will contain (a) descriptions of the semantic frames underlying the meanings of the words described, and (b) the valence representation (semantic and syntactic) of several thousand words and phrases, each accompanied by (c) a representative collection of annotated corpus attestations, which jointly exemplify the observed linkings between frame elements and their syntactic realizations (e.g. grammatical function, phrase type, and other syntactic traits). This report will present the project's goals and 

KeyError: 'W2130237711'

I accidentally nuked the previous papers, but the embeddings were kept in the Milvus instance. Since I didn't filter for only direct references to the paper it tried offering similar papers outside the references

# Only Include Direct References

In [102]:
direct_references = papers[root_id].references if papers[root_id].references else papers[root_id].related # this is gross
direct_references = [f'"{ref}"' for ref in direct_references]
direct_references = ', '.join(direct_references)

results = paper_trail_collection.search(
	data=[root_paper_embeddings], 
	anns_field="embeddings", 
	param=search_params,
	limit=10, 
	expr=f"work_id in [{direct_references}]",
	consistency_level="Strong"
)

In [103]:
for i in range(5):
    print(papers[results[0].ids[i]])
    print(f"Distance: {results[0].distances[i]}\n\n")

W2963341956: 
We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers. Unlike recent language representation models (Peters et al., 2018a; Radford et al., 2018), BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly conditioning on both left and right context in all layers. As a result, the pre-trained BERT model can be fine-tuned with just one additional output layer to create state-of-the-art models for a wide range of tasks, such as question answering and language inference, without substantial task-specific architecture modifications. BERT is conceptually simple and empirically powerful. It obtains new state-of-the-art results on eleven natural language processing tasks, including pushing the GLUE score to 80.5 (7.7 point absolute improvement), MultiNLI accuracy to 86.7% (4.6% absolute improvement), SQuAD v1.1 question answering Test F1 to 93.2 (1.5 point absolut