### Snowflake Imports:

In [37]:
# --- Snowflake & Snowpark --
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col


### General Imports:

In [None]:
# --- Core Python Libs ---
import json
import re
import traceback
import os
import textwrap
import base64
from datetime import datetime
from pathlib import Path
import threading

# --- Third-Party Libs ---
import pandas as pd
import yaml
import requests
import sqlparse


### UI Imports:

In [None]:
# --- Interactive UI Widgets ---
import ipywidgets as widgets
from IPython.display import display, clear_output

### Global Variables:

In [38]:
TARGET_DB ="DATAPLATR_DEMO"
L1_SCHEMA = "STG_SI" 
L2_SCHEMA = "DWH_SI"   

MODEL = "llama3.1-70b"
#"snowflake-arctic"


### Github Configs:

In [None]:

git_repo_url="https://github.com/poojapapney/snf_elt.git"
GITHUB_REPO = "poojapapney/snf_elt"
BRANCH="PP-Test"
GITHUB_TOKEN="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"

### Snowflake Connection Set up:

In [39]:
# Global snf session variable :

SNOWFLAKE_SESSION = None


def init_snowflake_session(target_db: str, target_schema: str):
    """
    Initializes or retrieves a global Snowflake session.
    Works in both Snowflake-native and local environments.
    """
    global SNOWFLAKE_SESSION

    if SNOWFLAKE_SESSION is not None:
        
        print(" Reusing existing Snowflake session.")
        
        return SNOWFLAKE_SESSION

    try:
        # Try to use active Snowflake Notebook session:
        
        SNOWFLAKE_SESSION = get_active_session()
        print("  Successfully connected to active Snowflake session.")
    except Exception: # if not
        print("  No active session found. Creating a new session for local development...")

        with open("configs.json", "r") as f:
            connection_parameters = json.load(f)

        SNOWFLAKE_SESSION = Session.builder.configs(connection_parameters).create()
        print("  Successfully created a new local Snowflake session.")

    # Set database and schema context:
    
    SNOWFLAKE_SESSION.use_database(target_db)
    SNOWFLAKE_SESSION.use_schema(target_schema)

    print(f" Session active with DB: {target_db}, Schema: {target_schema}")
    return SNOWFLAKE_SESSION


In [40]:
SNOWFLAKE_SESSION= init_snowflake_session(TARGET_DB, L1_SCHEMA)

  Successfully connected to active Snowflake session.
 Session active with DB: DATAPLATR_DEMO, Schema: STG_SI


### Helper Functions

In [41]:
def ensure_schema(session, database_name: str, schema_name: str):
    """Checks if a schema exists and creates it if not."""
    schema_ref = f'"{database_name}"."{schema_name}"'
    try:
        session.sql(f'USE DATABASE "{database_name}"').collect()
        session.sql(f"DESC SCHEMA {schema_ref}").collect()
        print(f"Schema {schema_ref} already exists.")
    except SnowparkSQLException as e:
        if "does not exist" in str(e).lower():
            print(f"Schema {schema_ref} not found. Creating...")
            session.sql(f"CREATE SCHEMA IF NOT EXISTS {schema_ref}").collect()
            print(f"Schema {schema_ref} created successfully.")
        else:
            raise

def parse_model_json(raw_text: str):
    # Remove markdown fences if present
    cleaned = re.sub(r"^```json|```$", "", raw_text.strip(), flags=re.MULTILINE).strip()
    return json.loads(cleaned)
    

def commit_files_to_github(repo: str, branch: str, token: str, files: list, commit_message: str):
    """Pushes a list of files to a GitHub repository in a single batch commit."""
    print("\nðŸš€ Starting batch commit to GitHub...")
    headers = {"Authorization": f"token {token}", "Accept": "application/vnd.github+json"}
    api_base = f"https://api.github.com/repos/{repo}"

    try:
        ref_resp = requests.get(f"{api_base}/git/refs/heads/{branch}", headers=headers)
        ref_resp.raise_for_status()
        commit_sha = ref_resp.json()["object"]["sha"]
        
        commit_obj = requests.get(f"{api_base}/git/commits/{commit_sha}", headers=headers)
        commit_obj.raise_for_status()
        base_tree_sha = commit_obj.json()["tree"]["sha"]

        tree = [{"path": f["path"], "mode": "100644", "type": "blob", "content": f["content"]} for f in files]

        tree_resp = requests.post(
            f"{api_base}/git/trees", headers=headers, json={"base_tree": base_tree_sha, "tree": tree}
        )
        tree_resp.raise_for_status()
        new_tree_sha = tree_resp.json()["sha"]

        commit_resp = requests.post(
            f"{api_base}/git/commits",
            headers=headers,
            json={"message": commit_message, "tree": new_tree_sha, "parents": [commit_sha]},
        )
        commit_resp.raise_for_status()
        new_commit_sha = commit_resp.json()["sha"]

        patch_resp = requests.patch(
            f"{api_base}/git/refs/heads/{branch}", headers=headers, json={"sha": new_commit_sha}
        )
        patch_resp.raise_for_status()

        print(f"Successfully pushed {len(files)} files to {repo}@{branch}")
    except Exception as e:
        print(f"GitHub batch commit failed: {e}")
        if hasattr(e, "response") and e.response is not None:
            print(f"    Response body: {e.response.text}")

### Silver L2 Agent: LLM Suggestions & User Interface
These functions suggest keys using a local LLM call and display an interactive UI for validation.

In [42]:
def suggest_l2_keys_snowflake(session: Session, l1_metadata: dict, model_name: str = MODEL) -> dict:
    """Uses Snowflake Cortex AI_COMPLETE to suggest keys for an L2 incremental model."""
    print("Calling Snowflake Cortex to suggest keys...")
    columns_json = json.dumps(l1_metadata.get("columns", []), indent=2)
    
    # Sanitize the JSON string for use inside a SQL query
    escaped_columns_json = columns_json.replace("'", "''")
    
    prompt = f"""
    You are an expert dbt data engineer designing an incremental model (SCD Type 1).
    Based on the following columns from a source view, please suggest the best keys.

    Columns:
    {escaped_columns_json}

    Your Task is to suggest:
    1.  **unique_key**: A list of one or more columns that uniquely identify a single record (the primary key). This is mandatory.
    2.  **non_null_fields**: A list of columns that should never be null (e.g., important business fields). # <-- ADDED
    3.  **incremental_key**: The single best column for incremental loading. This must be a timestamp representing the last modification date.
    4.  **delete_flag**: A boolean or status column that indicates if a record has been soft-deleted (e.g., 'IS_DELETED').

    Respond ONLY in valid JSON format like this, with no extra text or markdown:
    {{
      "unique_key": ["COLUMN_NAME_1"],
      "non_null_fields": ["COLUMN_NAME_1", "ANOTHER_IMPORTANT_COLUMN"],
      "incremental_key": "LAST_MODIFIED_DATE",
      "delete_flag": "IS_DELETED"
    }}
    """
    
    # Escape single quotes in the prompt for the SQL query
    escaped_prompt = prompt.replace("'", "''")

    query = f"""
    SELECT AI_COMPLETE('{model_name}', '{escaped_prompt}',
         OBJECT_CONSTRUCT('temperature', 0))
    """
    
    result = ""
    try:
        result = session.sql(query).collect()[0][0]
        
        # First parse attempt
        suggestions = json.loads(result)
        
       
        if isinstance(suggestions, str):
            suggestions = json.loads(suggestions)

        print("Cortex suggestions received.")
        return suggestions
        
    except (json.JSONDecodeError, IndexError) as e:
        print(f" Cortex did not return valid JSON. Error: {e}\\nRaw Response: {result}")
        return {} # Return empty dict on failure

def interactive_key_selection_ui(columns: list, suggestions: dict, table_name: str, callback):
    """Displays an interactive UI with a title for a user to confirm or edit LLM-suggested keys."""
    
    def_val = lambda key, fallback=None: suggestions.get(key, fallback)

    
    title_widget = widgets.HTML(f"<h3>Select Keys for: {table_name}</h3>")
    

    unique_keys_widget = widgets.SelectMultiple(
        options=columns, description="Unique Keys:",
        value=[v for v in def_val("unique_key", []) if v in columns],
        layout=widgets.Layout(width="90%", height="100px")
    )
    
    # <-- ADDED THIS WIDGET -->
    non_null_fields_widget = widgets.SelectMultiple(
        options=columns, description="Non-Null Fields:",
        value=[v for v in def_val("non_null_fields", []) if v in columns],
        layout=widgets.Layout(width="90%", height="100px")
    )
    
    incremental_key_widget = widgets.Dropdown(
        options=columns, description="Incremental Key:",
        value=def_val("incremental_key") if def_val("incremental_key") in columns else None,
        layout=widgets.Layout(width="90%")
    )
    
    delete_flag_widget = widgets.Dropdown(
        options=["None"] + columns, description="Delete Flag:",
        value=def_val("delete_flag") if def_val("delete_flag") in columns else "None",
        layout=widgets.Layout(width="90%")
    )

    submit_btn = widgets.Button(description="Confirm and Process Next", button_style='success', layout=widgets.Layout(width="50%"))
    
    
    ui = widgets.VBox([
        title_widget, 
        unique_keys_widget, 
        non_null_fields_widget, 
        incremental_key_widget, 
        delete_flag_widget, 
        submit_btn
    ])
    
    display(ui)

    def on_submit(b):
        selections = {
            "unique_key": list(unique_keys_widget.value),
            "non_null_fields": list(non_null_fields_widget.value), # <-- ADDED
            "incremental_key": incremental_key_widget.value,
            "delete_flag": delete_flag_widget.value if delete_flag_widget.value != "None" else None
        }
        
        callback(selections)

    submit_btn.on_click(on_submit)

In [43]:
def generate_l2_dbt_artifacts(l1_view_name: str, l2_table_name: str, selections: dict, metadata: dict) -> dict:
    
    """Generates the content for the dbt .sql and .yml files for an L2 incremental model."""

    # === Extract selections ===
    unique_key = selections['unique_key']
    incremental_key = selections['incremental_key']
    delete_flag = selections.get('delete_flag') 
    non_null_fields = selections.get('non_null_fields', [])

    if not unique_key or not incremental_key:
        raise ValueError("Unique Key and Incremental Key are mandatory selections.")

    # === 1. LOGIC FOR SOFT DELETES ===
    post_hook_config = ""
    if delete_flag:
        post_hook_config = f"""
        , post_hook=[
            "DELETE FROM {{{{ this }}}} WHERE {delete_flag} = TRUE"
        ]
        """

    # === 2. LOGIC FOR SOURCE DEDUPLICATION & INCREMENTAL LOAD ===
    partition_by_str = ", ".join(unique_key)
    
    # --- Generate SQL Content ---
    sql_content = textwrap.dedent(f"""
    {{{{
        config(
            materialized='incremental',
            incremental_strategy='merge',
            merge_exclude_columns = ['DW_INSERT_DATE'],
            schema='{L2_SCHEMA}',
            unique_key={unique_key}{post_hook_config}
        )
    }}}}

    -- This model incrementally builds the L2 table (SCD Type 1)
    -- It merges new records and updates existing ones based on the unique key.

    WITH filtered_source AS (
        SELECT
            *, {{{{ generate_audit_column('DW_INSERT_DATE') }}}}
             ,{{{{ generate_audit_column('DW_UPDATE_DATE') }}}}
        FROM {{{{ ref('{l1_view_name}') }}}}

        {{% if is_incremental() %}}
        -- This filter is applied *before* deduplication, just like in GCP
        WHERE {incremental_key} > (SELECT MAX({incremental_key}) FROM {{{{ this }}}})
        {{% endif %}}
    )

    SELECT
        *
    FROM filtered_source
    
    -- Deduplicate *after* filtering for new records
    QUALIFY ROW_NUMBER() OVER (
      PARTITION BY {partition_by_str}
      ORDER BY {incremental_key} DESC
    ) = 1
    """)

    # --- Generate YAML Content ---
    
    unique_key_set = set(unique_key)
    non_null_set = set(non_null_fields) | unique_key_set 

    column_yaml = []
    for col in metadata.get("columns", []):
        col_name = col['name']
        
        # --- FIX: Replaced "''" with '\"' for correct YAML escaping ---
        # --- CORRECTED: Removed 'or 'No description provided.'' ---
        col_desc = (col.get('description') or "").replace('"', '\\"')
        
        # --- FIX: Changed \\n to \n and fixed quote escaping ---
        entry = f"      - name: {col_name}\n        description: \"{col_desc}\""
        
        tests_to_add = []
        if col_name in unique_key_set:
            tests_to_add.append("unique")
        
        if col_name in non_null_set:
            tests_to_add.append("not_null")

        if tests_to_add:
            # --- FIX: Changed \\n to \n ---
            entry += "\n        tests:"
            for test_name in tests_to_add:
                entry += f"\n          - {test_name}"
        
        column_yaml.append(entry)

    # --- FIX: Replaced "''" with '\"' for correct YAML escaping ---
    # --- CORRECTED: Removed 'or 'No description provided.'' ---
    table_desc = (metadata.get("table_description") or "").replace('"', '\\"')
    
    # --- FIX: Removed the trailing \ from f"""\ ---
    yaml_content = textwrap.dedent(f"""
    version: 2

    models:
      - name: {l2_table_name}
        description: "{table_desc}"
        columns:
    """) + "\n".join(column_yaml)

    return {
        "sql_path": f"models/silver/L2/{l2_table_name}.sql",
        "sql_content": sql_content,
        "yaml_path": f"models/silver/L2/{l2_table_name}.yml",
        "yaml_content": yaml_content
    }

### Main Silver L2 Pipeline:

In [44]:
class L2PipelineOrchestrator:
    def __init__(self, session, db, l1_schema, l2_schema, github_repo, branch, token):
        self.session = session
        self.db = db
        self.l1_schema = l1_schema
        self.l2_schema = l2_schema
        self.github_repo = github_repo
        self.branch = branch
        self.token = token
        self.all_generated_files = []
        self.l1_view_names = []
        self.current_index = 0
        self.current_l1_metadata = {}
        #
        self.output_area = widgets.Output()
        

    def start(self):
        """the pipeline process and displays the output area."""
        
        display(self.output_area)
        #
        
        with self.output_area:
            ensure_schema(self.session, self.db, self.l2_schema)
            
            print(f"\nFetching L1 views from schema: {self.db}.{self.l1_schema}...")
            l1_views_df = self.session.sql(f"""
                SELECT TABLE_NAME
                FROM {self.db}.INFORMATION_SCHEMA.TABLES
                WHERE TABLE_SCHEMA = '{self.l1_schema}' 
                  AND TABLE_TYPE = 'VIEW'
                  AND TABLE_NAME LIKE 'V_STG_%_L1' -- <--- ADD THIS LINE
                ORDER BY TABLE_NAME
            """).to_pandas()

            if l1_views_df.empty:
                print(f" No L1 views found in {self.db}.{self.l1_schema}. Please run the L1 pipeline first.")
                return

            self.l1_view_names = l1_views_df['TABLE_NAME'].tolist()
            print(f"Found {len(self.l1_view_names)} L1 views to process: {self.l1_view_names}")
            
        self._prepare_next_table()

    def _prepare_next_table(self):
        """Fetches metadata/suggestions and displays UI, logging into the output area."""
        
        
        if self.current_index >= len(self.l1_view_names):
            with self.output_area:
                clear_output(wait=True) 
                print("\nðŸŽ‰ All tables have been processed locally.")
            self._commit_to_github() # Attempt commit (logs will also go to output_area)
            return

        
        with self.output_area:
            clear_output(wait=True) 
        
        l1_view_name = self.l1_view_names[self.current_index]

       
        with self.output_area:
            print(f"\n{'='*20}\nProcessing L1 View ({self.current_index + 1}/{len(self.l1_view_names)}): {l1_view_name}\n{'='*20}")

            # 1. Fetch L1 metadata from Snowflake
            try:
                cols_df = self.session.sql(f"""
                    SELECT COLUMN_NAME, COMMENT
                    FROM {self.db}.INFORMATION_SCHEMA.COLUMNS
                    WHERE TABLE_SCHEMA = '{self.l1_schema}' AND TABLE_NAME = '{l1_view_name}'
                    ORDER BY ORDINAL_POSITION
                """).to_pandas()
                
                table_comment_df = self.session.sql(f"""
                    SELECT COMMENT
                    FROM {self.db}.INFORMATION_SCHEMA.TABLES
                    WHERE TABLE_SCHEMA = '{self.l1_schema}' AND TABLE_NAME = '{l1_view_name}'
                """).to_pandas()

                self.current_l1_metadata = {
                    "table_description": table_comment_df['COMMENT'][0] if not table_comment_df.empty else "",
                    "columns": [{"name": row['COLUMN_NAME'], "description": row['COMMENT']} for _, row in cols_df.iterrows()]
                }
                column_names = [col['name'] for col in self.current_l1_metadata['columns']]

                # 2. Get LLM suggestions using Snowflake Cortex
                suggestions = suggest_l2_keys_snowflake(self.session, self.current_l1_metadata)
                
                # 3. Display the UI (UI appears below the output area)
                print(f"Waiting for your input for {l1_view_name}...")
                # Pass l1_view_name to the UI function for the title
                interactive_key_selection_ui(column_names, suggestions, l1_view_name, self._handle_ui_submission) 

            except Exception as e:
                print(f" Error preparing table {l1_view_name}: {e}\n{traceback.format_exc()}")
                print("Skipping this table and moving to the next one.")
                self.current_index += 1
                
                threading.Timer(0.5, self._prepare_next_table).start() 

    def _handle_ui_submission(self, selections):
        """Callback: Generates artifacts and triggers next step, logging into the output area."""
        
        # Extra Sanity Checks
        if self.current_index >= len(self.l1_view_names):
            
             with self.output_area:
                 print("Warning: Submission received after processing should have finished.")
             return 
       

        l1_view_name = self.l1_view_names[self.current_index] 
        l2_table_name = l1_view_name.replace("V_STG_", "").replace("_L1", "_L2")
        
        
        with self.output_area:
            print(f" Selections confirmed for {l2_table_name}: {selections}")
            
            try:
                
                artifacts = generate_l2_dbt_artifacts(l1_view_name, l2_table_name, selections, self.current_l1_metadata)
                self.all_generated_files.append({"path": artifacts["sql_path"], "content": artifacts["sql_content"]})
                self.all_generated_files.append({"path": artifacts["yaml_path"], "content": artifacts["yaml_content"]})
                print(f"   Generated {artifacts['sql_path']}")
                print(f"   Generated {artifacts['yaml_path']}")
                
            except Exception as e:
                print(f" Error generating artifacts for {l2_table_name}: {e}\n{traceback.format_exc()}")
        
        
        self.current_index += 1
        
        
        self._prepare_next_table() 

    def _commit_to_github(self):
        """Commits all generated files to GitHub, logging into the output area."""
        with self.output_area: 
            if self.all_generated_files:
                print(f"\nAttempting to commit {len(self.all_generated_files)} files to GitHub...")
                commit_message = f"feat: Add/Update L2 incremental models for schema {self.l1_schema}"
                try: 
                    # Assuming commit_files_to_github also uses print statements
                    commit_files_to_github(self.github_repo, self.branch, self.token, self.all_generated_files, commit_message)
                    print("GitHub commit process finished.") 
                except Exception as e:
                    print(f" An error occurred during the GitHub commit process: {e}")
                    print(traceback.format_exc())
            else:
                print("\nNo files were generated or an error occurred. Nothing to commit.")

### Execution Workflow:

In [45]:
# Create an instance of the orchestrator with configuration
orchestrator = L2PipelineOrchestrator(
    session=SNOWFLAKE_SESSION,
    db=TARGET_DB,
    l1_schema=L1_SCHEMA,
    l2_schema=L2_SCHEMA,
    github_repo=GITHUB_REPO,
    branch=BRANCH,
    token=GITHUB_TOKEN
)

# Start the pipeline
orchestrator.start()

Output()