In [1]:
# %pip install transformers
import pandas as pd
from cogdl.oag import oagbert
import torch
import re
import numpy as np
import ipywidgets as widgets
import requests
import json
from dataclasses import dataclass
from typing import Dict, List, Optional
import os
import pymilvus
from pymilvus import (
    connections,
    utility,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection
)

In [2]:
# Parameters
max_depth = 2
ignore_related = True
ignore_referenced = False
base_works_url = "https://api.openalex.org/works"

In [3]:
@dataclass
class Article:
    # Keeping track of some needed paper details
    id: str
    title: str
    inverted_abstract: Dict[str, List[int]]
    authors: List[str]
    host_venue: str
    affiliations: List[str]
    concepts: List[str]
    references: List[str]
    related: List[str]

    def get_abstract(self) -> str:
        abstract = dict()
        for k, v in self.inverted_abstract.items():
            for i in v:
                abstract[i] = k

        final = ""
        for i in sorted(abstract.keys()):
            final += abstract[i] + " "
        return final
    
    def fetch_references_queries(self):
        # open alex only allows 50 OR joins per request
        queries = list()
        for i in range(0, len(self.references), 50):
            queries.append('|'.join(self.references[i:i+50]))
        return queries
    
    def fetch_related_queries(self):
        # open alex only allows 50 OR joins per request
        queries = list()
        for i in range(0, len(self.related), 50):
            queries.append('|'.join(self.related[i:i+50]))
        return queries
    
    def __str__(self):
        return f"{self.id}: {self.title}\n{self.get_abstract()}"

In [4]:
def fetch_article(result):
    work_id = result["id"].split('/')[-1]
    title = result["title"]
    inverted_abstract = result['abstract_inverted_index']
    authors = [authorship['author']['display_name'] for authorship in result['authorships']]
    host_venue = result['host_venue']['publisher']
    institutions = list()

    for authorship in result['authorships']:
        for institution in authorship['institutions']: 
            if institution['display_name'] not in institutions:
                institutions.append(institution['display_name'])

    concepts = [concept['display_name'] for concept in result['concepts'] if float(concept['score']) > 0.5]
    referenced_works = [work.split('/')[-1] for work in result['referenced_works']]
    related_works = [work.split('/')[-1] for work in result['related_works']]

    return Article(
        work_id,
        title if title else "",
        inverted_abstract if inverted_abstract else {"": [0]},
        authors,
        host_venue if host_venue else "",
        institutions,
        concepts,
        referenced_works,
        related_works
    )

In [5]:
connection = pymilvus.connections.connect(
    alias='default',
    host='localhost',
    port='19530'
)

In [44]:
fields = [
    FieldSchema(name='pk', dtype=DataType.VARCHAR, max_length=32, is_primary=True),
    FieldSchema(name='embeddings', dtype=DataType.FLOAT_VECTOR, dim=768)
]
collection_name = 'paper_trail_test'
schema = CollectionSchema(fields, "Testing")
paper_trail_collection = Collection(collection_name, schema)

In [43]:
utility.drop_collection(collection_name)

# Search for Article Title
Edit the title variable below to search for a paper. If not exact then returns 25 most relevant papers in the OpenAlex dataset. Select the paper in the dropdown menu.

In [102]:
title = "Attention is all you need"
title = title.replace(" ", "%20")
req = requests.get(base_works_url+f"?filter=title.search:{title}")
response = json.loads(req.content)

relevant_titles = [result['title'] for result in response['results']]
title_selector = widgets.Dropdown(
    options=relevant_titles,
    value=relevant_titles[0],
    description="Title: "
)
display(title_selector)

Dropdown(description='Title: ', options=('Attention is All you Need', 'Attention Is All You Need', 'Channel At…

In [103]:
raise Exception("Please select correct title above. If done, run all cells below this one.")

Exception: Please select correct title above. If done, run all cells below this one.

In [104]:
index = relevant_titles.index(title_selector.value)
papers = dict()
root_id = response['results'][index]['id'].split('/')[-1]

papers[root_id] = fetch_article(response['results'][index])

In [105]:
use_references = ignore_referenced != True
use_related = ignore_related != True

related_works: Dict[int, List[Article]] = {}

def get_relevant_papers(current_depth: int, previous: List[Article]):
    related_works[current_depth] = []
    print(current_depth)
    for parent in previous:
        if use_references and len(parent.references) > 0:
            for query in parent.fetch_references_queries():         
                req = requests.get(base_works_url + f'?filter=openalex_id:{query}')
                res = json.loads(req.content)
                for result in res["results"]:
                    paper_id = result['id'].split('/')[-1]
                    if paper_id not in papers.keys():
                        temp = fetch_article(result)
                        papers[temp.id] = temp
                        related_works[current_depth].append(temp)
            
        if (use_related and len(parent.related) > 0) or len(parent.references) == 0:
            for query in parent.fetch_related_queries():  
                req = requests.get(base_works_url + f'?filter=openalex_id:{query}')
                res = json.loads(req.content)
                for result in res["results"]:
                    paper_id = result['id'].split('/')[-1]
                    if paper_id not in papers.keys():
                        temp = fetch_article(result)
                        papers[temp.id] = temp
                        related_works[current_depth].append(temp)

    if current_depth < max_depth:
        get_relevant_papers(current_depth+1, related_works[current_depth])

In [106]:
get_relevant_papers(1, [papers[root_id]])

1
2


KeyboardInterrupt: 

In [None]:
tokenizer, model = oagbert("oagbert-v2")

In [None]:
if not os.path.exists("./embeddings/"):
    os.mkdir("./embeddings/")

files = os.listdir("./embeddings/")
files = [file.split('.')[0] for file in files]

for key in papers.keys():
    curr_paper = papers[key]
    input_ids, input_masks, token_type_ids, masked_lm_labels, position_ids, position_ids_second, masked_positions, num_spans = model.build_inputs(
        title=curr_paper.title, 
        abstract=curr_paper.get_abstract(), 
        venue=curr_paper.host_venue, 
        authors=curr_paper.authors, 
        concepts=curr_paper.concepts, 
        affiliations=curr_paper.affiliations
    )

    sequence_output, pooled_output = model.bert.forward(
        input_ids=torch.LongTensor(input_ids).unsqueeze(0),
        token_type_ids=torch.LongTensor(token_type_ids).unsqueeze(0),
        attention_mask=torch.LongTensor(input_masks).unsqueeze(0),
        output_all_encoded_layers=False,
        checkpoint_activations=False,
        position_ids=torch.LongTensor(position_ids).unsqueeze(0),
        position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0)
    )

    pooled_normalized = torch.nn.functional.normalize(pooled_output, p=2, dim=1)

    paper_trail_collection.insert([
            [key], 
            [pooled_normalized.tolist()[0]]
        ])
    
paper_trail_collection.flush()

In [None]:
print(paper_trail_collection.num_entities)

784


In [None]:
index_params = {
    "metric_type": "IP",
    "index_type": "IVF_FLAT",
    "params": {"nlist": 128}
}

paper_trail_collection.create_index(field_name="embeddings", index_params=index_params)

Status(code=0, message=)

In [79]:
paper_trail_collection.load()
root_paper_embeddings = paper_trail_collection.query(
    expr = f'pk == "{root_id}"',
    output_fields=['embeddings']
)
root_paper_embeddings = torch.Tensor(root_paper_embeddings[0]['embeddings'])
root_paper_embeddings.shape

torch.Size([768])

In [80]:
paper_trail_collection.load()
root_paper_embeddings = paper_trail_collection.query(
    expr = f'pk == "{root_id}"',
    output_fields=['embeddings']
)
root_paper_embeddings = torch.Tensor([root_paper_embeddings[0]['embeddings']])

paper_keys = list(papers.keys())
paper_keys.remove(root_id)

cols = ["id", "title", "score"]
similarities = pd.DataFrame(columns=cols)

for key in paper_keys:
    paper_embeddings = paper_trail_collection.query(
        expr = f'pk == "{key}"',
        output_fields=['embeddings']
    )
    paper_embeddings = torch.Tensor([paper_embeddings[0]['embeddings']])
    sim = torch.mm(root_paper_embeddings, paper_embeddings.transpose(0, 1))
    results = {
        "id": [key],
        "title": [papers[key].title],
        "score": [sim.detach().numpy()]
    }

    similarities = pd.concat([similarities, pd.DataFrame(results)], ignore_index=True)

In [81]:
similarities.sort_values(by="score", ascending=False).head(25)

Unnamed: 0,id,title,score
69,W2963341956,,[[0.99487865]]
10,W1840435438,A large annotated corpus for learning natural ...,[[0.9947636]]
254,W1948751323,Hypercolumns for object segmentation and fine-...,[[0.9937664]]
236,W2130576126,Memory-based one-step named-entity recognition,[[0.993732]]
402,W1579429723,Limitations of Co-Training for Natural Languag...,[[0.9935994]]
84,W1889268436,Semantic Compositionality through Recursive Ma...,[[0.9934784]]
153,W2146057216,Using Corpus Statistics on Entities to Improve...,[[0.99345046]]
291,W2126725946,Exploiting Similarities among Languages for Ma...,[[0.99340725]]
374,W2132679783,Knowledge-Based Weak Supervision for Informati...,[[0.99332273]]
110,W2964222246,End-to-end Neural Coreference Resolution,[[0.9933091]]


# Print Root Paper Abstract and 5 most similar papers

In [107]:
print(papers[root_id])

W2963403868: Attention is All you Need
The dominant sequence transduction models are based on complex recurrent orconvolutional neural networks in an encoder and decoder configuration. The best performing such models also connect the encoder and decoder through an attentionm echanisms. We propose a novel, simple network architecture based solely onan attention mechanism, dispensing with recurrence and convolutions entirely.Experiments on two machine translation tasks show these models to be superiorin quality while being more parallelizable and requiring significantly less timeto train. Our single model with 165 million parameters, achieves 27.5 BLEU onEnglish-to-German translation, improving over the existing best ensemble result by over 1 BLEU. On English-to-French translation, we outperform the previoussingle state-of-the-art with model by 0.7 BLEU, achieving a BLEU score of 41.1. 


In [83]:
ids = similarities.sort_values(by="score", ascending=False).head(5)["id"]
for id in ids.values:
    print(papers[id])
    print("\n\n")

W2963341956: 
We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers. Unlike recent language representation models (Peters et al., 2018a; Radford et al., 2018), BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly conditioning on both left and right context in all layers. As a result, the pre-trained BERT model can be fine-tuned with just one additional output layer to create state-of-the-art models for a wide range of tasks, such as question answering and language inference, without substantial task-specific architecture modifications. BERT is conceptually simple and empirically powerful. It obtains new state-of-the-art results on eleven natural language processing tasks, including pushing the GLUE score to 80.5 (7.7 point absolute improvement), MultiNLI accuracy to 86.7% (4.6% absolute improvement), SQuAD v1.1 question answering Test F1 to 93.2 (1.5 point absolut

# Testing Milvus Search Functionality (cosine similarity)

Slightly different results from above, but still very similar

In [92]:
root_paper_embeddings = root_paper_embeddings.tolist()[0]

AttributeError: 'list' object has no attribute 'tolist'

In [85]:
len(root_paper_embeddings)

768

In [93]:
search_params = {"metric_type": "IP", "params": {"nprobe": 10}, "offset": 5}

In [94]:
results = paper_trail_collection.search(
	data=[root_paper_embeddings], 
	anns_field="embeddings", 
	param=search_params,
	limit=10, 
	expr=None,
	consistency_level="Strong"
)

In [95]:
results[0].ids

['W3082274269', 'W2963748441', 'W1903029394', 'W2964110616', 'W2944815030', 'W1948751323', 'W2251939518', 'W2130576126', 'W1579429723', 'W2259472270']

In [96]:
results[0].distances

[0.9941544532775879, 0.9941114187240601, 0.9940299391746521, 0.9939961433410645, 0.9938169717788696, 0.9937663674354553, 0.9937633275985718, 0.9937319755554199, 0.9935994148254395, 0.9935078024864197]

In [97]:
for i in range(5):
    print(papers[results[0].ids[i]])
    print(f"Distance: {results[0].distances[i]}\n\n")

W3082274269: Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
Transfer learning, where a model is first pre-trained on a data-rich task before being fine-tuned on a downstream task, has emerged as a powerful technique in natural language processing (NLP). The effectiveness of transfer learning has given rise to a diversity of approaches, methodology, and practice. In this paper, we explore the landscape of transfer learning techniques for NLP by introducing a unified framework that converts all text-based language problems into a text-to-text format. Our systematic study compares pre-training objectives, architectures, unlabeled data sets, transfer approaches, and other factors on dozens of language understanding tasks. By combining the insights from our exploration with scale and our new ``Colossal Clean Crawled Corpus'', we achieve state-of-the-art results on many benchmarks covering summarization, question answering, text classification, and more. To

KeyError: 'W2964110616'

I accidentally nuked the previous papers, but the embeddings were kept in the Milvus instance. Since I didn't filter for only direct references to the paper it tried offering similar papers outside the references

# Only Include Direct References

In [100]:
direct_references = papers[root_id].references if papers[root_id].references else papers[root_id].related # this is gross
direct_references = [f'"{ref}"' for ref in direct_references]
direct_references = ', '.join(direct_references)

results = paper_trail_collection.search(
	data=[root_paper_embeddings], 
	anns_field="embeddings", 
	param=search_params,
	limit=10, 
	expr=f"pk in [{direct_references}]",
	consistency_level="Strong"
)

In [101]:
for i in range(5):
    print(papers[results[0].ids[i]])
    print(f"Distance: {results[0].distances[i]}\n\n")

W2117130368: A unified architecture for natural language processing
We describe a single convolutional neural network architecture that, given a sentence, outputs a host of language processing predictions: part-of-speech tags, chunks, named entity tags, semantic roles, semantically similar words and the likelihood that the sentence makes sense (grammatically and semantically) using a language model. The entire network is trained jointly on all these tasks using weight-sharing, an instance of multitask learning. All the tasks use labeled data except the language model which is learnt from unlabeled text and represents a novel form of semi-supervised learning for the shared tasks. We show how both multitask learning and semi-supervised learning improve the generalization of the shared tasks, resulting in state-of-the-art-performance. 
Distance: 0.9932626485824585


W2131744502: Distributed Representations of Sentences and Documents
Many machine learning algorithms require the input to be