In [1]:
# %pip install transformers
import pandas as pd
from cogdl.oag import oagbert
import torch
import re
import numpy as np
import ipywidgets as widgets
import requests
import json
from dataclasses import dataclass
from typing import Dict, List, Optional
import os
import pymilvus
from pymilvus import (
    connections,
    utility,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection
)

In [2]:
# Parameters
max_depth = 2
ignore_related = True
ignore_referenced = False
base_works_url = "https://api.openalex.org/works"

In [3]:
@dataclass
class Article:
    # Keeping track of some needed paper details
    id: str
    title: str
    inverted_abstract: Dict[str, List[int]]
    authors: List[str]
    host_venue: str
    affiliations: List[str]
    concepts: List[str]
    references: List[str]
    related: List[str]

    def get_abstract(self) -> str:
        abstract = dict()
        for k, v in self.inverted_abstract.items():
            for i in v:
                abstract[i] = k

        final = ""
        for i in sorted(abstract.keys()):
            final += abstract[i] + " "
        return final
    
    def fetch_references_queries(self):
        # open alex only allows 50 OR joins per request
        queries = list()
        for i in range(0, len(self.references), 50):
            queries.append('|'.join(self.references[i:i+50]))
        return queries
    
    def fetch_related_queries(self):
        # open alex only allows 50 OR joins per request
        queries = list()
        for i in range(0, len(self.related), 50):
            queries.append('|'.join(self.related[i:i+50]))
        return queries
    
    def __str__(self):
        return f"{self.id}: {self.title}\n{self.get_abstract()}"

In [4]:
def fetch_article(result):
    work_id = result["id"].split('/')[-1]
    title = result["title"]
    inverted_abstract = result['abstract_inverted_index']
    authors = [authorship['author']['display_name'] for authorship in result['authorships']]
    host_venue = result['host_venue']['publisher']
    institutions = list()

    for authorship in result['authorships']:
        for institution in authorship['institutions']: 
            if institution['display_name'] not in institutions:
                institutions.append(institution['display_name'])

    concepts = [concept['display_name'] for concept in result['concepts'] if float(concept['score']) > 0.5]
    referenced_works = [work.split('/')[-1] for work in result['referenced_works']]
    related_works = [work.split('/')[-1] for work in result['related_works']]

    return Article(
        work_id,
        title if title else "",
        inverted_abstract if inverted_abstract else {"": [0]},
        authors,
        host_venue if host_venue else "",
        institutions,
        concepts,
        referenced_works,
        related_works
    )

In [5]:
connection = pymilvus.connections.connect(
    alias='default',
    host='localhost',
    port='19530'
)

In [8]:
fields = [
    FieldSchema(name='pk', dtype=DataType.VARCHAR, max_length=32, is_primary=True),
    FieldSchema(name='embeddings', dtype=DataType.FLOAT_VECTOR, dim=768)
]
collection_name = 'paper_trail_test'
schema = CollectionSchema(fields, "Testing")
paper_trail_collection = Collection(collection_name, schema)

In [7]:
utility.drop_collection(collection_name)

# Search for Article Title
Edit the title variable below to search for a paper. If not exact then returns 25 most relevant papers in the OpenAlex dataset. Select the paper in the dropdown menu.

In [31]:
title = "babies"
title = title.replace(" ", "%20")
req = requests.get(base_works_url+f"?filter=title.search:{title}")
response = json.loads(req.content)

relevant_titles = [result['title'] for result in response['results']]
title_selector = widgets.Dropdown(
    options=relevant_titles,
    value=relevant_titles[0],
    description="Title: "
)
display(title_selector)

Dropdown(description='Title: ', options=("Mama's Baby, Papa's Maybe: An American Grammar Book", 'The baby boom…

In [103]:
raise Exception("Please select correct title above. If done, run all cells below this one.")

Exception: Please select correct title above. If done, run all cells below this one.

In [32]:
index = relevant_titles.index(title_selector.value)
papers = dict()
root_id = response['results'][index]['id'].split('/')[-1]

papers[root_id] = fetch_article(response['results'][index])

In [33]:
use_references = ignore_referenced != True
use_related = ignore_related != True

related_works: Dict[int, List[Article]] = {}

def get_relevant_papers(current_depth: int, previous: List[Article]):
    related_works[current_depth] = []
    print(current_depth)
    for parent in previous:
        if use_references and len(parent.references) > 0:
            for query in parent.fetch_references_queries():         
                req = requests.get(base_works_url + f'?filter=openalex_id:{query}')
                res = json.loads(req.content)
                for result in res["results"]:
                    paper_id = result['id'].split('/')[-1]
                    if paper_id not in papers.keys():
                        temp = fetch_article(result)
                        papers[temp.id] = temp
                        related_works[current_depth].append(temp)
            
        if (use_related and len(parent.related) > 0) or len(parent.references) == 0:
            for query in parent.fetch_related_queries():  
                req = requests.get(base_works_url + f'?filter=openalex_id:{query}')
                res = json.loads(req.content)
                for result in res["results"]:
                    paper_id = result['id'].split('/')[-1]
                    if paper_id not in papers.keys():
                        temp = fetch_article(result)
                        papers[temp.id] = temp
                        related_works[current_depth].append(temp)

    if current_depth < max_depth:
        get_relevant_papers(current_depth+1, related_works[current_depth])

In [34]:
get_relevant_papers(1, [papers[root_id]])

1
2


In [35]:
tokenizer, model = oagbert("oagbert-v2")

In [36]:
if not os.path.exists("./embeddings/"):
    os.mkdir("./embeddings/")

files = os.listdir("./embeddings/")
files = [file.split('.')[0] for file in files]

for key in papers.keys():
    curr_paper = papers[key]
    input_ids, input_masks, token_type_ids, masked_lm_labels, position_ids, position_ids_second, masked_positions, num_spans = model.build_inputs(
        title=curr_paper.title, 
        abstract=curr_paper.get_abstract(), 
        venue=curr_paper.host_venue, 
        authors=curr_paper.authors, 
        concepts=curr_paper.concepts, 
        affiliations=curr_paper.affiliations
    )

    sequence_output, pooled_output = model.bert.forward(
        input_ids=torch.LongTensor(input_ids).unsqueeze(0),
        token_type_ids=torch.LongTensor(token_type_ids).unsqueeze(0),
        attention_mask=torch.LongTensor(input_masks).unsqueeze(0),
        output_all_encoded_layers=False,
        checkpoint_activations=False,
        position_ids=torch.LongTensor(position_ids).unsqueeze(0),
        position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0)
    )

    pooled_normalized = torch.nn.functional.normalize(pooled_output, p=2, dim=1)

    paper_trail_collection.insert([
            [key], 
            [pooled_normalized.tolist()[0]]
        ])
    
paper_trail_collection.flush()

In [37]:
print(paper_trail_collection.num_entities)

209


In [38]:
index_params = {
    "metric_type": "IP",
    "index_type": "IVF_FLAT",
    "params": {"nlist": 128}
}

paper_trail_collection.create_index(field_name="embeddings", index_params=index_params)

RPC error: [create_index], <MilvusException: (code=1, message=create index failed, collection is loaded, please release it first)>, <Time:{'RPC start': '2023-02-17 14:57:36.789319', 'RPC error': '2023-02-17 14:57:36.791198'}>


MilvusException: <MilvusException: (code=1, message=create index failed, collection is loaded, please release it first)>

In [39]:
paper_trail_collection.load()
root_paper_embeddings = paper_trail_collection.query(
    expr = f'pk == "{root_id}"',
    output_fields=['embeddings']
)
root_paper_embeddings = torch.Tensor(root_paper_embeddings[0]['embeddings'])
root_paper_embeddings.shape

torch.Size([768])

In [40]:
paper_trail_collection.load()
root_paper_embeddings = paper_trail_collection.query(
    expr = f'pk == "{root_id}"',
    output_fields=['embeddings']
)
root_paper_embeddings = torch.Tensor([root_paper_embeddings[0]['embeddings']])

paper_keys = list(papers.keys())
paper_keys.remove(root_id)

cols = ["id", "title", "score"]
similarities = pd.DataFrame(columns=cols)

for key in paper_keys:
    paper_embeddings = paper_trail_collection.query(
        expr = f'pk == "{key}"',
        output_fields=['embeddings']
    )
    paper_embeddings = torch.Tensor([paper_embeddings[0]['embeddings']])
    sim = torch.mm(root_paper_embeddings, paper_embeddings.transpose(0, 1))
    results = {
        "id": [key],
        "title": [papers[key].title],
        "score": [sim.detach().numpy()]
    }

    similarities = pd.concat([similarities, pd.DataFrame(results)], ignore_index=True)

In [41]:
similarities.sort_values(by="score", ascending=False).head(25)

Unnamed: 0,id,title,score
19,W1977791619,Short-Run Instability in Single-Family Housing...,[[0.9968647]]
13,W2031190214,The Effects of Anticipated Inflation on Housin...,[[0.9968407]]
47,W2129610866,The Effects of Property Taxes and Local Public...,[[0.9967787]]
45,W2081714882,A Pure Theory of Local Expenditures,[[0.99646693]]
51,W2002754280,The Income Tax and Charitable Contributions,[[0.99618965]]
50,W1999267230,The Demand for Housing: A Study in Specificati...,[[0.99566674]]
20,W3124557698,The Economics of Tenure Choice: 1955-79,[[0.9952711]]
6,W2113898792,The Estimation of Simultaneous Equation Models...,[[0.9950483]]
38,W2078902383,Short-Run Analysis of Fiscal Policy in a Simpl...,[[0.9950331]]
40,W2008857635,External and Internal Adjustment Costs and the...,[[0.9948812]]


# Print Root Paper Abstract and 5 most similar papers

In [42]:
print(papers[root_id])

W3122369839: The baby boom, the baby bust, and the housing market
This paper explores the impact of demographic changes on the housing market in the US, 1st by reviewing the facts about the Baby Boom, 2nd by linking age and housing demand using census data for 1970 and 1980, 3rd by computing the effect of demand on price of housing and on the quantity of residential capital, and last by constructing a theoretical model to plot the predictability of the jump in demand caused by the Baby Boom. The Baby Boom in the U.S. lasted from 1946-1964, with a peak in 1957 when 4.3 million babies were born. In 1980 19.7% of the population were aged 20-30, compared to 13.3% in 1960. Demand for housing was modeled for a given household from census data, resulting in the finding that demand rises sharply at age 20-30, then declines after age 40 by 1% per year. Thus between 1970 and 1980 the real value of housing for an adult at any given age jumped 50%, while the real disposable personal income per cap

In [43]:
ids = similarities.sort_values(by="score", ascending=False).head(5)["id"]
for id in ids.values:
    print(papers[id])
    print("\n\n")

W1977791619: Short-Run Instability in Single-Family Housing Starts
Abstract This article provides a statistical analysis of an inventory-theoretic approach to the study of the speculative sector of the single-family housing starts. Seen as a source of supply of new housing units the speculative housing starts are found to be heavily affected by the short-term interest rate and the monetary base. 



W2031190214: The Effects of Anticipated Inflation on Housing Market Equilibrium
An increase in the anticipated rate of inflation causes distortions in the housing market due to a nonindexed tax system. Since nominal rather than real interest payments are tax deductible, an increase in inflation decreases the aftertax cost of capital for homeowners, which in turn increases the demand for housing and increases its real price. This tax gain is shown to be larger for rental housing than for owner-occupied housing. In a competitive market, this implies that although the real price of housing inc

# Testing Milvus Search Functionality (cosine similarity)

Slightly different results from above, but still very similar

In [44]:
root_paper_embeddings = root_paper_embeddings.tolist()[0]

In [45]:
len(root_paper_embeddings)

768

In [46]:
search_params = {"metric_type": "IP", "params": {"nprobe": 10}, "offset": 5}

In [47]:
results = paper_trail_collection.search(
	data=[root_paper_embeddings], 
	anns_field="embeddings", 
	param=search_params,
	limit=10, 
	expr=None,
	consistency_level="Strong"
)

In [48]:
results[0].ids

['W2070001648', 'W1993829876', 'W2058686713', 'W2140347937', 'W1599416785', 'W1977791619', 'W3121418881', 'W2031190214', 'W2129610866', 'W2008636598']

In [49]:
results[0].distances

[0.9973241090774536, 0.997285008430481, 0.9972245693206787, 0.9971573352813721, 0.9971568584442139, 0.9968646168708801, 0.9968504309654236, 0.9968406558036804, 0.9967787265777588, 0.9967279434204102]

In [50]:
for i in range(5):
    print(papers[results[0].ids[i]])
    print(f"Distance: {results[0].distances[i]}\n\n")

KeyError: 'W2070001648'

I accidentally nuked the previous papers, but the embeddings were kept in the Milvus instance. Since I didn't filter for only direct references to the paper it tried offering similar papers outside the references

# Only Include Direct References

In [None]:
direct_references = papers[root_id].references if papers[root_id].references else papers[root_id].related # this is gross
direct_references = [f'"{ref}"' for ref in direct_references]
direct_references = ', '.join(direct_references)

results = paper_trail_collection.search(
	data=[root_paper_embeddings], 
	anns_field="embeddings", 
	param=search_params,
	limit=10, 
	expr=f"pk in [{direct_references}]",
	consistency_level="Strong"
)

In [None]:
for i in range(5):
    print(papers[results[0].ids[i]])
    print(f"Distance: {results[0].distances[i]}\n\n")

W2006758266: ACCELERATING INFLATION OR RISING UNEMPLOYMENT -IS THERE AN ALTERNATIVE?
Empirical observations suggest the existence of an unstable inflation-unemployment trade-off: whereas high levels of employment generate accelerating inflation rates, stability of the price level seems to require growing rates of unemployment. Under these conditions, macroeconomic policy has to produce countervailing economic fluctuations in order to limit the extent of the two-sided instability. This is discussed within the framework of a macroeconomic model, in which prices and wages are determined by potential competition, and in which employment depends on monetary demand. The model is confronted with prevailing natural rate theories in which a stable equilibrium rate of employment is determined by supply conditions. 
Distance: 0.08184956759214401


W1502198272: 
 
Distance: 0.07037132978439331


W1986173648: Adjustment Costs and Immiserizing Growth in LDCs
Using a general equilibrium model of a sm

In [51]:
root_paper_embeddings

[-0.025229308754205704,
 0.008874981664121151,
 0.031361814588308334,
 -0.026391707360744476,
 -0.028850480914115906,
 -0.029651520773768425,
 0.05034702271223068,
 -0.04259222000837326,
 0.009890438988804817,
 -0.005985609255731106,
 -0.0015636623138561845,
 0.004788116551935673,
 0.0015794093487784266,
 -0.009480288252234459,
 -0.03993260860443115,
 0.027945488691329956,
 -0.0013326369225978851,
 0.009243334643542767,
 -0.032626789063215256,
 0.004350637085735798,
 -0.03266753628849983,
 0.03139306604862213,
 -0.03789612650871277,
 -0.041920147836208344,
 -0.04251689836382866,
 0.016640841960906982,
 0.03893373906612396,
 0.025439921766519547,
 -0.053355056792497635,
 -0.023595143109560013,
 0.05321202799677849,
 0.011292262934148312,
 0.053757183253765106,
 0.029380492866039276,
 0.030563652515411377,
 -0.0003833558293990791,
 0.02687254548072815,
 -0.016594110056757927,
 -0.0012419149279594421,
 0.012365070171654224,
 0.06617121398448944,
 -0.0415433794260025,
 -0.06538936495780945