## Create specter data
Code here was adapted from specter's github page

Here, we are taking our datasets and converting them to specter embeddings for training/testing.

In [7]:
import pandas as pd
import json#
import os

In [None]:
# First set your data directory. Recommend keeping data somewhere other than where you keep code!
data_dir = os.path.abspath(r'C:\Users\aday\OneDrive - SAGE Publishing\PROJECT_DATA\pubmed_case_report_classifier\data')

In [8]:
datasets = ['train','dev','test']


for dataset in datasets:
    dataset_json = {}
    dataset_pids = []
    data = pd.read_csv(os.path.join(data_dir,'{}.csv'.format(dataset)), dtype=str)
    
    dataset_pids = list(data.doi.values)
    for i,row in data.iterrows():
        title = row['articletitle']
        abstract = row['abstract']
        doi = row['doi']
        dataset_json[doi] = {'title': title,
                            'abstract':abstract,
                            'paper_id':doi}
        doistr = '\r\n'.join(dataset_pids)
        
    output_id_path = os.path.join(data_dir,'specter_ids_{}.json'.format(dataset))
    with open(output_id_path,'w') as f:
        f.write(doistr)
    output_data_path = os.path.join(data_dir,'specter_data_{}.json'.format(dataset))
    with open(output_data_path,'w') as f:
        json.dump(dataset_json,f)
    print(f'Written {dataset}:', len(dataset_pids), len(dataset_json))

Written train: 16974 16974
Written dev: 2140 2140
Written test: 2166 2166


# Embed with API

In [9]:
# write out
# code adapted from: https://github.com/allenai/paper-embedding-public-apis
from typing import Dict, List
import json
import requests

URL = "https://model-apis.semanticscholar.org/specter/v1/invoke"
MAX_BATCH_SIZE = 16

def chunks(lst, chunk_size=MAX_BATCH_SIZE):
    for i in range(0, len(lst), chunk_size):
        yield lst[i : i + chunk_size]

def embed(papers):
    embeddings_by_paper_id: Dict[str, List[float]] = {}
    for chunk in chunks(papers):
        # Allow Python requests to convert the data above to JSON
        response = requests.post(URL, json=chunk)
        if response.status_code != 200:
            raise RuntimeError("Sorry, something went wrong, please try later!")
        for paper in response.json()["preds"]:
            embeddings_by_paper_id[paper["paper_id"]] = paper["embedding"]
    return embeddings_by_paper_id

In [10]:
%%time
datasets = ['train','dev','test']

for dataset in datasets:
    dataset_json = []
    data = pd.read_csv(os.path.join(data_dir,'{}.csv'.format(dataset)), dtype=str)
    for i,row in data.iterrows():
        title = row['articletitle']
        abstract = row['abstract']
        doi = row['doi']
        # ensure that everything is right type
        if type(abstract)==type(title)==type(doi)==str:
            dataset_json.append({'title': title,
                                'abstract':abstract,
                                'paper_id':doi})
#             doistr = '\r\n'.join(dataset_pids)
    print('EMBEDDING:', len(dataset_json))
    all_embeddings = embed(dataset_json)
    dataset_path = os.path.join(data_dir,f'{dataset} embeddings.json')
    with open(dataset_path, 'w') as f:
        json.dump(all_embeddings,f)

EMBEDDING: 16974
EMBEDDING: 2140
EMBEDDING: 2166
Wall time: 1h 3min 50s


# check embeddings
Embeddings should be 768-length vectors.

In [11]:
test_data_path = os.path.join(data_dir,'test embeddings.json')
with open(test_data_path,'r') as f:
    test_data = json.load(f)
len(test_data)

2166

In [12]:
data = pd.read_csv(os.path.join(data_dir,'test.csv'))
for i,row in data.iterrows():
    doi = row['doi']
    embedding = test_data[doi]
    assert len(embedding)==768
    print(len(embedding),embedding)
    break

768 [-1.0540858507156372, -3.4259109497070312, 3.916700601577759, 6.795539855957031, 2.9405980110168457, -0.9851967096328735, 2.230900287628174, 0.9079151749610901, 2.8595125675201416, -1.8966714143753052, 6.147038459777832, -4.22273063659668, 2.4962775707244873, -0.5903092622756958, -0.44812774658203125, -6.857989311218262, 1.9606411457061768, -2.4655890464782715, 2.719186782836914, 0.38086390495300293, -0.3229959011077881, 0.38820257782936096, -3.6748452186584473, -5.199097156524658, 3.9746737480163574, 1.2412915229797363, 1.1937206983566284, -2.438277244567871, 0.5980643033981323, 1.471097469329834, 0.45259472727775574, -2.23211669921875, -1.8075674772262573, 1.3865406513214111, 3.0615172386169434, -3.0413858890533447, 0.20692133903503418, 2.0704336166381836, -4.731867790222168, 0.7591934204101562, -1.2635823488235474, -5.734018325805664, 7.364973068237305, -0.945850670337677, 0.93680739402771, -0.845180869102478, 2.342761993408203, -2.7090094089508057, -2.059267997741699, 1.1609206