In [13]:
from pathlib import Path
import json

root = Path.cwd()

# M3AE model
m3ae_en = root / "experiments/xmode/en/xmode-vqa-m3ae-english.json"

with open(m3ae_en, "r") as f:
    data = json.load(f)


In [14]:
data[0]['xmode']

[{'content': [{'question': "given the last study of patient 10284038 in 2105, is the cardiac silhouette's width larger than half of the total thorax width?",
    'database_schema': '\nCREATE TABLE "TB_CXR" (\n\trow_id INTEGER, \n\tsubject_id INTEGER, \n\thadm_id REAL, \n\tstudy_id INTEGER, \n\timage_id TEXT, \n\tviewposition TEXT, \n\tstudydatetime TEXT\n)'}],
  'type': 'human',
  'id': 'fe46614c-50ce-4eab-b18f-c3f9434d31c2'},
 {'content': "{'status': 'success', 'data': [{'study_id': 54252229, 'image_id': '3684deb8-616530e3-55256cd6-044fe3a2-4ac97ba7'}]}",
  'additional_kwargs': {'idx': 1,
   'args': {'problem': 'What is the study_id and image_id for the last study of patient 10284038 in 2105?',
    'context': 'CREATE TABLE "TB_CXR" (\n\trow_id INTEGER, \n\tsubject_id INTEGER, \n\thadm_id REAL, \n\tstudy_id INTEGER, \n\timage_id TEXT, \n\tviewposition TEXT, \n\tstudydatetime TEXT\n)'}},
  'type': 'function',
  'name': 'text2SQL',
  'id': '54a377dd-b108-41a4-b1b2-4f1f622f3ab5'},
 {'cont

In [15]:
def format_func_str(func_name, func_args):
    idx = func_args['idx']
    args = func_args['args']
    args_list = []  
    if isinstance(args, list):
        args = ', '.join(args)
    else:
        # args = ', '.join([f"{k}='{v}'" if '"' in v else f'{k}="{v}"' for k, v in args.items()])
        for k, v in args.items():
            if v[0] == '$':
                arg = f"{k}={v}"
            else:
                if '"' in v:
                    arg = f"{k}='{v}'"
                else:
                    arg = f'{k}="{v}"'
            args_list.append(arg)
        args = ', '.join(args_list)
    return f"{idx}. {func_name}({args})"

def format_human_message_str(msg):
    content = msg['content']
    id = msg['id']
    return f'[HumanMessage(content={str(content)}, id={id})]'

def extract_plan(row):
    msgs = row['xmode']
    # we only need the all type of function messages before the type of "invalid_tool_calls"
    human_plan = ''
    ai_plan = []
    for msg in msgs:
        if msg['type'] == 'human':
            human_plan = format_human_message_str(msg)
        if msg['type'] == 'function':
            ai_plan.append(format_func_str(msg['name'], msg['additional_kwargs']))
        if ('name' in msg.keys()) and (msg['name'] == 'join'):
            break    
    assert len(ai_plan) > 0, "No AI plan found"
    ai_plan.append('<END_OF_PLAN>')
    ai_plan = '\n'.join(ai_plan)
    plan = {"user": human_plan, "assistant":ai_plan}
    return plan

def parse_data(data):
    res = data.copy()
    for row in data:
        plan = extract_plan(row)
        row['plan'] = plan
    return res
res = parse_data(data)
[r['plan'] for r in res]

[{'user': '[HumanMessage(content=[{\'question\': "given the last study of patient 10284038 in 2105, is the cardiac silhouette\'s width larger than half of the total thorax width?", \'database_schema\': \'\\nCREATE TABLE "TB_CXR" (\\n\\trow_id INTEGER, \\n\\tsubject_id INTEGER, \\n\\thadm_id REAL, \\n\\tstudy_id INTEGER, \\n\\timage_id TEXT, \\n\\tviewposition TEXT, \\n\\tstudydatetime TEXT\\n)\'}], id=fe46614c-50ce-4eab-b18f-c3f9434d31c2)]',
  'assistant': '1. text2SQL(problem="What is the study_id and image_id for the last study of patient 10284038 in 2105?", context=\'CREATE TABLE "TB_CXR" (\n\trow_id INTEGER, \n\tsubject_id INTEGER, \n\thadm_id REAL, \n\tstudy_id INTEGER, \n\timage_id TEXT, \n\tviewposition TEXT, \n\tstudydatetime TEXT\n)\')\n2. image_analysis(question="Is the cardiac silhouette\'s width larger than half of the total thorax width?", context=$1)\n3. join()\n<END_OF_PLAN>'},
 {'user': '[HumanMessage(content=[{\'question\': \'given the last study of patient 19243401 th

In [16]:
# find the longest plan in res and get the index
max_len = 0
for r in res:
    if len(r['plan']['assistant']) > max_len:
        max_len = len(r['plan']['assistant'])
        idx = res.index(r)
print(max_len, idx)

2514 25


In [17]:
# find the shortest plan in res
min_len = max_len
min_idx = 0
for r in res:
    if len(r['plan']['assistant']) < min_len:
        min_len = len(r['plan']['assistant'])
        min_idx = res.index(r)
print(min_len, min_idx)

338 6


In [18]:
# find the medien lenth plan in res
lens = [len(r['plan']['assistant']) for r in res]
lens.sort()
median_idx = lens[len(lens)//2]
for r in res:
    if len(r['plan']['assistant']) == median_idx:
        print(median_idx, res.index(r))
        break
    


801 16


In [19]:
all_planer_path = root / "experiments/xmode/en/all_plans_30_en.json"
sampled_plans = [res[i] for i in [6, 16, 25]]
sampled_plans_path = root / "experiments/xmode/en/sampled_plans_30_en.json"
with open(all_planer_path, "w") as f:
    json.dump(res, f, indent=4)
with open(sampled_plans_path, "w") as f:
    json.dump(sampled_plans, f, indent=4)

In [60]:
res

[{'db_id': 'mimic_iv_cxr',
  'split': 'test',
  'id': 5,
  'question': "given the last study of patient 10284038 in 2105, is the cardiac silhouette's width larger than half of the total thorax width?",
  'template': 'given the last study of patient 10284038 in 2105, is the width of the cardiac silhouette wider than 1/2 of the thorax width?',
  'query': 'select func_vqa("is the cardiac silhouette\'s width larger than half of the total thorax width?", t1.study_id) from ( select tb_cxr.study_id from tb_cxr where tb_cxr.study_id in ( select distinct tb_cxr.study_id from tb_cxr where tb_cxr.subject_id = 10284038 and strftime(\'%y\',tb_cxr.studydatetime) = \'2105\' order by tb_cxr.studydatetime desc limit 1 ) ) as t1',
  'value': {'patient_id': 10284038},
  'q_tag': 'given the [time_filter_exact1] study of patient {patient_id} [time_filter_global1], is the width of the cardiac silhouette wider than 1/2 of the thorax width?',
  't_tag': ['abs-year-in', '', '', 'exact-last', ''],
  'o_tag': {}