In [5]:
import json
import random
from collections import defaultdict
from openai import OpenAI
from tqdm import tqdm
import ast
import re
import time
import os
import pprint
from itertools import product, islice
import tiktoken
import json
import hashlib
import nest_asyncio
import asyncio
import aiohttp
import time
from tqdm.asyncio import tqdm as tqdm_async
import json
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from openai import OpenAI
from openai import RateLimitError


In [6]:
with open("/Users/paolocadei/Documents/Masters/Thesis/Spider2/2_final_structure_all.json", "r") as f:
    data = json.load(f)

with open("/Users/paolocadei/Documents/Masters/Thesis/Spider2/api_keys/api_keys.json", "r") as f:
    api_keys = dict(json.load(f))

client = OpenAI(
        api_key = api_keys["open_ai_key"],
    )

In [3]:
# === SETTINGS ===
SAVE_EVERY = 500
SAVE_PATH = "3_dense_cols_embeddings.json"
MAX_WORKERS = 50
BATCH_SIZE = 10
EMBEDDING_MODEL = "text-embedding-3-small"
MAX_TOKENS_PER_BATCH = 2000

# === INITIALIZATION ===
total_embeddings_done = 0
texts_to_embed = []
new_embeddings = {}
api_call_counter = 0

# === TOKENIZER ===
enc = tiktoken.encoding_for_model(EMBEDDING_MODEL)

def count_tokens(texts):
    return sum(len(enc.encode(text)) for text in texts)

# === SAFE LOAD FUNCTIONS ===

def safe_load_json(path):
    try:
        with open(path, 'r') as f:
            return json.load(f)
    except json.JSONDecodeError:
        print(f"⚠️ Warning: {path} is corrupted or incomplete. Starting fresh.")
        return {}

def save_embeddings_to_file():
    tmp_save_path = SAVE_PATH + ".tmp"
    with open(tmp_save_path, 'w') as f:
        json.dump(new_embeddings, f)
    os.replace(tmp_save_path, SAVE_PATH)

# === LOAD EXISTING EMBEDDINGS ===
if os.path.exists(SAVE_PATH):
    new_embeddings = safe_load_json(SAVE_PATH)
    print(f"🔄 Loaded existing embeddings from {SAVE_PATH}.")
else:
    print(f"🆕 No existing embeddings found. Starting fresh.")

# === EMBEDDING FUNCTION ===

def get_embeddings_batch(texts, model=EMBEDDING_MODEL, client=client):
    global api_call_counter
    tokens = count_tokens(texts)
    try:
        response = client.embeddings.create(input=texts, model=model)
        api_call_counter += 1
        return [item.embedding for item in response.data]
    except RateLimitError:
        print("⚡ Rate limit hit. Sleeping for 5 minutes...")
        time.sleep(300)
        return get_embeddings_batch(texts, model=model, client=client)

# === GATHER TEXTS TO EMBED ===

for database in data:
    if database not in new_embeddings:
        new_embeddings[database] = {}

    for table in data[database]:
        if table not in new_embeddings[database]:
            new_embeddings[database][table] = {'grouped': {}, 'ungrouped': {}}

        # GROUPED
        for template in data[database][table].get('grouped', {}):
            if template not in new_embeddings[database][table]['grouped']:
                new_embeddings[database][table]['grouped'][template] = []

            for group_index, group_entry in enumerate(data[database][table]['grouped'][template]):
                column_descriptions = group_entry['details']['description']

                # Ensure structure alignment
                while len(new_embeddings[database][table]['grouped'][template]) <= group_index:
                    new_embeddings[database][table]['grouped'][template].append({
                        'details': {'column_embeddings': {}}
                    })

                for col_name, col_description in column_descriptions.items():
                    if not col_description:
                        print(col_description)
                        continue

                    already_embedded = col_name in new_embeddings[database][table]['grouped'][template][group_index]['details']['column_embeddings']
                    if already_embedded:
                        total_embeddings_done += 1
                    else:
                        texts_to_embed.append((database, table, ('grouped', template, group_index), col_name, col_description))

        # UNGROUPED
        for ungrouped_key, ungrouped_entry in data[database][table].get('ungrouped', {}).items():
            if ungrouped_key not in new_embeddings[database][table]['ungrouped']:
                new_embeddings[database][table]['ungrouped'][ungrouped_key] = {
                    'details': {'column_embeddings': {}}
                }

            column_descriptions = ungrouped_entry['details']['description']

            for col_name, col_description in column_descriptions.items():
                if not col_description:
                    continue

                already_embedded = col_name in new_embeddings[database][table]['ungrouped'][ungrouped_key]['details']['column_embeddings']
                if already_embedded:
                    total_embeddings_done += 1
                else:
                    texts_to_embed.append((database, table, ('ungrouped', ungrouped_key), col_name, col_description))

# === DISPLAY PROGRESS COUNTS ===
total_columns = total_embeddings_done + len(texts_to_embed)
print(f"\n📋 Total columns: {total_columns}")
print(f"✅ Already embedded: {total_embeddings_done}")
print(f"📝 To embed now: {len(texts_to_embed)}\n")

if len(texts_to_embed) > 0:

    # === SMART BATCHING ===
    smart_batches = []
    current_batch = []
    current_batch_tokens = 0

    for item in texts_to_embed:
        text = item[4]
        tokens = len(enc.encode(text))

        if len(current_batch) >= BATCH_SIZE or (current_batch_tokens + tokens) > MAX_TOKENS_PER_BATCH:
            smart_batches.append(current_batch)
            current_batch = []
            current_batch_tokens = 0

        current_batch.append(item)
        current_batch_tokens += tokens

    if current_batch:
        smart_batches.append(current_batch)

    print(f"📦 Total smart batches created: {len(smart_batches)}")

    # === PARALLEL EXECUTION ===
    executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
    futures = []

    start_time = time.time()

    with tqdm(total=len(texts_to_embed), desc="Embedding columns") as pbar:
        for batch in smart_batches:
            texts_only = [item[4] for item in batch]
            future = executor.submit(get_embeddings_batch, texts_only)
            futures.append((future, batch))

        for future, batch in futures:
            embeddings = future.result()
            for (database, table, location, col_name, _), embedding in zip(batch, embeddings):
                if location[0] == 'grouped':
                    template, index = location[1], location[2]
                    new_embeddings[database][table]['grouped'][template][index]['details']['column_embeddings'][col_name] = embedding
                else:
                    ungrouped_key = location[1]
                    new_embeddings[database][table]['ungrouped'][ungrouped_key]['details']['column_embeddings'][col_name] = embedding

                total_embeddings_done += 1
                pbar.update(1)

                if total_embeddings_done % SAVE_EVERY == 0:
                    save_embeddings_to_file()

    elapsed = time.time() - start_time

    # === FINAL SAVE ===
    save_embeddings_to_file()

    print(f"\n✅ Completed {len(texts_to_embed)} new embeddings in {elapsed:.2f} seconds.")
    print(f"📈 Total embeddings now stored: {total_embeddings_done}/{total_columns}")
    print(f"📞 Total OpenAI API calls made (including retries): {api_call_counter}")

🔄 Loaded existing embeddings from 3_dense_cols_embeddings.json.

📋 Total columns: 134095
✅ Already embedded: 134095
📝 To embed now: 0



### Table embeddings

In [8]:
import os
import json
import time
import tiktoken
from openai import OpenAI, RateLimitError
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

# === SETTINGS ===
SAVE_EVERY = 50
SAVE_PATH = "3_table_embeddings.json"
MAX_WORKERS = 5
BATCH_SIZE = 10
EMBEDDING_MODEL = "text-embedding-3-small"
MAX_TOKENS_PER_BATCH = 2000

# === INITIALIZATION ===
total_embeddings_done = 0
texts_to_embed = []
new_embeddings = {}
api_call_counter = 0

enc = tiktoken.encoding_for_model(EMBEDDING_MODEL)

def count_tokens(texts):
    return sum(len(enc.encode(text)) for text in texts)

def safe_load_json(path):
    try:
        with open(path, 'r') as f:
            return json.load(f)
    except json.JSONDecodeError:
        print(f"⚠️ Warning: {path} is corrupted or incomplete. Starting fresh.")
        return {}

def save_embeddings_to_file():
    tmp_save_path = SAVE_PATH + ".tmp"
    with open(tmp_save_path, 'w') as f:
        json.dump(new_embeddings, f)
    os.replace(tmp_save_path, SAVE_PATH)

if os.path.exists(SAVE_PATH):
    new_embeddings = safe_load_json(SAVE_PATH)
    print(f"🔄 Loaded existing embeddings from {SAVE_PATH}.")
else:
    print(f"🆕 No existing embeddings found. Starting fresh.")

def get_embeddings_batch(texts, model=EMBEDDING_MODEL, client=client):
    global api_call_counter
    try:
        response = client.embeddings.create(input=texts, model=model)
        api_call_counter += 1
        return [item.embedding for item in response.data]
    except RateLimitError:
        print("⚡ Rate limit hit. Sleeping for 5 minutes...")
        time.sleep(300)
        return get_embeddings_batch(texts, model=model, client=client)

# === PREVIEW FIRST 2 GROUPED AND 2 UNGROUPED DESCRIPTIONS ===
print("\n🔍 Previewing first 2 grouped and 2 ungrouped table descriptions:\n")
grouped_count = 0
ungrouped_count = 0

for database in data:
    for table in data[database]:
        # GROUPED
        for template in data[database][table].get("grouped", {}):
            entries = data[database][table]["grouped"][template]
            if entries:
                desc = entries[0].get("description", "")
                print(f"[GROUPED] {database}.{table} — {template}:\n{desc}\n")
                grouped_count += 1
                if grouped_count == 2:
                    break
        if grouped_count == 2:
            break
    if grouped_count == 2:
        break

for database in data:
    for table in data[database]:
        # UNGROUPED
        for ungrouped_table, entry in data[database][table].get("ungrouped", {}).items():
            desc = entry.get("description", "")
            print(f"[UNGROUPED] {database}.{table} — {ungrouped_table}:\n{desc}\n")
            ungrouped_count += 1
            if ungrouped_count == 2:
                break
        if ungrouped_count == 2:
            break
    if ungrouped_count == 2:
        break

# === GATHER DESCRIPTIONS FROM GROUPED AND UNGROUPED ===
for database in data:
    if database not in new_embeddings:
        new_embeddings[database] = {}

    for table in data[database]:
        if table not in new_embeddings[database]:
            new_embeddings[database][table] = {'grouped': {}, 'ungrouped': {}}

        # GROUPED TEMPLATES
        for template in data[database][table].get('grouped', {}):
            if template in new_embeddings[database][table]['grouped']:
                if 'table_embedding' in new_embeddings[database][table]['grouped'][template]:
                    total_embeddings_done += 1
                    continue

            entries = data[database][table]['grouped'][template]
            if entries:
                desc = entries[0].get('description', '')
                if desc:
                    texts_to_embed.append((database, table, 'grouped', template, desc))

        # UNGROUPED TABLES
        for ungrouped_table, entry in data[database][table].get('ungrouped', {}).items():
            if ungrouped_table in new_embeddings[database][table]['ungrouped']:
                if 'table_embedding' in new_embeddings[database][table]['ungrouped'][ungrouped_table]:
                    total_embeddings_done += 1
                    continue
            desc = entry.get('description', '')
            if desc:
                texts_to_embed.append((database, table, 'ungrouped', ungrouped_table, desc))

# === DISPLAY PROGRESS COUNTS ===
total_tables = total_embeddings_done + len(texts_to_embed)
print(f"\n📋 Total tables: {total_tables}")
print(f"✅ Already embedded: {total_embeddings_done}")
print(f"📝 To embed now: {len(texts_to_embed)}\n")

if len(texts_to_embed) > 0:

    # === SMART BATCHING ===
    smart_batches = []
    current_batch = []
    current_batch_tokens = 0

    for item in texts_to_embed:
        text = item[4]
        tokens = len(enc.encode(text))

        if len(current_batch) >= BATCH_SIZE or (current_batch_tokens + tokens) > MAX_TOKENS_PER_BATCH:
            smart_batches.append(current_batch)
            current_batch = []
            current_batch_tokens = 0

        current_batch.append(item)
        current_batch_tokens += tokens

    if current_batch:
        smart_batches.append(current_batch)

    print(f"📦 Total smart batches created: {len(smart_batches)}")

    # === PARALLEL EXECUTION ===
    executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
    futures = []

    start_time = time.time()

    with tqdm(total=len(texts_to_embed), desc="Embedding table descriptions") as pbar:
        for batch in smart_batches:
            texts_only = [item[4] for item in batch]
            future = executor.submit(get_embeddings_batch, texts_only)
            futures.append((future, batch))

        for future, batch in futures:
            embeddings = future.result()
            for (database, table, kind, name, _), embedding in zip(batch, embeddings):
                if name not in new_embeddings[database][table][kind]:
                    new_embeddings[database][table][kind][name] = {}
                new_embeddings[database][table][kind][name]['table_embedding'] = embedding
                total_embeddings_done += 1
                pbar.update(1)

                if total_embeddings_done % SAVE_EVERY == 0:
                    save_embeddings_to_file()

    elapsed = time.time() - start_time

    # === FINAL SAVE ===
    save_embeddings_to_file()

    print(f"\n✅ Completed {len(texts_to_embed)} new embeddings in {elapsed:.2f} seconds.")
    print(f"📈 Total embeddings now stored: {total_embeddings_done}/{total_tables}")
    print(f"📞 Total OpenAI API calls made (including retries): {api_call_counter}")


🔄 Loaded existing embeddings from 3_table_embeddings.json.

🔍 Previewing first 2 grouped and 2 ungrouped table descriptions:

[GROUPED] NEW_YORK.NEW_YORK — TLC_GREEN_TRIPS_201{variable0}.json:
This group of tables captures data on green taxi trips in New York City, detailing trip metrics such as fare amounts, location coordinates, timestamps, passenger counts, and payment methods for different years, providing insights for transportation analysis and urban mobility planning.

[GROUPED] NEW_YORK.NEW_YORK — TLC_YELLOW_TRIPS_20{variable0}.json:
This group of tables captures taxi trip data in New York City, detailing metrics such as trip distances, payment methods, surcharges, and passenger counts over different years, facilitating analysis of transportation trends and economic factors.

[UNGROUPED] NEW_YORK.NEW_YORK — TREE_CENSUS_2015.json:
The TREE_CENSUS_2015 table contains detailed data about street trees in New York City, including their locations, health statuses, and various conditi

In [12]:
new_embeddings['NEW_YORK']['NEW_YORK']['ungrouped']['TREE_CENSUS_2015.json']

{'table_embedding': [0.0025230799801647663,
  0.017231358215212822,
  0.023056533187627792,
  0.026184221729636192,
  0.013568822294473648,
  0.0016815689159557223,
  -0.00956910103559494,
  -0.03141641616821289,
  0.018463829532265663,
  0.012510756030678749,
  0.003150943201035261,
  -0.010127201676368713,
  0.007795968558639288,
  0.009214473888278008,
  -0.00036625354550778866,
  -0.012045672163367271,
  0.004412483423948288,
  5.436576248030178e-05,
  0.01467339601367712,
  0.013429297134280205,
  0.06483269482851028,
  0.038206640630960464,
  0.01912657357752323,
  -0.020754367113113403,
  0.05711229890584946,
  0.021405484527349472,
  0.006685580592602491,
  0.02434714138507843,
  0.043438833206892014,
  -0.012103808112442493,
  -0.00015287815767806023,
  -0.0028282913845032454,
  0.00302013847976923,
  0.05790294334292412,
  0.01958003081381321,
  -0.017405763268470764,
  -0.00854591652750969,
  -0.02457968331873417,
  0.039671655744314194,
  -0.014045532792806625,
  0.00334569