Try to extract coordinates from tables

In [1]:
import json
from pathlib import Path
from ontology_learner.publication import Publication
from dotenv import load_dotenv
import os
from ontology_learner.gpt4_direct.gpt_term_mining import (
    mk_batch_script,
    run_batch_request,
    wait_for_batch_completion
    )
from ontology_learner.gpt4_direct.together_utils import chat
from ontology_learner.coordinate_extraction.extraction_utils import get_coord_prompt
import pandas as pd
from joblib import Parallel, delayed
from tqdm_joblib import tqdm_joblib
from tqdm import tqdm
import time

load_dotenv()
api_key = os.getenv("OPENAI")

datadir = Path(os.getenv('DATADIR'))
print(datadir)

jsondir = datadir / 'json'


def get_messages(prompt):
    return [
        {"role": "system", "content": "You are an expert in neuroimaging research."},
        {"role": "user", "content": prompt}
    ]



/Users/poldrack/Dropbox/data/ontology-learner/data


  from tqdm.autonotebook import tqdm


In [2]:
coord_df = pd.read_csv(datadir / 'coordinate_extraction/coords_df_gpt4.csv')

all_pmids = sorted(list(coord_df.pmcid.unique()))
print(len(all_pmids))



13421


In [3]:
def get_tabledata(pmids, jsondir, use_tqdm=False):
    tabledata = {}
    # use tqdm 
    if use_tqdm:
        iter_pmids = tqdm(pmids)
    else:
        iter_pmids = pmids
    
    for pmid in iter_pmids:
        p = Publication(pmid, datadir=jsondir)
        p.parse_sections()
        if 'TABLE' in p.sections and 'coordinate' in p.sections['TABLE'].lower():
            tabledata[str(pmid)] = p.sections['TABLE']

    return tabledata



In [4]:
if (datadir / 'coordinate_extraction/coord_results_llama3.json').exists():
    with open(datadir / 'coordinate_extraction/coord_results_llama3.json', 'r') as f:
        responses = json.load(f)
    print(f'loaded {len(responses)} responses from file')
else:
    responses = {}

loaded 4996 responses from file


In [5]:
def process_pmid(pmid, tbldata, max_tries=3):
    prompt = get_messages(get_coord_prompt(tbldata))
    ntries = 0
    while ntries < max_tries:
        try:
            result = chat(prompt, temperature=0.0)
            return (pmid, result)
        except Exception as e:
            ntries += 1
            print(f"Error processing {pmid}: {e}")
            print(f"Retrying {ntries} of {max_tries}...")
            # wait 500 ms
            time.sleep(0.5)
    return None


In [15]:
n_papers_to_process = 2500
# pmids need to be strings for json keys
pmids = [str(i) for i in all_pmids if str(i) not in responses][:n_papers_to_process]
tabledata = get_tabledata(pmids, jsondir)
print(f'processing {len(pmids)} papers')

# results = Parallel(n_jobs=10)(
#     delayed(process_pmid)(pmid, responses, tabledata) 
#     for pmid in pmids
# )


with tqdm_joblib(desc="llama3", total=len(pmids)) as progress_bar:
    results = Parallel(n_jobs=24)(delayed(process_pmid)(pmid, tabledata[pmid]) for pmid in pmids)

for result in results:
    if result is not None:
        pmid, response = result
        responses[pmid] = response

with open(datadir / 'coordinate_extraction/coord_results_llama3.json', 'w') as f:
    json.dump(responses, f, indent=4)


processing 953 papers


llama3:   0%|          | 0/953 [00:00<?, ?it/s]

Error processing 7929989: Expecting value: line 42 column 27 (char 973)
Retrying 1 of 3...
Error processing 4600370: Extra data: line 2 column 1 (char 3)
Retrying 1 of 3...
Error processing 6382927: Expecting value: line 17 column 10 (char 351)
Retrying 1 of 3...
Error processing 4600370: Extra data: line 2 column 1 (char 3)
Retrying 2 of 3...
Error processing 7929989: Expecting value: line 42 column 27 (char 973)
Retrying 2 of 3...
Error processing 3040725: Expecting value: line 94 column 14 (char 2354)
Retrying 1 of 3...
Error processing 4174863: Expecting value: line 74 column 18 (char 1965)
Retrying 1 of 3...
Error processing 4600370: Extra data: line 2 column 1 (char 3)
Retrying 3 of 3...
Error processing 6382927: Expecting value: line 17 column 10 (char 351)
Retrying 2 of 3...
Error processing 6382927: Expecting value: line 17 column 10 (char 351)
Retrying 3 of 3...
Error processing 7929989: Expecting value: line 42 column 27 (char 973)
Retrying 3 of 3...
Error processing 5315540

In [16]:
len(responses)

13412

In [8]:
len(responses)

7490