Refine the tasks obtained using literature mining with GPT-4

In [1]:
# autoreload
%load_ext autoreload
%autoreload 2
import json
import fasttext
import os
from ontology_learner.json_utils import parse_jsonl_file, load_jsonl, parse_jsonl_task_line
from refinement_utils import load_original_results
from pathlib import Path
import contextlib
import io
import re
import nltk
from nltk.corpus import stopwords
import umap
import numpy as np
from openai import OpenAI
from llm_query.chat_client import ChatClientFactory

nltk.download('stopwords')
english_stopwords = set(stopwords.words('english'))


datadir = Path('/Users/poldrack/Dropbox/data/ontology-learner/data')



[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/poldrack/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:

results_dir = datadir / 'concept_results'

concept_dict = load_original_results(results_dir)

Loaded 7383 concepts from /Users/poldrack/Dropbox/data/ontology-learner/data/concept_results/batch_673d2908c74c81908932e5e6769a3c55.jsonl


In [3]:
concept_dict

{'food_cravings': {'type': 'construct',
  'description': 'Food cravings are intense desires or urges to consume a specific type of food. This construct encompasses both the emotional and physiological aspects that drive an individual to seek certain foods, often irrespective of actual hunger. Food cravings are distinguished from general hunger in their specificity for certain tastes or textures and are influenced by a range of cognitive and emotional factors.',
  'references': ['Hill, A. J. (2007). The psychology of food craving. Proceedings of the Nutrition Society, 66(2), 277-285.',
   'Weingarten, H. P., & Elston, D. (1991). The phenomenology of food cravings. Appetite, 17(1), 37-45.',
   'Hormes, J. M., & Rozin, P. (2010). Perimenstrual Chocolate Craving: Evidence from Symptom Concealment. Appetite, 31(1), 14-26.'],
  'tasks': ['Food Cravings Questionnaire-Trait (FCQ-T)',
   'Food Cravings Questionnaire-State (FCQ-S)',
   'Food Frequency Questionnaire (FFQ)',
   'Visual Analogue Sc

In [15]:
# seriallize each task structure to a string
# after removing the following keys from the dictionary: 'type', 'system_fingerprint', 'model'
concept_strings = {}
english_stopwords.update(['description', 'references',
                     'conditions', 'disorders', 'custom_id'])
for t in concepts:
    t_dict = t.copy()
    for k in ['type', 'system_fingerprint', 'model']:
        t_dict.pop(k)
    dictstr = json.dumps(t_dict)
    # remove punctuation
    dictstr = re.sub(r'[^\w\s]', '', dictstr)
    dictstr = dictstr.lower().replace('_', ' ')
    # remove one letter words
    dictstr = re.sub(r'\b\w\b', '', dictstr)
    # remove <> tags
    dictstr = re.sub(r'<[^>]*>', '', dictstr)
    # remove stopwords
    dictstr = ' '.join([word for word in dictstr.split() if word not in english_stopwords])
    concept_strings[t['custom_id']] = dictstr

# %%

concept_string_file = datadir / 'concept_results' / 'concept_strings.txt'
with open(concept_string_file, 'w') as f:
    for s in concept_strings.values():
        f.write(s + '\n')

# %%

Learn embedding using fasttext

In [16]:
model = fasttext.train_unsupervised(concept_string_file.as_posix(), dim=100)


Read 0M words
Number of words:  10429
Number of labels: 0
Progress: 100.0% words/sec/thread:  130439 lr:  0.000000 avg.loss:  1.968195 ETA:   0h 0m 0s


In [17]:
embeddings = {}
for k, v in concept_strings.items():
    embeddings[k] = model.get_word_vector(v)

# %%


### Cluster tasks to find similar ones

First perform agglomerative clustering to identify sets.

In [18]:
from sklearn.cluster import AgglomerativeClustering
cluster = AgglomerativeClustering(n_clusters=None, distance_threshold=0.5)
cluster.fit(list(embeddings.values()))
print(f'Found {len(set(cluster.labels_))} clusters')
# %%

Found 1367 clusters


In [19]:
# create a dictionary mapping each cluster to a list of task ids
cluster_dict = {}
cluster_embeddings = {}
task_keys = list(embeddings.keys())

for i, l in enumerate(cluster.labels_):
    if l not in cluster_dict:
        cluster_dict[l] = []
    cluster_dict[l].append(task_keys[i])
    # append the embedding to the cluster embedding
    if l not in cluster_embeddings:
        cluster_embeddings[l] = []
    cluster_embeddings[l].append(embeddings[task_keys[i]])
# %%

In [20]:
cluster_dict

{810: ['food_cravings',
  'sensory-specific_satiety',
  'state-based_hunger',
  'subjective_satiety',
  'satiation',
  'palatability',
  'hedonic_hunger'],
 403: ['action-oriented_body_representation',
  'body_perception',
  'body_representation',
  'body_representations',
  'body_schema',
  'nonaction-oriented_body_representation'],
 618: ['conditioned_fear',
  'trace_fear_conditioning',
  'fear_conditioning',
  'conditioned_fear_memory_formation'],
 913: ['noise_suppression',
  'binaural_separation',
  'binaural_masking_level_difference',
  'speaker_segregation',
  'equalization-cancelation',
  'informational_masking'],
 125: ['reactive_balance_control',
  'gait_control_mechanisms',
  'gait_adaptation',
  'reactive_stability',
  'control_of_walking'],
 197: ['utilitarian_inclinations',
  'moral_decision-making',
  'moral_judgment',
  'moral_conflict',
  'moral_judgement',
  'deontological_considerations',
  'utilitarian_considerations',
  'deontological_inclinations',
  'deontology',

For each cluster, create a prompt that will ask the LLM to identify unique tasks within the set.

In [26]:
system_msg = """
You are an expert in psychology and neuroscience.
You should be as specific and as comprehensive as possible in your responses.
Your response should be a JSON object with no additional text.  
"""


def get_construct_set_prompt(constructs):
    prompt = f"""
# CONTEXT #
Researchers in the field of cognitive neuroscience and psychology study specific 
psychological constructs (or concepts), which are the building blocks of the mind, such as 
memory, attention, theory of mind, and so on.  

# OBJECTIVE #
You will be provided with a list of construct names that were identified by a language model based on
publications.  Each set of names may contain duplicate labels for the same construct.  For example, 
for the construct of "psychedelic experience", there might be the following labels:

'psychedelic response',
'psychoactive experiences',
'substance-induced consciousness alterations',
'cognitive (psychedelic experience)',
'psychedelic-induced experiences',
'psychedelic experience'

Your job is to group together all of the labels within the set that most likely refer to the same construct. 
Each set may include multiple labels that refer to the same construct, and multiple constructs 
may be mentioned in the same set.

# DATA #

Here is the list of constructs to be clustered with a brief description of each:

{constructs}

# RESPONSE #
Please return the results in JSON format. 

The result should include a dictionary of dicts.   Each subdict should refer to a single construct that was identified 
from the set, and its key should be a consensus name derived from the set.  For the example above, the consensus name
might be "psychedelic experience".

Each subdict should include two keys:
- "labels": containing a list of labels that match this task
- "description": containing a brief description of the construct that was identified, based on the descriptions provided for
each of the different labels.

Respond only with JSON, without any additional text or description.
"""
    return prompt




In [27]:
cluster_names = cluster_dict[13]

def create_construct_prompt(cluster_names, construct_dict):
    construct_strs = [f"{t.replace('_', ' ').title()}: {construct_dict[t]['description']}" for t in cluster_names]
    return get_construct_set_prompt("\n".join(construct_strs))

print(create_construct_prompt(cluster_names, concept_dict))



# CONTEXT #
Researchers in the field of cognitive neuroscience and psychology study specific 
psychological constructs (or concepts), which are the building blocks of the mind, such as 
memory, attention, theory of mind, and so on.  

# OBJECTIVE #
You will be provided with a list of construct names that were identified by a language model based on
publications.  Each set of names may contain duplicate labels for the same construct.  For example, 
for the construct of "psychedelic experience", there might be the following labels:

'psychedelic response',
'psychoactive experiences',
'substance-induced consciousness alterations',
'cognitive (psychedelic experience)',
'psychedelic-induced experiences',
'psychedelic experience'

Your job is to group together all of the labels within the set that most likely refer to the same construct. 
Each set may include multiple labels that refer to the same construct, and multiple constructs 
may be mentioned in the same set.

# DATA #

Here is the l

### run using openai batch api

In [28]:
api_key = os.environ.get("OPENAI")


client = ChatClientFactory.create_client("openai", api_key, 
                                            system_msg=system_msg,
                                            model="gpt-4o")



In [30]:
construct_batch_file = datadir / 'construct_annotation_batch.jsonl'
if construct_batch_file.exists():
    construct_batch_file.unlink()

ids = []    
for k, cluster_names in cluster_dict.items():
    prompt = create_construct_prompt(cluster_names, concept_dict)
    id = f'cluster-{k}'
    if id in ids:
        print(f"skipping duplicate id: {id}")
        continue
    ids.append(id)
    batch_request = client.create_batch_request(id, prompt)
    with open(construct_batch_file, 'a') as f:
        f.write(json.dumps(batch_request) + '\n')



In [31]:
batch_client = OpenAI(api_key=api_key)


batch_input_file = batch_client.files.create(file=open(construct_batch_file, "rb"),
                                        purpose="batch")

batch_input_file_id = batch_input_file.id

batch_metadata = batch_client.batches.create(
    input_file_id=batch_input_file_id,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={
        "description": "construct clustering"
    }
)


In [33]:
print(batch_client.batches.retrieve(batch_metadata.id).status)
import time
while batch_client.batches.retrieve(batch_metadata.id).status != 'completed':
    time.sleep(60)
    print(batch_client.batches.retrieve(batch_metadata.id).status)
# os.system('say "your program has finished"')


in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
in_progress
finalizing
finalizing
finalizing
finalizing
finalizing
finalizing
finalizing
finalizing
finalizing
finalizing
finalizing
finalizing
finalizing
finalizing
finalizing
completed


In [34]:
# NOTE: This step is run once the batch job has completed.
import sys
sys.path.append('..')
from gpt4_batch_utils import get_batch_results, save_batch_results
batch_results = get_batch_results(batch_client, batch_metadata.id)
outdir = datadir / 'construct_refinement_results'
outdir.mkdir(exist_ok=True, parents=True)

outfile = save_batch_results(batch_results, batch_metadata.id, outdir)

In [35]:



results_raw = load_jsonl(outfile)

results_content = {}
for result in results_raw:
    task = result['custom_id'].split('-')[1]
    content = result['response']['body']['choices'][0]['message']['content']
    content = content.replace('json', '').replace("```", "").replace("\n", "")
    try:
        results_content[task] = json.loads(content)
    except json.JSONDecodeError:
        print(f"error decoding {task}")


In [36]:
results_content

{'810': {'food cravings': {'labels': ['food cravings'],
   'description': 'Food cravings are intense desires or urges to consume a specific type of food, driven by emotional and physiological factors, and significant for their specificity to certain tastes or textures.'},
  'sensory-specific satiety': {'labels': ['sensory-specific satiety'],
   'description': 'Sensory-specific satiety refers to the decrease in the pleasantness of a specific flavor or food after consumption to satiation, influencing food choice and promoting dietary variety.'},
  'state-based hunger': {'labels': ['state-based hunger'],
   'description': "State-based hunger refers to the immediate sensation or motivation to eat, often triggered by physiological signals and reflecting the body's current nutritional needs."},
  'subjective satiety': {'labels': ['subjective satiety'],
   'description': "Subjective satiety refers to an individual's personal perception of fullness and satisfaction following food consumption, 