# Part 1 
## Generate 30 samples with less images required

In [None]:
from utils import 

import sqlite3 
import json
import re
from pathlib import Path
from sqlite3 import OperationalError

root = Path('.').cwd().parent
dataset_path = root / 'dataset' / 'mimic_iv_cxr'
db_file = 'mimic_iv_cxr.db'
db_path = root / db_file

def get_connection(db_path):
    assert db_path.exists(), f"Database file {db_path} does not exist"
    
    conn = sqlite3.connect(db_path)
    def func_vqa(s, col_name):
        return True
    conn.create_function("func_vqa", 2, func_vqa)
    return conn

def clean_query(q):
    res = q.replace("%y", "%Y").replace(
        "current_time", "strftime('2105-12-31 23:59:00')"
    )
    return res


def substitute_word_outside_quotes(text, target_word, substitute_word):
    pattern = re.compile(
        r"""(["\'])(?:(?!\1).)*\1|\b{0}\b""".format(
            re.escape(target_word)
        ),
        re.VERBOSE,
    )

    # Function to perform conditional replacement
    def replacer(match):
        if match.group(0).startswith(('"', "'")):
            return match.group(0)  # Return the match as is if inside quotes
        else:
            return substitute_word  # Replace with substitute_word if not inside quotes

    # Perform the substitution
    result = pattern.sub(replacer, text)
    return result

def pre_process_query(q, sep="###"):
    kws = ["except", "intersect", "union"]
    for kw in kws:
        if kw in q.lower():
            _q = substitute_word_outside_quotes(q, kw, sep)
            if q != _q:
                q1, q2 = _q.split(sep)
                return q1, q2, kw
    return q, None, None

def count_required_images(query, db_pqth):
    conn = get_connection(db_path)
    q1, q2, op = pre_process_query(query)
    try:
        if op:
            cursor = conn.cursor()
            cursor.execute(q1)
            res1 = cursor.fetchall()
            cursor = conn.cursor()
            cursor.execute(q2)
            res2 = cursor.fetchall()
            res = len(res1) + len(res2)
        else:
            cursor = conn.cursor()
            cursor.execute(query)
            res = cursor.fetchall()
            if len(res) == 1 and str(res[0][0]).isnumeric() and "count" in query.lower():
                res = res[0][0]
            else:
                res = len(res)
        """
        if res == 0:
            print(f"No result {query}\n{res}")
        if res > 20:
            print(f"id: {d['id']}, count: {res}")
        """
        return res
    except OperationalError as e:
        return 999999 # return a large number to indicate error

# get test_with_scope.json file
test_data = dataset_path / 'test_with_scope.json'
with open(test_data, 'r') as f:
    test_data = json.load(f)
    for d in test_data:
        query = clean_query(d['query'])
        d['tables'] = sql_parser(query).tables
        d['num_required_images'] = count_required_images(query, db_path)

new_test_path = dataset_path / 'test_with_scope_preprocessed.json'
with open(new_test_path, 'w') as f:
    json.dump(test_data, f, indent=2)


In [3]:
"""
select categores: 
Image single 1
Image single 2
multimodle single
"""

categories = ["IMAGE-SINGLE-1", "IMAGE-SINGLE-2", "MULTIMODAL-SINGLE"]
max_num_required_images = 3
min_num_required_images = 1

def filter_dataset(data, categories, min_num_required_images, max_num_required_images):
    new_data = []
    for d in data:
        if d['num_required_images'] >= min_num_required_images and d['num_required_images'] <= max_num_required_images:
            if d['scope'] in categories:
                new_data.append(d)
    return new_data

def sort_dataset(data):
    return sorted(data, key=lambda x: x['num_required_images'])

filtered_test_data = sort_dataset(filter_dataset(test_data, categories, min_num_required_images, max_num_required_images))


In [4]:
len(filtered_test_data)

2295

In [20]:
filtered_path = dataset_path / 'filtered_test_with_scope_preprocessed.json'
with open(filtered_path, 'w') as f:
    json.dump(filtered_test_data, f, indent=2)


In [35]:
"""
# get first 10 samples from each category in filtered dataset
res = []
for scope in categories:
    data_per_scope = []
    for d in filtered_test_data:
        if d['scope'] == scope
            data_per_scope.append(d)
    data_per_scope = sorted(data_per_scope, key=lambda x: x['num_required_images'])
    print(len(data_per_scope))
    res.extend(data_per_scope[:min(10, len(data_per_scope))])
"""

840
468
987


In [46]:
# get first 10 samples from each category in filtered dataset
res = []
max_sapmles_per_category = 10
max_samples_per_answertype = 5

for scope in categories:
    data_per_scope = []
    for d in filtered_test_data:
        if d['scope'] == scope and len(d['answer']) > 0:
            data_per_scope.append(d)
    data_per_scope = sorted(data_per_scope, key=lambda x: x['num_required_images'])
    data_per_scope_single_numeric_value = list(filter(lambda x: len(x['answer']) == 1 and isinstance(x['answer'][0], int), data_per_scope))
    data_per_scope_list_value = list(filter(lambda x: len(x['answer']) > 1, data_per_scope))
    num_single_value = len(data_per_scope_single_numeric_value)
    num_list_value = len(data_per_scope_list_value)
    if num_single_value < max_samples_per_answertype:
        num_list_value = min(max_sapmles_per_category - num_single_value, num_list_value)
    if num_list_value < max_samples_per_answertype:
        num_single_value = min(max_sapmles_per_category - num_list_value, num_single_value)
    else:
        num_single_value = min(max_samples_per_answertype, num_single_value)
        num_list_value = min(max_samples_per_answertype, num_list_value)
    res.extend(data_per_scope_single_numeric_value[:num_single_value])
    res.extend(data_per_scope_list_value[:num_list_value])
    

In [47]:
len(res)

30

In [49]:
sampled_test = dataset_path / 'sampled_test_with_scope_preprocessed_balenced_answer.json'
with open(sampled_test, 'w') as f:
    json.dump(res, f, indent=4)