In [None]:

import os
import json
import pandas as pd
from typing import TypedDict, Optional
import pandas as pd
from langgraph.graph import StateGraph, END
from langchain_openai import AzureChatOpenAI
from langchain_openai import AzureOpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.messages import HumanMessage
from langchain.prompts import PromptTemplate
from hashlib import md5
from IPython.display import Image, display
import re

AZURE_OPENAI_CONFIG = {
    "deployment_name": "gpt4-deployment",
    "embedding_deployment": "text-embedding-3-large",
    "api_key": "wBjgqz2HegyKwtsNCInM8T0aGAYsSFQ2sPHrv2N9BNhmmreKVJ1NJQQJ99BDACYeBjFXJ3w3AAAAACOGQOtm",
    "api_base": "https://ai-testgeneration707727059630.openai.azure.com/",
    "api_version": "2024-12-01-preview"
}

DATA_FOLDER = 'C:/QProjects/TestData_AI/New_data'
DDL_FOLDER = 'C:/QProjects/TestData_AI/ddls'
OUTPUT_FOLDER = "generated_test_data_new_data_test"
NUM_RECORDS = 10


class GenerationState(TypedDict):
    tables: Optional[dict[str, pd.DataFrame]]
    ddls: Optional[dict[str, str]]
    foreign_keys: Optional[dict[str, list]]
    vectorstores: Optional[dict]
    generated: Optional[dict]

llm = AzureChatOpenAI(
    deployment_name=AZURE_OPENAI_CONFIG["deployment_name"],
    azure_endpoint=AZURE_OPENAI_CONFIG["api_base"],  # <-- updated
    api_key=AZURE_OPENAI_CONFIG["api_key"],
    api_version=AZURE_OPENAI_CONFIG["api_version"],
    temperature=0,
)

embedding_model = AzureOpenAIEmbeddings(
    deployment=AZURE_OPENAI_CONFIG["embedding_deployment"],
    azure_endpoint=AZURE_OPENAI_CONFIG["api_base"],  # <-- updated
    api_key=AZURE_OPENAI_CONFIG["api_key"],
    api_version=AZURE_OPENAI_CONFIG["api_version"],
)

In [None]:
def load_csvs_node(state: GenerationState) -> GenerationState:
    print("Running load_csvs_node...")

    tables = {}
    if not os.path.exists(DATA_FOLDER):
        raise ValueError(f"DATA_FOLDER does not exist: {DATA_FOLDER}")
    
    files = os.listdir(DATA_FOLDER)
    print("Files in CSV folder:", files)

    for file in files:
        if file.endswith(".csv"):
            table_name = file.replace(".csv", "")
            path = os.path.join(DATA_FOLDER, file)
            df = pd.read_csv(path)
            print(f"Loaded {file} with shape {df.shape}")
            if not df.empty:
                tables[table_name] = df

    if not tables:
        raise ValueError("No non-empty CSV files found to load.")

    print("Returning updated state with tables:", list(tables.keys()))
    print("load_csvs_node=======")
    return {**state, "tables": tables}

In [None]:
def load_ddls_node(state: GenerationState) -> GenerationState:
    ddls = {}
    for file in os.listdir(DDL_FOLDER):
        if file.endswith(".sql"):
            table_name = file.replace(".sql", "")
            with open(os.path.join(DDL_FOLDER, file), "r") as f:
                ddls[table_name] = f.read()
    print("load_ddls_node=======")
    return {**state, "ddls": ddls}

In [None]:
def extract_keys_from_ddl(ddl: str):
    pk_pattern = r"PRIMARY KEY\s*\(([^)]+)\)"
    fk_pattern = r"FOREIGN KEY\s*\(([^)]+)\)\s*REFERENCES\s+(\S+)\(([^)]+)\)"

    primary_keys = []
    foreign_keys = {}

    for match in re.finditer(pk_pattern, ddl, re.IGNORECASE):
        columns = [col.strip() for col in match.group(1).split(",")]
        primary_keys.extend(columns)

    for match in re.finditer(fk_pattern, ddl, re.IGNORECASE):
        fk_column = match.group(1).strip()
        ref_table = match.group(2).strip().split('.')[-1]  # remove schema if present
        ref_column = match.group(3).strip()
        foreign_keys[fk_column] = {"ref_table": ref_table, "ref_column": ref_column}
    print(primary_keys, "primamry keys ===")
    print(foreign_keys, "foreign_keys ===")

    return primary_keys, foreign_keys

def infer_foreign_keys_node(state: GenerationState) -> GenerationState:
    print("Running infer_foreign_keys_node...")
    tables = state['tables']
    ddls = state["ddls"]
    fk_map = {}
    for table_name, df in tables.items():
        fk_map[table_name] = {}
        ddl = ddls.get(table_name, "")
        _, fk_defs = extract_keys_from_ddl(ddl)

        for fk_col, ref_info in fk_defs.items():
            ref_table = ref_info["ref_table"]
            print(ref_table, "ref table==========")
            print(tables.keys(),"tablekeys===========")
            if ref_table in tables:
                existing_values = tables[ref_table][ref_info["ref_column"]].dropna().astype(str).unique().tolist()
                fk_map[table_name][fk_col] = existing_values

    
    for k, v in fk_map.items():
        print(f"{k}: {list(v.keys())}")
    print("Foreign keys : ", fk_map)
    return {**state, "foreign_keys": fk_map}

In [None]:
def build_vectorstores_node(state: GenerationState) -> GenerationState:
    print('==============build_vectorstores_node')
    vectorstores = {}
    for table, df in state['tables'].items():
        docs = df.astype(str).apply(lambda row: ", ".join(row), axis=1).tolist()
        vectorstores[table] = FAISS.from_texts(docs, embedding_model)
    return {**state, "vectorstores": vectorstores}

def find_parent_table(fk_col: str, fk_map: dict[str, dict[str, list]]) -> str:
    print('==============find_parent_table')
    for table, mappings in fk_map.items():
        if fk_col in mappings:
            return table
    return None

In [None]:
def generate_test_data_node(state: GenerationState) -> GenerationState:
    print('==============enerate_test_data_node')
    def is_duplicate(new_row, existing_rows):
        new_hash = md5(json.dumps(new_row, sort_keys=True).encode()).hexdigest()
        return new_hash in existing_rows

    generated_data = {}
    tables = list(state["tables"].keys())
    fk_map = state["foreign_keys"]

    prompt_template = PromptTemplate.from_template("""
You are a test data generator. Generate ONE realistic, non-duplicate row for the table `{table_name}`.

DDL Definition:
{ddl}

Sample Data Context:
{examples}

Foreign Key Constraints (if any):
{fk_values}

Guidelines:
- Follow the DDL strictly (types, nullability, constraints)
- Use realistic names, emails, products, descriptions, prices, timestamps, etc.
- Avoid duplicates (no exact same rows)
- Sometimes leave nullable fields blank
- Respect relationships and existing FK values
- Ensure unique constraints (like emails, phone numbers, user_ids) are followed

Return ONLY a valid JSON object, without extra commentary or markdown.
""")

    for table in tables[:1]:
        df = state["tables"][table]
        table_rows = []
        existing_hashes = set()

        for _ in range(NUM_RECORDS):
            retries = 5
            for _ in range(retries):
                sample_contexts = state["vectorstores"][table].similarity_search("generate", k=3)
                context = "\n".join([doc.page_content for doc in sample_contexts])

                fk_values = fk_map.get(table, {}).copy()
                for fk_col in fk_values:
                    parent_table = find_parent_table(fk_col, fk_map)
                    if parent_table in generated_data:
                        parent_values = generated_data[parent_table][fk_col].dropna().unique().tolist()
                        if parent_values:
                            fk_values[fk_col] = parent_values

                ddl_text = state["ddls"].get(table, "")
                prompt = prompt_template.format(
                    table_name=table,
                    ddl=ddl_text,
                    examples=context,
                    fk_values=json.dumps(fk_values)
                )

                response = llm([HumanMessage(content=prompt)])
                try:
                    new_row = json.loads(response.content.strip())
                    if not is_duplicate(new_row, existing_hashes):
                        existing_hashes.add(md5(json.dumps(new_row, sort_keys=True).encode()).hexdigest())
                        table_rows.append(new_row)
                        break
                except json.JSONDecodeError:
                    continue

        generated_data[table] = pd.DataFrame(table_rows)
    return {**state, "generated": generated_data}
    

def save_outputs_node(state: GenerationState) -> GenerationState:
    os.makedirs(OUTPUT_FOLDER, exist_ok=True)
    print("Saving Generated Data to CSV files...")
    for table, df in state["generated"].items():
        output_path = os.path.join(OUTPUT_FOLDER, f"{table}_generated.csv")
        df.to_csv(output_path, index=False)
    return state

In [None]:
# ================================
# BUILD LANGGRAPH FLOW
# ================================
workflow = StateGraph(GenerationState)

workflow.add_node("load_csvs", load_csvs_node)
workflow.add_node("load_ddls", load_ddls_node)
workflow.add_node("infer_foreign_keys", infer_foreign_keys_node)
workflow.add_node("build_vectorstores", build_vectorstores_node)
workflow.add_node("generate_data", generate_test_data_node)
workflow.add_node("save_data", save_outputs_node)


workflow.set_entry_point("load_csvs")
workflow.add_edge("load_csvs", "load_ddls")
workflow.add_edge("load_csvs", "load_ddls")
workflow.add_edge("load_ddls", "infer_foreign_keys")
workflow.add_edge("infer_foreign_keys", "build_vectorstores")
workflow.add_edge("build_vectorstores", "generate_data")
workflow.add_edge("generate_data", "save_data")
workflow.add_edge("save_data", END)

# ================================
# RUN WORKFLOW
# ================================
flow = workflow.compile()
# display(Image(flow.get_graph().draw_mermaid_png()))

In [None]:
print("STARTING...")
initial_state = GenerationState(
    tables=None,
    ddls=None,
    foreign_keys=None,
    vectorstores=None,
    generated=None
)

final_state = flow.invoke(initial_state)
# final_state = flow.invoke(GenerationState())