In [None]:
import json
import random
from collections import defaultdict
from openai import OpenAI
from tqdm import tqdm
import ast
import re
import time
import os
import pprint
from itertools import product, islice
import tiktoken

In [None]:
with open("/Users/paolocadei/Documents/Masters/Thesis/Spider2/api_keys/api_keys.json", "r") as f:
    api_keys = dict(json.load(f))

client = OpenAI(
        api_key = api_keys["open_ai_key"],
    )
    
model = "gpt-4o-mini-2024-07-18"

In [None]:
with open("/Users/paolocadei/Documents/Masters/Thesis/Spider2/0_final_preprocessed.json", "r") as f:
    schema_groups = json.load(f)

In [None]:
# Choose the number of databases you would like to sample
sample_size = len(schema_groups)

# Randomly sample keys (database names) from the full schema_groups
sampled_keys = random.sample(list(schema_groups.keys()), sample_size)

# Build the new sampled dictionary
sampled_schema_groups = {k: schema_groups[k] for k in sampled_keys}

In [None]:
sampled_schema_groups

In [None]:
def count_tokens(text, model="gpt-4o-mini-2024-07-18"):
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        encoding = tiktoken.get_encoding("cl100k_base")
    return len(encoding.encode(text))

def sample_top_n_columns(columns, descriptions, sample_row, n):
    zipped = list(zip(columns.items(), descriptions))

    described = [item for item in zipped if item[1]]
    others = [item for item in zipped if not item[1]]

    sampled = described[:n]
    if len(sampled) < n:
        sampled += others[: n - len(sampled)]

    if not sampled and len(zipped) > 0:
        sampled = zipped[:n]

    return sampled

def create_prompt(
    database,
    database2,
    table_template,
    columns,
    descriptions,
    sample_row,
    model="gpt-4o-mini-2024-07-18",
    token_limit=20000,
    allow_sampling_if_too_long=True,
    is_grouped=True,
    combinations=None,
    variable_order=None,
):
    # Normalize sample_row to always be a list of dicts
    if isinstance(sample_row, dict):
        sample_row = [sample_row]
    elif not isinstance(sample_row, list):
        sample_row = []
    sample_row = [row for row in sample_row if isinstance(row, dict)]

    def get_column_sample(col_name):
        # Return up to 2 sample values from sample_row
        if not sample_row:
            return ["N/A"]
        return [row.get(col_name, "N/A") for row in sample_row[:2]]

    def build_prompt(selected_columns, all_columns_used=False):
        if is_grouped:
            prompt = (
                f"**Database**: {database}.{database2}\n"
                f"**Table template**: {table_template}\n"
                f"Variables: {', '.join(variable_order)}\n"
                f"Sample values: {combinations[:5]}\n"
                "Each value generates a specific table instance (e.g. 'TLC_YELLOW_TRIPS_2016.json').\n\n"
            )
        else:
            prompt = (
                f"**Database**: {database}.{database2}\n"
                f"**Table name**: {table_template}\n\n"
            )

        column_text = "**Columns**:\n"
        for (col_name, col_type), desc in selected_columns:
            samples = get_column_sample(col_name)
            if all_columns_used and len(samples) > 1:
                sample_str = f"[{samples[0]}, {samples[1]}]"
            else:
                sample_str = samples[0]
            desc_text = desc if desc else ""
            column_text += f"- {col_name} (type: {col_type}, sample: {sample_str}) – {desc_text}\n"

        return prompt + column_text

    total_columns = len(columns)

    if not allow_sampling_if_too_long:
        selected_columns = sample_top_n_columns(columns, descriptions, sample_row, total_columns)
        prompt = build_prompt(selected_columns, all_columns_used=True)
        token_count = count_tokens(prompt, model=model)
        return prompt

    for n_cols in range(total_columns, -1, -1):
        selected_columns = sample_top_n_columns(columns, descriptions, sample_row, n_cols)
        all_used = (n_cols == total_columns)
        prompt = build_prompt(selected_columns, all_columns_used=all_used)
        token_count = count_tokens(prompt, model=model)

        if token_count <= token_limit:
            return prompt

    print(f"⚠️ Prompt exceeds limit even with 0 columns. Returning fallback.")
    return build_prompt([], all_columns_used=False)



def clean_and_parse_result(result):
    """
    Cleans LLM response and returns a parsed Python object.
    Handles triple backticks, extra text, or malformed JSON by trying both json and ast parsing.
    Supports top-level dicts and lists.
    """
    # Step 1: Remove triple backticks or ```json
    cleaned = re.sub(r"^```(?:json)?\n?", "", result.strip())
    cleaned = re.sub(r"\n?```$", "", cleaned.strip())

    # Step 2: Try parsing the entire cleaned text as JSON
    try:
        return json.loads(cleaned)
    except json.JSONDecodeError:
        pass

    # Step 3: Try parsing as a Python literal
    try:
        return ast.literal_eval(cleaned)
    except Exception:
        pass

    # Step 4: Try to extract a list or dict substring (fallback)
    match = re.search(r"(\[.*\]|\{.*\})", cleaned, re.DOTALL)
    if match:
        snippet = match.group(0)
        try:
            return json.loads(snippet)
        except json.JSONDecodeError:
            try:
                return ast.literal_eval(snippet)
            except Exception as e:
                print("⚠️ Failed to parse fallback snippet")
                print(snippet)
                raise e

    print("⚠️ Could not parse result at all.")
    raise ValueError("Could not parse cleaned result.")




## Creating the prompts for table description generation

Initial structure of schema group:

- database (1st folder)
    - database2 (2nd folder)
        - grouped
            - table template (.json)
                - **lists** of dictionaries that share the same template with keys
                    - variables <- dictionary with key = variable0, variable1, etc...: value = list of values that variables can take
                    - combinations <- possible combinations the variables can take in the template
                    - variable_order <- order in which the variables should be inserted in the template
                    - details <- dictionary
                        - columns <- column_name: column_type as string
                        - description <- column_description
                        - sample_row <- a sample of one row from it if available
        - ungrouped
            - table_name (.json)
                - columns <- column_name: column_type as string
                - description <- column_description
                - sample_row <- a sample of one row from it if available

After this step:

- database (1st folder)
    - database2 (2nd folder)
        - grouped
            - table template (.json)
                - **lists** of dictionaries that share the same template with keys
                    - variables <- dictionary with key = variable0, variable1, etc...: value = list of values that variables can take
                    - combinations <- possible combinations the variables can take in the template
                    - variable_order <- order in which the variables should be inserted in the template
                    - details <- dictionary
                        - columns <- column_name: column_type as string
                        - description <- column_description
                        - sample_row <- a sample of one row from it if available
                    - prompt <- prompt used to make a request of the LLM for the description of the table
                    - description <- description outputted by the LLM for the table
        - ungrouped
            - table_name (.json)
                - columns <- column_name: column_type as string
                - description <- column_description
                - sample_row <- a sample of one row from it if available
                - prompt <- prompt used to make a request of the LLM for the description of the table
                - description <- description outputted by the LLM for the table

In [None]:
outputs = []

# database level
for database in tqdm(sampled_schema_groups, desc="Databases"):

    # this is the second folder, constituting the smaller database
    for database2 in sampled_schema_groups[database]:


        # here we will be taking into account the GROUPED SCHEMAS

        for tables in sampled_schema_groups[database][database2]['grouped']:

            for t in range(len(sampled_schema_groups[database][database2]['grouped'][tables])):

                combinations = sampled_schema_groups[database][database2]['grouped'][tables][t]['combinations']
                variable_order = sampled_schema_groups[database][database2]['grouped'][tables][t]['variable_order']

                columns = sampled_schema_groups[database][database2]['grouped'][tables][t]['details']['columns']
                col_desc = sampled_schema_groups[database][database2]['grouped'][tables][t]['details']['description']
                sample_row = sampled_schema_groups[database][database2]['grouped'][tables][t]['details']['sample_row']


                prompt = create_prompt(
                            database = database,
                            database2 = database2,
                            table_template = tables,
                            columns = columns,
                            descriptions = col_desc,
                            sample_row = sample_row,
                            allow_sampling_if_too_long=True,
                            is_grouped=True,
                            combinations=combinations,
                            variable_order=variable_order
                        )

                sampled_schema_groups[database][database2]['grouped'][tables][t]['prompt'] = prompt

                sampled_schema_groups[database][database2]['grouped'][tables][t]['system_role'] = (
                            "You are generating metadata for a **grouped database table template**.\n"
                            "Each table in the group follows the same schema but varies by year, region, or other dimensions.\n"
                            "Your task is to describe what kind of data is captured across the group — not for a single instance.\n\n"
                            "Instructions:\n"
                            "- Write a concise 1–2 sentence **description** of what this group of tables contains, based on the database name, table name, column names, types, description and sample row.\n"
                            "- Then provide at least 10 **keywords or phrases** users might use to search for tables like this.\n"
                            "- Focus on concepts, entities, or analysis that the data enables. Avoid file names or exact column names.\n\n"
                            "**Output format**:\n"
                            "{\n"
                            '  "description": "...",\n'
                            '  "keywords": ["...", "...", ...]\n'
                            "}"
                        )
            
        for table in sampled_schema_groups[database][database2]['ungrouped']:

                columns = sampled_schema_groups[database][database2]['ungrouped'][table]['columns']
                col_desc = sampled_schema_groups[database][database2]['ungrouped'][table]['description']
                sample_row = sampled_schema_groups[database][database2]['ungrouped'][table]['sample_row']
                
                prompt = create_prompt(
                            database = database,
                            database2 = database2,
                            table_template = table,
                            columns = columns,
                            descriptions = col_desc,
                            sample_row = sample_row,
                            allow_sampling_if_too_long=True,
                            is_grouped=False
                        )

                sampled_schema_groups[database][database2]['ungrouped'][table]['prompt'] = prompt

                sampled_schema_groups[database][database2]['ungrouped'][table]['system_role'] = (
                            "You are generating metadata for a **single database table**.\n"
                            "This table is not part of a larger group or template.\n"
                            "Your task is to describe exactly what this table contains and how users might refer to it, based on the database name, table name, column names, types, descriptions and sample row\n\n"
                            "Instructions:\n"
                            "- Write a concise 1–2 sentence **description** of what this table contains.\n"
                            "- Then provide at least 10 **keywords or phrases** users might use to search for this table.\n"
                            "- Focus on meaningful terms and real-world concepts. Avoid repeating column names unless standard terminology.\n\n"
                            "**Output format**:\n"
                            "{\n"
                            '  "description": "...",\n'
                            '  "keywords": ["...", "...", ...]\n'
                            "}"
                        )


## Getting the batch of prompts to send to openAI

In [None]:
def generate_jsonl_batch_file(
    sampled_schema_groups,
    output_path="1_batch_requests.jsonl",
    model="gpt-4o-mini-2024-07-18",
    max_tokens=1000
):
    jsonl_lines = []

    for database in tqdm(sampled_schema_groups, desc="Generating JSONL"):
        for database2 in sampled_schema_groups[database]:

            # GROUPED TABLES
            for table in sampled_schema_groups[database][database2].get("grouped", {}):
                for idx, entry in enumerate(sampled_schema_groups[database][database2]["grouped"][table]):
                    prompt = entry.get("prompt")
                    system_prompt = entry.get("system_role")
                    if prompt:
                        custom_id = f"{database}::{database2}::{table}::grouped::{idx}"
                        jsonl_lines.append({
                            "custom_id": custom_id,
                            "method": "POST",
                            "url": "/v1/chat/completions",
                            "body": {
                                "model": model,
                                "messages": [
                                    {"role": "system", "content": system_prompt},
                                    {"role": "user", "content": prompt}
                                ],
                                "max_tokens": max_tokens
                            }
                        })

            # UNGROUPED TABLES
            for table in sampled_schema_groups[database][database2].get("ungrouped", {}):
                entry = sampled_schema_groups[database][database2]["ungrouped"][table]
                prompt = entry.get("prompt")
                system_prompt = entry.get("system_role")
                if prompt:
                    custom_id = f"{database}::{database2}::ungrouped::{table}"
                    jsonl_lines.append({
                        "custom_id": custom_id,
                        "method": "POST",
                        "url": "/v1/chat/completions",
                        "body": {
                            "model": model,
                            "messages": [
                                {"role": "system", "content": system_prompt},
                                {"role": "user", "content": prompt}
                            ],
                            "max_tokens": max_tokens
                        }
                    })

    # Save to .jsonl
    with open(output_path, "w") as f:
        for line in jsonl_lines:
            f.write(json.dumps(line) + "\n")

    print(f"✅ Batch file saved to: {output_path} with {len(jsonl_lines)} requests.")


In [None]:
generate_jsonl_batch_file(
    sampled_schema_groups,
    output_path="1_openai_batch_prompts.jsonl",
    model="gpt-4o-mini-2024-07-18",  # or another supported chat model
    max_tokens=450
)

## Creating multiple file batches

In [None]:
def split_jsonl_by_token_limit(input_path, token_limit, model="gpt-4o-mini-2024-07-18"):
    tokenizer = tiktoken.encoding_for_model(model)
    batches = []
    current_batch = []
    current_tokens = 0
    token_counts = []  # to store total tokens per batch

    with open(input_path, "r", encoding="utf-8") as f:
        for line in f:
            item = json.loads(line)
            messages = item["body"]["messages"]
            full_prompt = " ".join(m["content"] for m in messages)
            token_count = len(tokenizer.encode(full_prompt))

            if current_tokens + token_count > token_limit:
                batches.append(current_batch)
                token_counts.append(current_tokens)
                current_batch = []
                current_tokens = 0

            current_batch.append(item)
            current_tokens += token_count

        if current_batch:
            batches.append(current_batch)
            token_counts.append(current_tokens)

    # Print summary
    for i, (batch, tokens) in enumerate(zip(batches, token_counts), 1):
        print(f"📦 Batch {i}: {len(batch)} requests, {tokens} tokens")

    return batches

def save_batches_to_files(batches, output_dir="openai_batches", prefix="batch"):
    os.makedirs(output_dir, exist_ok=True)
    paths = []

    for i, batch in enumerate(batches):
        path = os.path.join(output_dir, f"{prefix}_{i+1}.jsonl")
        with open(path, "w", encoding="utf-8") as f:
            for item in batch:
                f.write(json.dumps(item) + "\n")
        paths.append(path)

    return paths


# Split and save
batches = split_jsonl_by_token_limit(
    input_path="/Users/paolocadei/Documents/Masters/Thesis/Spider2/1_openai_batch_prompts.jsonl",
    token_limit=500000
)

jsonl_paths = save_batches_to_files(batches, output_dir="openai_batches")

## Sending batches to OpenAI

In [None]:
import os
import time
import json
from datetime import datetime

def submit_batches_sequentially(
    client,
    jsonl_paths,
    output_dir="1_batch_outputs",
    log_file="1_completed_batches.txt",
    failed_log_file="1_failed_batches.jsonl",
    poll_interval=100
):
    os.makedirs(output_dir, exist_ok=True)

    for path in jsonl_paths:
        print(f"🚀 Submitting: {path}")
        try:
            # Upload file
            file_obj = client.files.create(file=open(path, "rb"), purpose="batch")

            # Submit batch
            batch = client.batches.create(
                input_file_id=file_obj.id,
                endpoint="/v1/chat/completions",
                completion_window="24h",
                metadata={"description": f"Auto-submitted from {path}"}
            )

            print(f"🆔 Submitted batch ID: {batch.id}")

            # Wait for completion
            while True:
                status = client.batches.retrieve(batch.id).status
                print(f"⏳ Waiting... Current status of {batch.id}: {status}")
                if status == "completed":
                    print(f"✅ Batch {batch.id} completed.")
                    break
                elif status in ["failed", "cancelled"]:
                    print(f"❌ Batch {batch.id} ended with status: {status}")
                    raise RuntimeError(f"Batch {batch.id} failed with status: {status}")
                time.sleep(poll_interval)

            # Fetch output and save to JSONL
            batch = client.batches.retrieve(batch.id)
            response_id = dict(batch).get("output_file_id")

            if response_id:
                file_response = client.files.content(response_id)
                output_path = os.path.join(output_dir, f"{batch.id}_output.jsonl")

                with open(output_path, "w", encoding="utf-8") as out_f:
                    for line in file_response.iter_lines():
                        if line:
                            try:
                                line_json = json.loads(line)
                                content_str = line_json["response"]["body"]["choices"][0]["message"]["content"]
                                assistant_content = json.loads(content_str)
                                formatted = {
                                    "custom_id": line_json.get("custom_id"),
                                    "description": assistant_content.get("description"),
                                    "keywords": assistant_content.get("keywords")
                                }
                                out_f.write(json.dumps(formatted) + "\n")
                            except Exception as e:
                                print(f"❗ Error parsing line: {e}")

                print(f"📝 Saved output to: {output_path}")

            # Log batch ID
            with open(log_file, "a") as log:
                log.write(batch.id + "\n")
                print(f"🗂️ Logged batch ID: {batch.id} to {log_file}")

        except Exception as e:
            error_entry = {
                "file_path": path,
                "timestamp": datetime.now().isoformat(),
                "error_message": str(e)
            }
            print(f"❌ Failed to process batch from {path}: {e}")
            with open(failed_log_file, "a") as f:
                f.write(json.dumps(error_entry) + "\n")

In [None]:
def get_jsonl_paths_from_folder(folder_path):
    return [
        os.path.join(folder_path, fname)
        for fname in sorted(os.listdir(folder_path))
        if fname.endswith(".jsonl")
    ]

# uncomment this for all the batches
'''# Path to the folder with batches
folder_path = "/Users/paolocadei/Documents/Masters/Thesis/Spider2/openai_batches"

# Collect all batch paths
file_paths = get_jsonl_paths_from_folder(folder_path)
'''

# uncomment this for the failed batches
file_paths = []

with open("1_failed_batches.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        if line.strip():
            try:
                entry = json.loads(line)
                file_path = entry.get("file_path")
                if file_path:
                    file_paths.append(file_path)
            except json.JSONDecodeError as e:
                print(f"⚠️ Skipping invalid line: {e}")

print(file_paths)  # or return, or use however you want

# Submit all batches sequentially
submit_batches_sequentially(
    client,
    jsonl_paths=file_paths,
    output_dir="1_batch_outputs",
    log_file="1_completed_batches.txt",
    failed_log_file="1_failed_batches.jsonl",
    poll_interval=100  # every ~1.5 minutes
)



## Submit batches in singular calls for the ones that failed

In [None]:
MAX_PROMPTS = None  # 👈 Set to None to run all
SUCCESS_LOG = "1_manually_recovered_prompts.jsonl"
FAILURE_LOG = "1_manual_failures.jsonl"

def load_prompts_from_failed_batch_file(file_path):
    prompts = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                prompts.append(json.loads(line))
    return prompts

def open_ai_call(client, model, system_prompt, user_prompt, max_tokens=500):
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        max_tokens=max_tokens
    )
    return response.choices[0].message.content

def clean_and_parse_result(result):
    cleaned = re.sub(r"^```(?:json)?\n?", "", result.strip())
    cleaned = re.sub(r"\n?```$", "", cleaned.strip())

    try:
        return json.loads(cleaned)
    except json.JSONDecodeError:
        pass

    try:
        return ast.literal_eval(cleaned)
    except Exception:
        pass

    match = re.search(r"(\[.*\]|\{.*\})", cleaned, re.DOTALL)
    if match:
        snippet = match.group(0)
        try:
            return json.loads(snippet)
        except json.JSONDecodeError:
            try:
                return ast.literal_eval(snippet)
            except Exception as e:
                print("⚠️ Failed to parse fallback snippet")
                print(snippet)
                raise e

    raise ValueError("⚠️ Could not parse cleaned result.")

# Load already successful custom_ids
processed_custom_ids = set()
if os.path.exists(SUCCESS_LOG):
    with open(SUCCESS_LOG, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                try:
                    entry = json.loads(line)
                    processed_custom_ids.add(entry.get("custom_id"))
                except json.JSONDecodeError:
                    continue

# Load all prompts from all failed batch files
failed_file_paths = []
with open("1_failed_batches.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        if line.strip():
            try:
                entry = json.loads(line)
                file_path = entry.get("file_path")
                if file_path:
                    failed_file_paths.append(file_path)
            except json.JSONDecodeError as e:
                print(f"⚠️ Skipping invalid line: {e}")

# Load all prompts from all failed batch files
failed_file_paths = []
with open("1_failed_batches.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        if line.strip():
            try:
                entry = json.loads(line)
                file_path = entry.get("file_path")
                if file_path:
                    failed_file_paths.append(file_path)
            except json.JSONDecodeError as e:
                print(f"⚠️ Skipping invalid line: {e}")

print(f"\n📦 Batch files to process ({len(failed_file_paths)} total):")
for path in failed_file_paths:
    print(f"  - {path}")


all_prompts = []
for file_path in failed_file_paths:
    try:
        prompts = load_prompts_from_failed_batch_file(file_path)
        all_prompts.extend(prompts)
    except Exception as e:
        print(f"❌ Skipped file {file_path}: {e}")

print(f"🧩 {len(all_prompts)} total prompts loaded from {len(failed_file_paths)} batch files.")
print(f"🔄 Resuming from {len(processed_custom_ids)} already processed prompts.")

calls_made = 0
with open(SUCCESS_LOG, "a", encoding="utf-8") as success_f, open(FAILURE_LOG, "a", encoding="utf-8") as fail_f:
    for request in tqdm(all_prompts, desc="🚀 Processing prompts"):
        if MAX_PROMPTS is not None and calls_made >= MAX_PROMPTS:
            print("✅ Reached MAX_PROMPTS limit.")
            break

        custom_id = request.get("custom_id")
        if custom_id in processed_custom_ids:
            continue

        body = request.get("body", {})
        model = body.get("model")
        max_tokens = body.get("max_tokens", 300)
        messages = body.get("messages", [])

        system_prompt = next((m["content"] for m in messages if m["role"] == "system"), "")
        user_prompt = next((m["content"] for m in messages if m["role"] == "user"), "")

        try:
            raw_output = open_ai_call(client, model, system_prompt, user_prompt, max_tokens=max_tokens)
            parsed = clean_and_parse_result(raw_output)

            formatted = {
                "custom_id": custom_id,
                "description": parsed.get("description"),
                "keywords": parsed.get("keywords")
            }

            success_f.write(json.dumps(formatted) + "\n")
            success_f.flush()
            calls_made += 1
            processed_custom_ids.add(custom_id)

        except Exception as e:
            print(f"❌ Failed for {custom_id}: {e}")
            fail_f.write(json.dumps({"custom_id": custom_id, "error": str(e)}) + "\n")
            fail_f.flush()


## Putting everything together

### Taking all the information from the batched requests

In [None]:
def merge_batches_from_ids(batch_id_file="1_completed_batches.txt", output_path="1_merged_openai_output.jsonl"):
    merged_lines = []

    # Read all batch IDs from file
    with open(batch_id_file, "r") as f:
        batch_ids = [line.strip() for line in f if line.strip()]

    print(f"🔍 Found {len(batch_ids)} batch IDs to process")

    for batch_id in batch_ids:
        try:
            print(f"📦 Retrieving batch: {batch_id}")
            batch = client.batches.retrieve(batch_id)
            response_id = dict(batch).get("output_file_id")

            if not response_id:
                print(f"⚠️ No response file for batch {batch_id}")
                continue

            file_response = client.files.content(response_id)

            for line in file_response.iter_lines():
                if line:
                    try:
                        line_json = json.loads(line)
                        content_str = line_json["response"]["body"]["choices"][0]["message"]["content"]
                        assistant_content = json.loads(content_str)
                        formatted = {
                            "custom_id": line_json.get("custom_id"),
                            "description": assistant_content.get("description"),
                            "keywords": assistant_content.get("keywords")
                        }
                        merged_lines.append(formatted)
                    except Exception as e:
                        print(f"❌ Failed to parse line in batch {batch_id}: {e}")

        except Exception as e:
            print(f"❌ Failed to process batch {batch_id}: {e}")

    # Save merged results
    with open(output_path, "w", encoding="utf-8") as f:
        for item in merged_lines:
            f.write(json.dumps(item) + "\n")

    print(f"✅ Merged {len(merged_lines)} results into: {output_path}")

merge_batches_from_ids(
    batch_id_file="1_completed_batches.txt",
    output_path="1_merged_openai_output.jsonl"
)

### Injecting the json of schema information with the descriptions

This includes both the batched descriptions and the singular API calls.

In [None]:
import json

def inject_metadata_from_file(file_path, sampled_schema_groups):
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue

            entry = json.loads(line)
            custom_id = entry.get("custom_id")
            description = entry.get("description")
            keywords = entry.get("keywords")

            if not custom_id or not description or not keywords:
                continue

            parts = custom_id.split("::")

            try:
                if "ungrouped" in parts:
                    # Format: database::database2::ungrouped::table
                    database, database2, _, table = parts
                    target = sampled_schema_groups[database][database2]["ungrouped"][table]
                elif "grouped" in parts:
                    # Format: database::database2::table::grouped::index
                    database, database2, table, _, index = parts
                    target = sampled_schema_groups[database][database2]["grouped"][table][int(index)]
                else:
                    print(f"⚠️ Unknown custom_id format: {custom_id}")
                    continue

                target["description"] = description
                target["keywords"] = keywords

            except (KeyError, IndexError) as e:
                print(f"❗ Failed to update {custom_id}: {e}")

    return sampled_schema_groups


# Load the sampled_schema_groups
with open("1_sampled_schema_groups_with_metadata.json", "r", encoding="utf-8") as f:
    sampled_schema_groups = json.load(f)

# Inject metadata from both files
input_files = [
    "/Users/paolocadei/Documents/Masters/Thesis/Spider2/1_merged_openai_output.jsonl",
    "/Users/paolocadei/Documents/Masters/Thesis/Spider2/1_manually_recovered_prompts.jsonl"
]

for file_path in input_files:
    sampled_schema_groups = inject_metadata_from_file(file_path, sampled_schema_groups)

# Save the updated version
with open("1_sampled_schema_groups_with_metadata.json", "w", encoding="utf-8") as out_f:
    json.dump(sampled_schema_groups, out_f, indent=2)

print("✅ Metadata from both files injected successfully into sampled_schema_groups.")


## Final Checking

Checks if:
- any table descriptions are missing
- any table instances are missing compared to the folder structure in databases

In [None]:
with open('/Users/paolocadei/Documents/Masters/Thesis/Spider2/1_sampled_schema_groups_with_metadata.json') as f:

    data = dict(json.load(f))

count = 0

for database in data.keys():

    for table in data[database].keys():

        for template in data[database][table]['grouped'].keys():

            for t in range(len(data[database][table]['grouped'][template])):

                # Assign to your nested structure
                if 'description' not in data[database][table]['grouped'][template][t]:
                     
                     count +=1
                     print(database, table, template, t)

        for t in data[database][table]['ungrouped']:

            if 'description' not in data[database][table]['ungrouped'][t]:
                print(database, table, t)
                count +=1

print(count)
        
    

In [None]:
def extract_all_generated_paths(nested):
    all_paths = []

    for database, tables in nested.items():
        for table, content in tables.items():
            # Grouped files
            grouped_entries = content.get("grouped", {})
            for template, template_entries in grouped_entries.items():  # template_entries is now a list
                for entry in template_entries:
                    combinations = entry["combinations"]
                    variable_order = entry["variable_order"]

                    for combo in combinations:
                        subs = dict(zip(variable_order, combo))

                        try:
                            filename = template.format(**subs)
                            full_path = f"{database}/{table}/{filename}"
                            all_paths.append(full_path)
                        except KeyError as e:
                            print(f"❌ Missing key in format for {database}/{table} with template '{template}'")
                            print(f"   combo: {combo}")
                            print(f"   subs: {subs}")
                        except Exception as e:
                            print(f"❌ Error formatting template '{template}' with combo {combo}: {e}")

            # Ungrouped files
            ungrouped_files = content.get("ungrouped", [])
            for filename in ungrouped_files:
                full_path = f"{database}/{table}/{filename}"
                all_paths.append(full_path)

    return all_paths



final_paths = extract_all_generated_paths(data)

all_file_paths = final_paths #+ unique_paths

with open("/Users/paolocadei/Documents/Masters/Thesis/Spider2/0_filesystem_structure.json", "r") as f:
    ground_truth = json.load(f)

ground_truth_paths = [f'{k1}/{k2}/{v}' for k1 in ground_truth.keys() for k2 in ground_truth[k1].keys() for v in ground_truth[k1][k2].keys()]

if set() == set(ground_truth_paths) - set(all_file_paths):
    print('All the tables are there!')
else:
    print(f'Missing: f{len(set(ground_truth_paths) - set(all_file_paths))}')