Try to extract coordinates from tables

In [1]:
import json
from pathlib import Path
from ontology_learner.publication import Publication
from tqdm import tqdm
from dotenv import load_dotenv
import os
from ontology_learner.gpt4_direct.gpt_term_mining import (
    mk_batch_script,
    run_batch_request,
    wait_for_batch_completion
    )
from ontology_learner.gpt4_direct.together_utils import chat
from ontology_learner.extraction_utils import get_coord_prompt
import pandas as pd

load_dotenv()
api_key = os.getenv("OPENAI")

datadir = Path(os.getenv('DATADIR'))
print(datadir)

jsondir = datadir / 'json'


def get_messages(prompt):
    return [
        {"role": "system", "content": "You are an expert in neuroimaging research."},
        {"role": "user", "content": prompt}
    ]



/Users/poldrack/Dropbox/data/ontology-learner/data


In [2]:
coord_df = pd.read_csv(datadir / 'coordinate_extraction/coords_df.csv')

pmids = sorted(list(coord_df.pmid.unique()))
print(len(pmids))



13421


In [3]:
def get_tabledata(pmids):
    tabledata = {}

    for pmid in tqdm(pmids):
        p = Publication(pmid, datadir=jsondir)
        p.parse_sections()
        if 'TABLE' in p.sections and 'coordinate' in p.sections['TABLE'].lower():
            tabledata[str(pmid)] = p.sections['TABLE']

    return tabledata



In [9]:
if (datadir / 'coordinate_extraction/coord_results_llama3.json').exists():
    with open(datadir / 'coordinate_extraction/coord_results_llama3.json', 'r') as f:
        responses = json.load(f)
    print(f'loaded {len(responses)} responses from file')
else:
    responses = {}

loaded 500 responses from file


In [11]:
n_papers_to_process = 20
# pmids need to be strings for json keys
pmids = [str(i) for i in pmids if i not in responses][:n_papers_to_process]
tabledata = get_tabledata(pmids)

max_tries = 3
from joblib import Parallel, delayed

def process_pmid(pmid, responses, tabledata, max_tries):
    if pmid in responses:
        return None
    prompt = get_messages(get_coord_prompt(tabledata[pmid]))
    ntries = 0
    while ntries < max_tries:
        try:
            result = chat(prompt, temperature=0.0)
            return (pmid, result)
        except Exception as e:
            ntries += 1
            print(f"Error processing {pmid}: {e}")
            print(f"Retrying {ntries} of {max_tries}...")
    return None

results = Parallel(n_jobs=10)(
    delayed(process_pmid)(pmid, responses, tabledata, max_tries) 
    for pmid in tqdm(pmids)
)

for result in results:
    if result is not None:
        pmid, response = result
        responses[pmid] = response

with open(datadir / 'coordinate_extraction/coord_results_llama3.json', 'w') as f:
    json.dump(responses, f, indent=4)


100%|██████████| 20/20 [00:00<00:00, 1260.95it/s]
100%|██████████| 20/20 [00:18<00:00,  1.09it/s]


In [12]:
len(responses)

520

Run using llama 3.3 via together

In [10]:
responses_fixed = {str(k): v for k, v in responses.items()}
with open(datadir / 'coordinate_extraction/coord_results_llama3.json', 'w') as f:
    json.dump(responses_fixed, f, indent=4)


In [6]:
len(responses)

500

In [18]:
responses

{'7563784': [{'TEA vs. TGA': {'coordinates': [{'x': -30,
      'y': -30,
      'z': -25,
      'cluster_size': None,
      'label': 'Parahippocampal gyrus (l)',
      'statistic': 'T Value',
      'coordinate_type': 'MNI'},
     {'x': -20,
      'y': -20,
      'z': -20,
      'cluster_size': None,
      'label': 'Parahippocampal gyrus (l)',
      'statistic': 'T Value',
      'coordinate_type': 'MNI'},
     {'x': -30,
      'y': -5,
      'z': -40,
      'cluster_size': None,
      'label': 'Uncus (l)',
      'statistic': 'T Value',
      'coordinate_type': 'MNI'},
     {'x': -25,
      'y': -5,
      'z': -25,
      'cluster_size': None,
      'label': 'Uncus (l)',
      'statistic': 'T Value',
      'coordinate_type': 'MNI'},
     {'x': -25,
      'y': -20,
      'z': -30,
      'cluster_size': None,
      'label': 'Parahippocampal gyrus (l)',
      'statistic': 'T Value',
      'coordinate_type': 'MNI'},
     {'x': -25,
      'y': -20,
      'z': -25,
      'cluster_size': None,
  