Uses [sqlglot](https://github.com/tobymao/sqlglot) to extract all tables referenced in a given SQL query.
Adds these tables as a List[str] feature in the `table_id` column of the spider and bird datasets.

In [44]:
from typing import List, Dict, Literal
import datasets
from tqdm.notebook import tqdm
from sqlglot import parse_one, exp
from sqlglot.optimizer.scope import build_scope

In [45]:
def get_tables(sql: str) -> List[str]:
    # https://github.com/tobymao/sqlglot/blob/main/posts/ast_primer.md
    ast = parse_one(sql)
    root = build_scope(ast)
    table_nodes: List[exp.Table] = [source for scope in root.traverse() for alias, (node, source) in scope.selected_sources.items() if isinstance(source, exp.Table)]
    return [n.name for n in table_nodes]

In [46]:
task_to_split_to_processed_dataset: Dict[str, Dict[Literal['train', 'test', 'validation'], datasets.Dataset]] = {}
for task in ["spider", "bird"]:
    task_to_split_to_processed_dataset[task] = {}
    for split in ["train", "test", "validation"]:
        dataset = datasets.load_dataset(f"target-benchmark/spider-queries-{split}")[split]
        all_table_names = []
        for idx, item in enumerate(tqdm(dataset, total=len(dataset), desc=f"{task} - {split}...")):
            table_names = get_tables(item['answer'])
            if len(table_names) == 0:
                print(split)
                print(idx)
                print(item['answer'])
                raise ValueError
            all_table_names.append(table_names)
        # Remove, then add `table_id`
        # This will be of type datasets.Sequence(datasets.Value("string")
        dataset = dataset.remove_columns("table_id")
        dataset = dataset.add_column(
            name="table_id",
            column=all_table_names,
        )
        task_to_split_to_processed_dataset[task][split] = dataset

spider - train...:   0%|          | 0/6997 [00:00<?, ?it/s]

spider - test...:   0%|          | 0/2147 [00:00<?, ?it/s]

spider - validation...:   0%|          | 0/1034 [00:00<?, ?it/s]

bird - train...:   0%|          | 0/6997 [00:00<?, ?it/s]

bird - test...:   0%|          | 0/2147 [00:00<?, ?it/s]

bird - validation...:   0%|          | 0/1034 [00:00<?, ?it/s]

In [51]:
# Save to csv files
for task, split_to_dataset in task_to_split_to_processed_dataset.items():
    for split, dataset in split_to_dataset.items():
        dataset.to_csv(f"{task}_{split}_with_table_id.csv")

Creating CSV from Arrow format:   0%|          | 0/7 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/7 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]