### Imports

In [1]:
import os
import json
import requests
import traceback

### Gcloud auth 

In [2]:
from google.auth import default
from google.auth.transport.requests import Request
from google.auth.exceptions import DefaultCredentialsError

In [3]:
if "GOOGLE_API_KEY" in os.environ:
    del os.environ["GOOGLE_API_KEY"]
    print(" Removed GOOGLE_API_KEY to use Vertex AI instead")

# Set environment variables for Vertex AI
os.environ["GOOGLE_CLOUD_PROJECT"] = "dataplatr-sandbox"
os.environ["GOOGLE_CLOUD_LOCATION"] = "us-central1"

In [4]:
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "True"

# Verify ADC is available
try:
    from google.auth import default
    credentials, project = default()
    print(f" ADC credentials found!")
    print(f"   Credential type: {type(credentials).__name__}")
except Exception as e:
    print(f" ADC error: {e}")

 ADC credentials found!
   Credential type: Credentials


In [5]:
import os
import json
import re
import textwrap
import asyncio
from typing import Optional, Tuple
from pydantic import BaseModel, Field
import gradio as gr
import uuid
from datetime import datetime

In [6]:
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col 

In [7]:
from google import genai
from google.genai import types
from google.adk.agents import LlmAgent
from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService
from google.genai.types import ThinkingConfig, GenerateContentConfig
from google.adk.planners import BuiltInPlanner


### Generation Configuration:

In [8]:
safety_settings = [
    types.SafetySetting(
        category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
        threshold=types.HarmBlockThreshold.OFF,
    ),
]

In [9]:
generate_content_config = types.GenerateContentConfig(
   safety_settings=safety_settings,
   temperature=0,
   max_output_tokens=8192,
   top_p=0.95,
)

In [10]:
genai_client = genai.Client(
    vertexai=True,
    project="dataplatr-sandbox",
    location="us-central1"
)

print("Vertex AI client initialized")


Vertex AI client initialized


In [11]:
os.environ["PATH"] += os.pathsep + r"C:\Users\pooja\AppData\Local\Google\Cloud SDK\google-cloud-sdk\bin"


In [12]:
import shutil
print(shutil.which("gcloud"))


C:\Users\pooja\AppData\Local\Google\Cloud SDK\google-cloud-sdk\bin\gcloud.CMD


In [13]:
def check_gcp_authentication_native():
    
    """Check GCP authentication using google.auth library."""
    
    print(" Checking GCP Authentication (Python SDK)...")
    print("=" * 60)

    try:
        # Try loading ADC credentials
        credentials, project_id = default()

        if credentials and credentials.valid:
            print(" Application Default Credentials found and valid!")
        else:
            print(" Credentials found but may need refresh.")
            credentials.refresh(Request())
            print(" Refreshed successfully.")

        print(f" Project ID: {project_id or 'Unknown'}")
        print("=" * 60)
        return True

    except DefaultCredentialsError:
        print(" No valid Application Default Credentials found.")
        print(" Run this command to authenticate:")
        print("   gcloud auth application-default login")
    except Exception as e:
        print(f" Error verifying authentication: {e}")
    finally:
        print("=" * 60)

    return False

# Run test
gcp_auth_ok = check_gcp_authentication_native()


 Checking GCP Authentication (Python SDK)...
 Credentials found but may need refresh.
 Refreshed successfully.
 Project ID: dataplatr-sandbox


In [14]:
os.getenv("GOOGLE_API_KEY") or "AIzaSyA3imT3fHzZobQCw1ZBPDqgLbU1GjjGNP0" 

'AIzaSyA3imT3fHzZobQCw1ZBPDqgLbU1GjjGNP0'

### Global variables:

In [15]:
APP_NAME = "reporting_layer_app"
USER_ID = "test_user_23"
SESSION_ID = "session123"
MODEL_NAME = "snowflake-arctic"

In [16]:
TARGET_DB ="DATAPLATR_DEMO"
L2_SCHEMA = "DWH_SI" 
L3_SCHEMA = "L3_SI"   

### Github Configs:

In [17]:
git_repo_url="https://github.com/poojapapney/snf_elt.git"
GITHUB_REPO = "poojapapney/snf_elt"
BRANCH="PP-Test"
GITHUB_TOKEN="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"

### Snowflake Connection Set up:

In [18]:
# Global snf session variable :

SNOWFLAKE_SESSION = None


def init_snowflake_session(target_db: str, target_schema: str):
    """
    Initializes or retrieves a global Snowflake session.
    Works in both Snowflake-native and local environments.
    """
    global SNOWFLAKE_SESSION

    if SNOWFLAKE_SESSION is not None:
        
        print(" Reusing existing Snowflake session.")
        
        return SNOWFLAKE_SESSION

    try:
        # Try to use active Snowflake Notebook session:
        
        SNOWFLAKE_SESSION = get_active_session()
        print("  Successfully connected to active Snowflake session.")
    except Exception: # if not
        print("  No active session found. Creating a new session for local development...")

        with open("configs.json", "r") as f:
            connection_parameters = json.load(f)

        SNOWFLAKE_SESSION = Session.builder.configs(connection_parameters).create()
        print("  Successfully created a new local Snowflake session.")

    # Set database and schema context:
    
    SNOWFLAKE_SESSION.use_database(target_db)
    SNOWFLAKE_SESSION.use_schema(target_schema)

    print(f" Session active with DB: {target_db}, Schema: {target_schema}")
    return SNOWFLAKE_SESSION


In [19]:
SNOWFLAKE_SESSION= init_snowflake_session(TARGET_DB, L2_SCHEMA)

  No active session found. Creating a new session for local development...


DatabaseError: 250001 (08001): Failed to connect to DB: ORQIFID-LB06580.snowflakecomputing.com:443. Your free trial has ended and all of your virtual warehouses have been suspended. Add billing information in the Snowflake web UI to continue using the full set of Snowflake features.

### Gold Layer Agentic Pipeline:

In [None]:
def get_tables_context(database: str, schema: str) -> list:
    """
    List all base tables in the given schema with their descriptions.

    Returns:
        list of dicts, each containing:
        - table_id: fully qualified name
        - description: table comment or empty string
    """
    
    if SNOWFLAKE_SESSION is None:
        
        raise ValueError("No active Snowflake session. Please call init_snowflake_session() first.")

    try:
        query = f"""
            SELECT TABLE_CATALOG AS DATABASE_NAME,
                   TABLE_SCHEMA,
                   TABLE_NAME,
                   TABLE_TYPE,
                   COMMENT AS DESCRIPTION
            FROM {database}.INFORMATION_SCHEMA.TABLES
            WHERE TABLE_SCHEMA = '{schema}'
              AND TABLE_TYPE = 'BASE TABLE'
            ORDER BY TABLE_NAME
        """

        df = SNOWFLAKE_SESSION.sql(query).to_pandas()

        candidates = [
            {
                "table_id": f"{row.DATABASE_NAME}.{row.TABLE_SCHEMA}.{row.TABLE_NAME}",
                "description": row.DESCRIPTION or ""
            }
            for _, row in df.iterrows()
        ]

        return candidates

    except Exception as e:
        return [{"error": f"Failed to list tables for {database}.{schema}: {e}"}]


In [None]:
CONFIG_TEMPLATE = f"""{{{{ config(
    materialized='table' ,
     schema='{L3_SCHEMA}'
)}}}}
"""

In [None]:
def discover_table_schema_snowflake(table_path: str) -> dict:
    
    """
    Discover Snowflake table schema and sample values.
    """
    global SNOWFLAKE_SESSION
    
    if SNOWFLAKE_SESSION is None or not hasattr(SNOWFLAKE_SESSION, "sql"):
        
        raise ValueError("Snowflake session not initialized. Please call init_snowflake_session() first.")

    db, sch, tbl = table_path.split(".")

    try:
        cols_df = SNOWFLAKE_SESSION.sql(f"""
            SELECT COLUMN_NAME, DATA_TYPE, COMMENT
            FROM {db}.INFORMATION_SCHEMA.COLUMNS
            WHERE TABLE_SCHEMA = '{sch}'
              AND TABLE_NAME = '{tbl}'
            ORDER BY ORDINAL_POSITION
        """).to_pandas()


    except Exception as e:
        
        return {"error": f"Failed to fetch schema for {table_path}: {e}"}

    schema_info = [
        {
            "name": row["COLUMN_NAME"],
            "type": row["DATA_TYPE"],
            "description": row["COMMENT"] or "",
            "sample_values": []
        }
        for _, row in cols_df.iterrows()
    ]

    try:
        df = SNOWFLAKE_SESSION.sql(f"SELECT * FROM {table_path} SAMPLE (0.01) LIMIT 5").to_pandas()
        
        for col in df.columns:
            
            values = df[col].dropna().astype(str).unique().tolist()[:5]
            
            for col_info in schema_info:
                
                if col_info["name"] == col:
                    
                    col_info["sample_values"] = values
                    
                    break
    except Exception as e:
        
        print(f"Could not fetch sample values for {table_path}: {e}")

    return {"table": table_path, "columns": schema_info}

In [None]:
def get_approved_tables_schema(approved_tables: list, database: str, schema: str) -> list:
    """
    Fetch schema metadata and sample values for approved Snowflake tables.
    """
    if not approved_tables:
        raise ValueError("approved_tables list is empty.")

    all_schemas = [
        discover_table_schema_snowflake(f"{database}.{schema}.{table}")
        for table in approved_tables
    ]
    return all_schemas

In [None]:
# def validate_sql(sql: str) -> dict:
#     global SNOWFLAKE_SESSION

#     if SNOWFLAKE_SESSION is None:
#         raise ValueError("No active Snowflake session. Please call init_snowflake_session() first.")

#     try:
#         # Wrap in try/except to catch missing columns/tables
#         test_sql = f"SELECT * FROM ({sql}) LIMIT 0"
#         SNOWFLAKE_SESSION.sql(test_sql).collect()
#         return {"status": "valid", "message": "SQL is syntactically and semantically valid."}
#     except Exception as e:
#         return {"status": "invalid", "error": str(e), "sql_snippet": sql}


In [None]:
def validate_sql(sql: str) -> str:
    """
    Validate a SQL query in Snowflake using EXPLAIN/LIMIT 0 approach.
    Always returns a valid JSON string.
    """
    global SNOWFLAKE_SESSION

    try:
        # Run a validation query:
        
        test_sql = f"SELECT * FROM ({sql}) LIMIT 0"
        SNOWFLAKE_SESSION.sql(test_sql).collect()

        result = {
            "status": "valid",
            "message": " SQL is syntactically and semantically valid."
        }

    except Exception as e:
        result = {
            "status": "invalid",
            "error": str(e),
            "trace": traceback.format_exc(limit=2),
            "sql_snippet": sql[:300]
        }

    # For JSON output:
    
    try:
        return json.dumps(result, ensure_ascii=False)
    except Exception as e:
        
        # Fallback if serialization somehow fails:
        
        return json.dumps({
            "status": "error",
            "message": f"Serialization failed: {e}",
            "fallback_result": str(result)
        })


In [None]:
def ensure_schema(session, database_name: str, schema_name: str):
    """Checks if a schema exists and creates it if not."""
    schema_ref = f'"{database_name}"."{schema_name}"'
    try:
        session.sql(f'USE DATABASE "{database_name}"').collect()
        session.sql(f"DESC SCHEMA {schema_ref}").collect()
        print(f"Schema {schema_ref} already exists.")
    except SnowparkSQLException as e:
        if "does not exist" in str(e).lower():
            print(f"Schema {schema_ref} not found. Creating...")
            session.sql(f"CREATE SCHEMA IF NOT EXISTS {schema_ref}").collect()
            print(f"Schema {schema_ref} created successfully.")
        else:
            raise


ensure_schema(SNOWFLAKE_SESSION, TARGET_DB, L3_SCHEMA)

In [None]:
def parse_model_json(raw_text: str):
    # Remove markdown fences if present
    cleaned = re.sub(r"^```json|```$", "", raw_text.strip(), flags=re.MULTILINE).strip()
    return json.loads(cleaned)

In [None]:
TOOL_REGISTRY = {
    'get_approved_tables_schema': get_approved_tables_schema,
    'validate_sql': validate_sql
}

In [None]:
L2_TABLES_DESCRIPTION = get_tables_context(TARGET_DB, L2_SCHEMA)

In [None]:
L2_TABLES_DESCRIPTION

In [None]:
L3_Agent = LlmAgent(

    model=MODEL_NAME,

    name="L3_Agent",

    description="Assistant to create L3 tables in ELT process",

    instruction=f"""
    Follow the following steps:
    1) Gather requirement from user for logic of L3 table creation. Ask clarification questions if needed.
    2) See how requirement can be met from the AVAILABLE TABLES OF L2 given below:
    3) If confused, ask for clarification which tables and fields to use.
    4) For L3 table creation first discuss the table logic in natural language in a crisp way with the user.
    5) For tables to be used get the columns and their description from get_approved_tables_schema tool only. Do NOT suggest generic field names but the ones from the actual tables.
    6) If the user confirms the table logic in natural language then create sql code for it and give it in a different part in output..
    7) Once the user confirms the sql code validate it by giving it a run in snowflake using validate_sql tool.
    8) If the validation of sql fails then handle it accordingly. Always confirm the results of validation to user.
    9) Use the tools from the {TOOL_REGISTRY} only.
    10) There is a sql code editor on the right. So if some message in a part comes like `Manually updated latest SQL:` then it means that you have to take this updated code as it has been manually changed by the user in the editor.
    11) If there is some change requested by the user in the sql code then always generate the full updated sql and not just the edited part.
    12) If the user wants to write the sql into dbt as a model then ask them to press the button on the right.
    13) User has very less time. Always keep the output text concise but always be complete in information.

    Response Formatting Rules:
    
    1. When asking questions or confirming table logic → respond in plain text.
    2. When showing SQL code → wrap the SQL inside triple backticks, like:\n
    "```sql\nSELECT ...\n```\n"
    3. When confirming validation results → always write a short natural-language summary (e.g. : SQL is valid or  Validation failed).\n"
    4. While making a tool call, always add a short line in natural language explaining the act.

    Output Instructions:
    Give sql code in a different Part wrappend in ```sql (code) ``` .

    AVAILABLE TABLES OF L2:
    {L2_TABLES_DESCRIPTION}

    """,

    tools=[
       get_approved_tables_schema,
        validate_sql
    ],

    generate_content_config=types.GenerateContentConfig(
        temperature=0,
        max_output_tokens=8192
     
    ),

    # Define structured output schema :

    # output_schema=L3RequirementOutput,
    output_key="l3_requirement"
)

print(f"Agent '{L3_Agent.name}' created using model '{MODEL_NAME}'.")

In [None]:
session_service = InMemorySessionService()

In [None]:
runner_l3 = Runner(agent=L3_Agent, app_name=APP_NAME, session_service=session_service)

In [None]:
#  Chat Session Manager 

class ChatSession:
    def __init__(self):
        self.code_state = ""   # stores SQL code
        self.history = []     # stores chat history

    def update_code_state(self, code):
        self.code_state = code.strip()
        return f"SQL Code state updated ({len(code)} characters)"

chat_session = ChatSession()

In [None]:
# Session has to be created in the running cell as it is the part of the event loop and awaits on the agent.

await session_service.create_session(
    app_name=APP_NAME,
    user_id=USER_ID,
    session_id=SESSION_ID
)

async def chat_with_agent(message, history, session_id, current_code_state=None):
    parts = []
    parts.append(types.Part(text=message))

    if current_code_state.strip() != chat_session.code_state:
        parts.append(types.Part(text=f"Manually updated latest SQL:\n{current_code_state.strip()}"))
        chat_session.code_state = current_code_state.strip()

    next_message = types.Content(
        role="user",
        parts=parts
    )

    response_text = ""
    raw_response_text = ""
    sql_generated = None
    made_tool_call = False

    async for event in runner_l3.run_async(
        user_id=USER_ID,
        session_id=session_id,
        new_message=next_message
    ):
        # Agent text output:

        if getattr(event, "content", None) and getattr(event.content, "parts", None):
            
            for part in event.content.parts:
                if hasattr(part, "text") and part.text:
                    # Collect normal agent text
                    response_text += part.text.strip() + "\n"

                    # Capture SQL blocks if present:
                    
                    if '```sql' in part.text:
                        sql_generated = part.text.split('```sql')[1].split('```')[0].strip()

                elif hasattr(part, "function_call") and part.function_call:
                    made_tool_call = True
                    func_name = getattr(part.function_call, "name", None)
                    if func_name:
                        print(f" Tool called: {func_name}")
                    else:
                        print(" Tool call part received (no name yet)")
            
            
        

        # if getattr(event, "content", None) and getattr(event.content, "parts", None):
        #     text_parts = [
        #         p.text for p in event.content.parts
        #         if hasattr(p, "text") and p.text is not None
        #     ]
        #     for agent_text in text_parts:

        #       raw_response_text += agent_text
        #       if '``sql' in agent_text:
        #         sql_generated = agent_text.split('`sql')[1].split('``')[0]
        #       else:
        #         response_text += agent_text + "\n"
                    
                  
    # If response text is empty:   # change
    
    if not response_text.strip():
        if made_tool_call:
            response_text = "(Agent made a tool call — waiting for results...)"
        else:
            response_text = "(No response generated. Please try rephrasing or check your SQL editor.)"
        
    

    history.append({"role": "user", "content": message})
    history.append({"role": "assistant", "content": response_text.strip()})

    # Save generated SQL as base_code
    if sql_generated:
        chat_session.code_state = sql_generated.strip()

    return history, chat_session.code_state

async def write_to_dbt(current_code_state, history):
    """
    Writes the current SQL code to a dbt L3 model and commits it to GitHub.
    """

    # Ensure a message is returned even if SQL is empty:
    
    if not current_code_state or not current_code_state.strip():
        message = {"role": "assistant", "content": "Cannot write empty SQL."}
        history.append(message)
        return history

    #  Update chat session code state:
    
    chat_session.code_state = current_code_state.strip()

    raw_sql = chat_session.code_state

    prompt = f"""
    Convert the following SQL query into dbt-compatible jinja sql model syntax.

    Rules:
    - Replace FROM and JOIN table references with the dbt ref() function.
      Example: FROM project.dataset.table -> FROM {{{{ref('table')}}}}
    - Keep indentation and SQL formatting clean.
    - Do not change SQL code, keep it as is.
    - Return only SQL code (no explanations).

    SQL:
    {raw_sql}

    OUTPUT_FORMAT:
    {{
      "sqlx": ...(sqlx code here)...
    }}
    """

    print("Sending SQL to Gemini for transformation...")

    try:
        response = genai_client.models.generate_content(
        model="gemini-2.5-pro",
        contents=prompt,
        config=types.GenerateContentConfig(
            temperature=0,
            max_output_tokens=8192,
            response_mime_type="application/json",
        )
    )

        raw_sql_json = response.text
        parsed = parse_model_json(raw_sql_json)
        sql = parsed["sqlx"]
        # filename = parsed["filename"]

        try:
            formatted_sql = sqlparse.format(sql, reindent=True, keyword_case="upper")
            print(formatted_sql)
        except Exception:
            formatted_sql = sql

        transformed_sql = formatted_sql
        print("Gemini transformed SQL successfully.")

    except Exception as e:
        transformed_sql = raw_sql
        history.append({
            "role": "assistant",
            "content": f"Gemini transformation failed, using raw SQL. Error: {e}"
        })

    #  Attach dbt config block on top:
    
    sql_content = CONFIG_TEMPLATE.strip() + "\n\n" + transformed_sql
    sqlx_content = sql_content.strip()

    if sqlx_content[-1] == ';':
      sqlx_content = sqlx_content[:-1]
        
    # Define folder and filenames
    folder = "models/gold"
    os.makedirs(folder, exist_ok=True)
    file_name = "Balance_Sheet_Summary"
    sql_path = f"{folder}/{file_name}.sql"
    yaml_path = f"{folder}/{file_name}.yml"

    # Build YAML metadata (optional, basic structure)
    yaml_content = textwrap.dedent(f"""\
    version: 2

    models:
      - name: {file_name}
        description: "Auto-generated L3 dbt model from agent"
        columns: []
    """)

    #  GitHub commit details :
    
    github_repo = GITHUB_REPO
    github_token = GITHUB_TOKEN
    branch = BRANCH

    if not github_token:
        error_text = "GitHub token not found. Please set GITHUB_TOKEN environment variable."
        history.append({"role": "assistant", "content": error_text})
        return history

    headers = {"Authorization": f"token {github_token}"}
    api_base = f"https://api.github.com/repos/{github_repo}"

    #  Get the latest branch SHA:
    
    branch_info = requests.get(f"{api_base}/git/refs/heads/{branch}", headers=headers).json()
    base_sha = branch_info["object"]["sha"]

    #  Create a new Git tree:
    
    tree = [
        {
            "path": sql_path,
            "mode": "100644",
            "type": "blob",
            "content": sql_content,
        },
        {
            "path": yaml_path,
            "mode": "100644",
            "type": "blob",
            "content": yaml_content,
        }
    ]

    tree_resp = requests.post(
        f"{api_base}/git/trees",
        headers=headers,
        json={"base_tree": base_sha, "tree": tree},
    ).json()
    new_tree_sha = tree_resp["sha"]

    # Create a commit:
    
    commit_message = f"Auto-generated L3 model: {file_name}.sql"
    commit_resp = requests.post(
        f"{api_base}/git/commits",
        headers=headers,
        json={"message": commit_message, "tree": new_tree_sha, "parents": [base_sha]},
    ).json()
    new_commit_sha = commit_resp["sha"]

    #  Update the branch reference:
    
    requests.patch(
        f"{api_base}/git/refs/heads/{branch}",
        headers=headers,
        json={"sha": new_commit_sha},
    )

    response_text = f"Successfully wrote L3 model `{file_name}.sql` to dbt repo `{github_repo}` on branch `{branch}`."
    history.append({"role": "assistant", "content": response_text})
    return history
    
# Gradio UI:

with gr.Blocks(title="L3 Agent App") as demo:

    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("# L3 Layer Agent")
            chatbot = gr.Chatbot(label="L3 Agent", type="messages")
            user_input = gr.Textbox(placeholder="Type your requirement...", label="Message")

        with gr.Column(scale=1):
            gr.Markdown("# SQL Generator (dbt)")
            sql_editor = gr.Code(language="sql", label="SQL Editor", value="", interactive=True)
            write_btn = gr.Button("Write to dbt Repo")

    #  Respond handler:
    
    async def respond(message, history, current_code_state):
        print(message, history, current_code_state)
        history, code_state = await chat_with_agent(message, history, SESSION_ID, current_code_state)
        return history, code_state, ""

    #  Gradio Event Bindings :
    
    user_input.submit(
        respond,
        [user_input, chatbot, sql_editor],
        [chatbot, sql_editor, user_input]
    )

    write_btn.click(
        fn=write_to_dbt,
        inputs=[sql_editor, chatbot],
        outputs=[chatbot]
    )



### Execution:

In [None]:
#  Launch:

demo.launch(share=True, debug=True)