# 1. Identifying Column Lineage, Joins and Calculated Fields:

In [0]:
import yaml
import pandas as pd
import re
import os
import json
from pathlib import Path
from typing import List, Dict, Any, Optional
import pyspark.sql.functions as F

# --- Configuration ---
CATALOG = "dbx_migration_poc"
SCHEMA = "dbx_migration_ts"

dbutils.widgets.text("tml_file", "")
tml_file = dbutils.widgets.get("tml_file")

# -------------------------------------------------------------------------
# CORE EXTRACTOR CLASS (Auto-Discovery Version)
# -------------------------------------------------------------------------
class TMLAnalyzer:
    def __init__(self, liveboard_path: str, tml_base_path: str):
        """
        Initializes the analyzer by reading the Liveboard file and 
        automatically discovering related Model and Table files.
        """
        self.liveboard_path = liveboard_path
        self.tml_base_path = tml_base_path
        
        print(f"1. Loading Liveboard: {Path(liveboard_path).name}")
        self.liveboard_tm = self._load_yaml(liveboard_path)
        
        # --- Auto-Discover Model ---
        visualizations = self.liveboard_tm.get('liveboard', {}).get('visualizations', [])
        model_name = None

        if visualizations:
            # Check the first visualization's answer block for the primary model/worksheet
            first_viz_tables = visualizations[0].get('answer', {}).get('tables', [])
            if first_viz_tables:
                model_name = first_viz_tables[0].get('name')

        if not model_name:
            raise ValueError(f"Could not identify a Model name from the first visualization in {Path(liveboard_path).name}")
            
        print(f"   > Found Base Model Reference: {model_name}")

        # --- Load Model File ---
        model_file_path = f"{tml_base_path}/model/{model_name}.model.tml"
        
        if not os.path.exists(model_file_path):
             model_file_path_ws = f"{tml_base_path}/worksheet/{model_name}.worksheet.tml"
             if os.path.exists(model_file_path_ws):
                 model_file_path = model_file_path_ws
             else:
                 model_file_path_flat = f"{os.path.dirname(liveboard_path)}/{model_name}.model.tml"
                 if os.path.exists(model_file_path_flat):
                    model_file_path = model_file_path_flat
                 else:
                    raise FileNotFoundError(f"Model file not found at: {model_file_path}")
             
        print(f"2. Loading Model: {Path(model_file_path).name}")
        self.model_tm = self._load_yaml(model_file_path)
        
        # --- Auto-Discover Tables ---
        self.tables_tm = {}
        table_names = self._get_all_model_tables(self.model_tm)
        
        print(f"3. Discovering {len(table_names)} Tables referenced in Model...")
        for t_name in table_names:
            t_path = f"{tml_base_path}/table/{t_name}.table.tml"
            
            if not os.path.exists(t_path):
                 t_path_flat = f"{os.path.dirname(liveboard_path)}/{t_name}.table.tml"
                 if os.path.exists(t_path_flat):
                     t_path = t_path_flat

            if os.path.exists(t_path):
                data = self._load_yaml(t_path)
                self.tables_tm[t_name] = data.get('table', {})
                print(f"   > Loaded Table: {t_name}")
            else:
                print(f"   ⚠️ Warning: Table file not found for '{t_name}' at {t_path}")

        # Pre-process Model Columns for fast lookup
        self.model_col_map = self._build_model_column_map()

    def _load_yaml(self, path: str) -> Dict[str, Any]:
        with open(path, 'r') as f:
            return yaml.safe_load(f)

    def _get_all_model_tables(self, model_data: Dict[str, Any]) -> List[str]:
        root = model_data.get('model') or model_data.get('worksheet') or {}
        return [t['name'] for t in root.get('model_tables', []) if 'name' in t]

    def _build_model_column_map(self) -> Dict[str, Dict]:
        col_map = {}
        root = self.model_tm.get('model') or self.model_tm.get('worksheet') or {}
        
        for col in root.get('columns', []):
            agg = col.get('properties', {}).get('aggregation')
            
            if 'column_id' in col:
                parts = col['column_id'].split('::')
                if len(parts) == 2:
                    entry = {
                        'table_name': parts[0],
                        'physical_col_id': parts[1],
                        'full_physical_id': col['column_id'], # <--- Storing full reference
                        'type': 'DIRECT',
                        'aggregation': agg
                    }
                    col_map[col['name']] = entry
            elif 'formula_id' in col:
                col_map[col['name']] = {
                    'type': 'FORMULA', 
                    'full_physical_id': col['formula_id'], # <--- Storing formula reference
                    'expr': '', 
                    'aggregation': agg
                }
        
        for form in root.get('formulas', []):
            name = form.get('name')
            expr = form.get('expr', '')
            if name in col_map:
                col_map[name]['expr'] = expr
            else:
                entry = {'type': 'FORMULA', 'expr': expr, 'aggregation': None, 'full_physical_id': form.get('id')}
                col_map[name] = entry
            if 'id' in form:
                col_map[form['id']] = col_map[name]
             
        return col_map

    def _clean_col_name(self, name: str) -> str:
        name = re.sub(r'^(Total |Maximum |Minimum |Average |Unique Number of )\s*', '', name, flags=re.IGNORECASE)
        wrapper_keywords = r'Sum|Count|Avg|Min|Max|Unique Count|Monthly|Daily|Weekly|Quarterly|Yearly|Week|Month|Quarter|Year|Day'
        while True:
            match = re.match(r'^(' + wrapper_keywords + r')\s*\((.*)\)$', name, flags=re.IGNORECASE)
            if match:
                name = match.group(2).strip()
            else:
                break
        return name

    def _get_physical_info(self, logical_table, col_id):
        t_def = self.tables_tm.get(logical_table, {})
        db = t_def.get('db', 'UNK')
        sch = t_def.get('schema', 'UNK')
        tbl = t_def.get('db_table', logical_table)
        clean_tbl = f"`{tbl}`" if ' ' in tbl or '-' in tbl else tbl
        full_table = f"{db}.{sch}.{clean_tbl}"
        
        phy_col = col_id
        for c in t_def.get('columns', []):
            if c['name'] == col_id:
                phy_col = c.get('db_column_name', col_id)
                break
        if ' ' in phy_col: phy_col = f"`{phy_col}`"
        return full_table, tbl, phy_col

    def _resolve_expr(self, expr, viz_formulas=None, depth=0):
        if depth > 10: return "UNKNOWN", "UNKNOWN", "UNKNOWN"
        match = re.search(r'\[([^\]]+)::([^\]]+)\]', expr)
        if match:
            return self._get_physical_info(match.group(1), match.group(2))
        matches = re.findall(r'\[(.*?)\]', expr)
        best_resolution = ("UNKNOWN", "UNKNOWN", "UNKNOWN")
        for ref in matches:
            if '::' in ref: continue 
            current_resolution = ("UNKNOWN", "UNKNOWN", "UNKNOWN")
            if viz_formulas and ref in viz_formulas:
                current_resolution = self._resolve_expr(viz_formulas[ref], viz_formulas, depth+1)
            if current_resolution[0] == "UNKNOWN":
                m_info = self.model_col_map.get(ref)
                if m_info:
                    if m_info['type'] == 'DIRECT':
                        current_resolution = self._get_physical_info(m_info['table_name'], m_info['physical_col_id'])
                    elif m_info['type'] == 'FORMULA':
                        current_resolution = self._resolve_expr(m_info['expr'], viz_formulas, depth+1)
            if current_resolution[0] != "UNKNOWN":
                best_resolution = current_resolution
                break 
        return best_resolution

    # -------------------------------------------------------------------------
    # INTEGRATED REQUIREMENT: JOINING LIVEBOARD FILTERS WITH MODEL COLUMN IDs
    # -------------------------------------------------------------------------
    def generate_filter_details(self) -> pd.DataFrame:
        """
        Parses Liveboard TML and joins with Model mapping for physical IDs.
        """
        rows = []
        # Extract filters from the liveboard root 
        filters = self.liveboard_tm.get('liveboard', {}).get('filters', [])
        
        for fltr in filters:
            # Get column alias from the Liveboard
            columns = fltr.get('column', [])
            col_name = columns[0] if columns else "Unknown"
            
            # --- Link with Model to get the actual physical ID ---
            # We look up the alias in the pre-built model_col_map 
            m_info = self.model_col_map.get(col_name, {})
            physical_id = m_info.get('full_physical_id', "Not Found in Model")
            
            # Get the operator (e.g., 'in') [cite: 10, 12, 13]
            operator = fltr.get('oper', "Unknown")
            # Get the operator (e.g., 'in') [cite: 10, 12, 13]
            display_name = fltr.get('display_name', "Unknown")
            # Get the operator (e.g., 'in') [cite: 10, 12, 13]
            is_single_value = fltr.get('is_single_value', "Unknown")
            # Get the operator (e.g., 'in') [cite: 10, 12, 13]
            is_mandatory = fltr.get('is_mandatory', "Unknown")
            
            # Extract and format values: wrap each in '' and the group in () [cite: 10, 11]
            raw_values = fltr.get('values', [])
            if raw_values:
                # Handle strings and numbers by forcing to string and adding quotes
                formatted_list = ", ".join([f"'{str(v)}'" for v in raw_values])
                values_str = f"({formatted_list})"
            else:
                values_str = "()"
            
            rows.append({
                "Filter_Column": col_name,
                "display_name": display_name,         # User-facing Name
                "Physical_Column_ID": physical_id, # Underlying table::column or formula_id [cite: 27, 31, 41]
                "Operator": operator,
                "Values": values_str,
                "is_single_value" : is_single_value,
                "is_mandatory" : is_mandatory
            })
            
        return pd.DataFrame(rows)

    def get_calculated_fields(self) -> pd.DataFrame:
        root = self.model_tm.get('model') or self.model_tm.get('worksheet') or {}
        formulas = root.get('formulas', [])
        formula_data = []
        for formula in formulas:
            raw_expr = formula.get('expr', '')
            is_nested_flag = "formula" in raw_expr.lower() if raw_expr else False
            formula_data.append({
                'id': formula.get('id', ''),
                'name': formula.get('name', ''),
                'expr': raw_expr,
                'is_nested': is_nested_flag
            })
        return pd.DataFrame(formula_data)

    def generate_column_lineage(self) -> pd.DataFrame:
        rows = []
        file_name = Path(self.liveboard_path).name
        visualizations = self.liveboard_tm.get('liveboard', {}).get('visualizations', [])
        for viz in visualizations:
            viz_id = viz.get('id')
            viz_name = viz.get('answer', {}).get('name', 'Unknown')
            viz_formulas = {form['name']: form.get('expr', '') for form in viz.get('answer', {}).get('formulas', [])}
            answer_cols = viz.get('answer', {}).get('answer_columns', [])
            for col in answer_cols:
                lb_col = col.get('name')
                if not lb_col: continue
                clean_name = self._clean_col_name(lb_col)
                base_model_col = clean_name
                ft, pt, pc = "UNKNOWN", "UNKNOWN", "UNKNOWN"
                agg_property = None
                if lb_col in viz_formulas:
                    ft, pt, pc = self._resolve_expr(viz_formulas[lb_col], viz_formulas)
                else:
                    m_info = self.model_col_map.get(clean_name)
                    if m_info:
                        agg_property = m_info.get('aggregation')
                        if m_info['type'] == 'DIRECT':
                            ft, pt, pc = self._get_physical_info(m_info['table_name'], m_info['physical_col_id'])
                        elif m_info['type'] == 'FORMULA':
                            ft, pt, pc = self._resolve_expr(m_info['expr'], viz_formulas)
                    else:
                        ft, pt, pc = self._resolve_expr(f"[{clean_name}]", viz_formulas)
                rows.append({
                    "tml_file": file_name,
                    "visualization_id": viz_id,
                    "Visualization": viz_name,
                    "Liveboard_Column": lb_col,
                    "Model_Base_Column": base_model_col,
                    "Model_Aggregation": agg_property,
                    "DBX_Full_Table": ft,
                    "Physical_Table": pt,
                    "Physical_DB_Column": pc
                })
        return pd.DataFrame(rows)

    def generate_join_details(self) -> pd.DataFrame:
        rows = []
        for t_name, t_def in self.tables_tm.items():
            for join in t_def.get('joins_with', []):
                t1, t2 = t_name, join.get('destination', {}).get('name')
                on = join.get('on', '')
                def replacer(m):
                    tb, cl = m.group(1), m.group(2)
                    if ' ' in cl: cl = f"`{cl}`"
                    if ' ' in tb: tb = f"`{tb}`"
                    return f"{tb}.{cl}"
                clean_on = re.sub(r'\[([^\]]+)::([^\]]+)\]', replacer, on)
                rows.append({
                    "Table_1__From": t1,
                    "Table_2__To": t2,
                    "Join_Type": join.get('type', 'INNER'),
                    "Explicit_Condition": clean_on,
                    "Relationship_Key": join.get('name')
                })
        return pd.DataFrame(rows)

# -------------------------------------------------------------------------
# EXECUTION
# -------------------------------------------------------------------------
BASE_PATH = "/Volumes/dbx_migration_poc/dbx_migration_ts/lv_dashfiles_ak" 
LIVEBOARD_FILE = f"{BASE_PATH}/liveboard/{tml_file}"

raw_name = os.path.basename(LIVEBOARD_FILE).split('.')[0]
asset_name = re.sub(r'[\s\-]+', '_', raw_name)

LINEAGE_TABLE_NAME = f"{CATALOG}.{SCHEMA}.{asset_name}_column_lineage"
JOINS_TABLE_NAME = f"{CATALOG}.{SCHEMA}.{asset_name}_join_details"
EXPR_TABLE_NAME = f"{CATALOG}.{SCHEMA}.{asset_name}_calculated_fields_details"
FILTER_TABLE_NAME = f"{CATALOG}.{SCHEMA}.{asset_name}_filter_details"

try:
    print(f"Starting analysis for: {LIVEBOARD_FILE}")
    analyzer = TMLAnalyzer(LIVEBOARD_FILE, BASE_PATH)
    
    df_lineage = analyzer.generate_column_lineage()
    df_joins = analyzer.generate_join_details()
    df_calculated_fields = analyzer.get_calculated_fields()
    df_filters = analyzer.generate_filter_details()
    
    print(f"\nGenerated {len(df_lineage)} lineage rows.")
    print(f"Generated {len(df_joins)} join rows.")
    print(f"Generated {len(df_calculated_fields)} calculated field rows.")
    print(f"Generated {len(df_filters)} filter detail rows.")
    
    # Save to Delta Tables (Assuming Spark environment)
    if not df_lineage.empty:
        spark.createDataFrame(df_lineage).write.mode("overwrite").saveAsTable(LINEAGE_TABLE_NAME)
    if not df_joins.empty:
        spark.createDataFrame(df_joins).write.mode("overwrite").saveAsTable(JOINS_TABLE_NAME)
    if not df_calculated_fields.empty:
        spark.createDataFrame(df_calculated_fields).write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(EXPR_TABLE_NAME)
    if not df_filters.empty:
        spark.createDataFrame(df_filters).write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(FILTER_TABLE_NAME)
        print(f"Saved: {FILTER_TABLE_NAME}")
        display(df_filters.head(5))
        
except Exception as e:
    print(f"Error: {e}")

### 1.1 Tables Availability Check:

In [0]:
query = f'''
    SELECT 
        DISTINCT
        A.Physical_Table AS Model_Table, 
        concat(table_catalog,'.',table_schema,'.',Physical_Table) AS Physical_Table
    FROM {LINEAGE_TABLE_NAME} AS A
    LEFT JOIN system.information_schema.tables AS B 
        ON UPPER(A.Physical_Table) = UPPER(B.table_name) 
''' 
table_availability = spark.sql(query)

# 1. Filter specifically for the missing tables
missing_tables_df = table_availability.filter(F.col("Physical_Table").isNull())
table_count = missing_tables_df.count()

print(f"{table_count} tables are not available in the target")

# 2. Logic to stop execution and return missing names
if table_count > 0:
    missing_list = [row.Model_Table for row in missing_tables_df.select("Model_Table").collect()] 
    missing_str = ", ".join(missing_list)
    error_message = f"FAILURE: The following {table_count} tables are missing in Databricks: {missing_str}"
    print(error_message)
    dbutils.notebook.exit(json.dumps({
    "Error_Message": error_message
}))
print("All tables validated successfully.")

## 1.2 Restructure Query

In [0]:
import re

def restructure_if_expression(expr):
    if not isinstance(expr, str) or not expr:
        return expr

    # Search for the start of an 'if (' block
    if_match = re.search(r'\bif\s*\(', expr, flags=re.IGNORECASE)
    if if_match and if_match.start() > 0:
        start_idx = if_match.start()
        
        # Examine the part before the 'if'
        prefix_raw = expr[:start_idx].rstrip()
        
        # Check if the 'if' block was wrapped in a parenthesis (e.g., "+ ( if ... )")
        paren_wrap = False
        if prefix_raw.endswith('('):
            prefix_raw = prefix_raw[:-1].rstrip()
            paren_wrap = True
            
        # Identify the operator (+, -, *, /) immediately preceding the if block
        op_match = re.search(r'([\+\-\*\/])\s*$', prefix_raw)
        if op_match:
            operator = op_match.group(1)
            arith_part = prefix_raw[:op_match.start()].strip()
            
            # Extract the 'if' functional block
            if_part = expr[start_idx:]
            if paren_wrap:
                # Assuming the closing parenthesis for the wrap is at the end
                if if_part.endswith(')'):
                    if_part = if_part[:-1].strip()
                # Reconstruct: ( if (...) ) <operator> ( <arithmetic_part> )
                return f" ( {if_part} )  {operator} ( {arith_part} )"
            else:
                # Reconstruct for unwrapped cases
                return f"( {if_part} ) {operator} ( {arith_part} )"

    return expr

# Update the expr column in the dataframe
df_calculated_fields['expr'] = df_calculated_fields['expr'].apply(restructure_if_expression)
spark.createDataFrame(df_calculated_fields).write.mode("overwrite").saveAsTable(f"{EXPR_TABLE_NAME}")

# 2. Conversion of TML Expressions to SQL Syntax:

In [0]:
!pip install sqlglot
import re
import sqlglot
from sqlglot import exp

def convert_tml_to_spark_sql(tml_expression: str) -> str:
    """
    Robustly converts ThoughtSpot TML expressions to Spark SQL syntax.
    
    Key improvements:
    - Converts IF-THEN-ELSE to CASE WHEN (more reliable parsing)
    - Recursive handling for nested conditions
    - Better suited for SQLGlot parsing
    """
    if not tml_expression:
        return None

    # --- PHASE 1: TOKENIZATION (Hide Column Names) ---
    column_map = {}
    
    def token_replacer(match):
        table = match.group(1)
        col = match.group(2)
        
        def quote(s):
            # 0. Handle Safety checks
            if not s:
                return s
            if s.startswith('`') and s.endswith('`'):
                return s
            if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', s):
                return f"`{s}`"
            return s
        
        if table:
            spark_col = f"{quote(table)}.{quote(col)}"
        else:
            spark_col = quote(col)
            
        token = f"__TML_COL_{len(column_map)}__"
        column_map[token] = spark_col
        return token

    clean_expr = re.sub(r'\[(?:([^:]+)::)?([^\]]+)\]', token_replacer, tml_expression)

    # --- PHASE 2: SYNTAX NORMALIZATION ---
    clean_expr = clean_expr.replace('{', '(').replace('}', ')')
    clean_expr = clean_expr.replace('"', "'")
    
    # Convert unique count early
    clean_expr = re.sub(r'unique\s+count\s*\(', 'COUNT_DISTINCT(', clean_expr, flags=re.IGNORECASE)
    
    # Convert unique_count_if to COUNT(DISTINCT CASE WHEN ... THEN ... END)
    # Pattern: unique_count_if(condition, value) -> COUNT(DISTINCT CASE WHEN condition THEN value END)
    def convert_unique_count_if(text):
        """Convert unique_count_if(condition, value) to Spark SQL syntax"""
        pattern = re.compile(r'unique_count_if\s*\(', flags=re.IGNORECASE)
        
        while True:
            match = pattern.search(text)
            if not match:
                break
            
            start_pos = match.start()
            paren_start = match.end() - 1
            
            # Find the comma separating condition and value
            depth = 1
            i = paren_start + 1
            comma_pos = None
            
            while i < len(text) and depth > 0:
                if text[i] == '(':
                    depth += 1
                elif text[i] == ')':
                    depth -= 1
                    if depth == 0:
                        break
                elif text[i] == ',' and depth == 1:
                    comma_pos = i
                    break
                i += 1
            
            if not comma_pos:
                # No comma found, skip
                text = text[:start_pos] + "SKIP_UNIQUE_COUNT_IF" + text[start_pos+15:]
                continue
            
            condition = text[paren_start + 1:comma_pos].strip()
            
            # Find the closing paren for the value
            depth = 1
            i = comma_pos + 1
            while i < len(text) and depth > 0:
                if text[i] == '(':
                    depth += 1
                elif text[i] == ')':
                    depth -= 1
                i += 1
            
            value = text[comma_pos + 1:i - 1].strip()
            
            # Build COUNT(DISTINCT CASE WHEN condition THEN value END)
            replacement = f"COUNT(DISTINCT CASE WHEN {condition} THEN {value} END)"
            
            text = text[:start_pos] + replacement + text[i:]
        
        text = text.replace("SKIP_UNIQUE_COUNT_IF", "unique_count_if")
        return text
    
    clean_expr = convert_unique_count_if(clean_expr)

    # --- PHASE 2.5: CONVERT IF-THEN-ELSE TO CASE WHEN ---
    def convert_if_to_case(text):
        """
        Recursively convert IF-THEN-ELSE statements to CASE WHEN syntax.
        Process from innermost to outermost by finding IFs without nested IFs in their condition.
        """
        max_iterations = 100
        iteration = 0
        
        while iteration < max_iterations:
            iteration += 1
            original = text
            
            # Find all "if (" positions
            if_pattern = re.compile(r'\bif\s*\(', flags=re.IGNORECASE)
            if_matches = list(if_pattern.finditer(text))
            
            if not if_matches:
                break
            
            # Process the LAST if (innermost/rightmost) to build from inside out
            if_match = if_matches[-1]
            
            start_pos = if_match.start()
            paren_start = if_match.end() - 1
            
            # Find matching ) for the condition - properly track depth
            depth = 1
            i = paren_start + 1
            while i < len(text) and depth > 0:
                if text[i] == '(':
                    depth += 1
                elif text[i] == ')':
                    depth -= 1
                i += 1
            
            if depth != 0:
                # Unmatched parens, skip and mark it
                text = text[:start_pos] + "SKIP_IF" + text[start_pos+2:]
                continue
                
            cond_end = i - 1
            condition = text[paren_start + 1:cond_end].strip()
            
            # Find "then" - must be right after the condition's closing paren
            then_match = re.match(r'\s*then\s+', text[cond_end + 1:], flags=re.IGNORECASE)
            if not then_match:
                # No THEN found, skip
                text = text[:start_pos] + "SKIP_IF" + text[start_pos+2:]
                continue
            
            true_start = cond_end + 1 + then_match.end()
            
            # Find "else" at the same nesting level (depth 0)
            # Count parens to know when we're at the top level
            depth = 0
            i = true_start
            else_pos = None
            true_end = len(text)
            
            while i < len(text):
                if text[i] == '(':
                    depth += 1
                elif text[i] == ')':
                    if depth == 0:
                        # Hit closing paren of wrapping expression
                        true_end = i
                        break
                    depth -= 1
                elif depth == 0:
                    # At top level - check for "else"
                    if re.match(r'\s*else\s+', text[i:], flags=re.IGNORECASE):
                        else_match = re.match(r'\s*else\s+', text[i:], flags=re.IGNORECASE)
                        else_pos = i + else_match.end()
                        true_end = i
                        break
                i += 1
            
            true_value = text[true_start:true_end].strip()
            
            if else_pos:
                # Find end of false value - same depth tracking
                depth = 0
                i = else_pos
                false_end = len(text)
                
                while i < len(text):
                    if text[i] == '(':
                        depth += 1
                    elif text[i] == ')':
                        if depth == 0:
                            false_end = i
                            break
                        depth -= 1
                    i += 1
                
                false_value = text[else_pos:false_end].strip()
                end_pos = false_end
            else:
                false_value = "NULL"
                end_pos = true_end
            
            # Build CASE WHEN statement
            case_stmt = f"CASE WHEN {condition} THEN {true_value} ELSE {false_value} END"
            
            # Replace in text
            text = text[:start_pos] + case_stmt + text[end_pos:]
            
            # If nothing changed, break to avoid infinite loop
            if text == original:
                break
        
        # Restore any SKIP_IF back to IF (ones we couldn't process)
        text = text.replace("SKIP_IF", "if")
        
        return text
    
    clean_expr = convert_if_to_case(clean_expr)

        # --- PHASE 3: PARSING & TRANSFORMATION ---
    try:
        parsed = sqlglot.parse_one(clean_expr)
        
        def transformer(node):
            def func(name, args):
                return exp.Anonymous(this=name, expressions=args)

            if isinstance(node, exp.Anonymous):
                name = node.this.lower()
                args = node.expressions
                
                # FUNCTION MAPPINGS
                if name == "safe_divide":
                    return func("try_divide", args)
                if name == "diff_days":
                    return func("datediff", args)
                if name == "day_number_of_week":
                    return func("dayofweek", args)
                if name == "now":
                    return func("current_timestamp", []) 
                if name == "isnull":
                    return func("isnull", args)
                if name == "contains":
                    if len(args) >= 2:
                        return exp.Like(this=args[0], expression=exp.Literal.string(f"%{args[1].name}%"))
                if name == "count_distinct":
                    return exp.Count(this=exp.Distinct(expressions=args))
                     
            return node

        transformed = parsed.transform(transformer)
        spark_sql = transformed.sql(dialect="spark")
        
    except Exception as e:
        print(f"Warning: Parse failed ({e}). Returning partially processed string.")
        spark_sql = clean_expr

    # --- PHASE 4: RESTORE COLUMN NAMES ---
    def restore(match):
        token = match.group(0)
        return column_map.get(token, token)
        
    final_sql = re.sub(r'__TML_COL_\d+__', restore, spark_sql)
    
    return final_sql


# Apply the function row-by-row to create the new column
df_calculated_fields['spark_sql'] = df_calculated_fields['expr'].apply(convert_tml_to_spark_sql)
spark.createDataFrame(df_calculated_fields).write.mode("overwrite").saveAsTable(f"{EXPR_TABLE_NAME}_v2")

# Verify
display(df_calculated_fields)

# 3. Model Object SQL Generator:

## 3.1: Hierarchy Builder:

In [0]:
import pandas as pd

def add_hierarchy_column(df):
    """
    Adds a 'hierarchy' column to the dataframe determining the dependency level.
    Level 1: No dependencies.
    Level 2: Depends on Level 1.
    Level N: 1 + Max(dependency levels).
    """
    # Create a mapping of id to spark_sql for easy lookup
    id_to_sql = df.set_index('id')['spark_sql'].to_dict()
    
    # Get all IDs and sort them by length descending
    # This ensures that when checking for dependencies, we match 'formula_Net_Ageing Group'
    # before 'formula_Ageing Group', preventing false partial matches.
    all_ids = sorted(df['id'].unique(), key=len, reverse=True)
    
    # Memoization dictionary to store calculated levels
    levels = {}
    
    # Set to keep track of visiting nodes for cycle detection
    visiting = set()

    def get_level(current_id):
        # Return memoized value if available
        if current_id in levels:
            return levels[current_id]
        
        # Cycle detection
        if current_id in visiting:
            raise ValueError(f"Circular dependency detected involving {current_id}")
        
        visiting.add(current_id)
        
        sql_text = id_to_sql.get(current_id, "")
        
        # Identify dependencies
        # We check if other IDs appear in this SQL text.
        # We rely on the sorted order (longest first) to avoid substring issues.
        dependencies = []
        temp_sql = sql_text # Work on a copy to "consume" found IDs
        
        for candidate_id in all_ids:
            # We skip the current_id itself to avoid self-reference counting
            if candidate_id == current_id:
                continue
                
            if candidate_id in temp_sql:
                dependencies.append(candidate_id)
                # Remove the found ID from temp_sql so it doesn't match shorter substrings later
                temp_sql = temp_sql.replace(candidate_id, "")
        
        # Base Case: No dependencies found
        if not dependencies:
            lvl = 1
        else:
            # Recursive Step
            dependency_levels = [get_level(dep_id) for dep_id in dependencies]
            lvl = 1 + max(dependency_levels)
        
        visiting.remove(current_id)
        levels[current_id] = lvl
        return lvl

    # Apply the function to all rows
    df['hierarchy'] = df['id'].apply(get_level)
    
    return df

# Apply the hierarchy logic first
df_result = add_hierarchy_column(df_calculated_fields)

# --- NEW STEP: Remove 'formula_' prefix from the spark_sql column ---
# This safely replaces the string "formula_" with an empty string in the generated SQL
df_result['spark_sql'] = df_result['spark_sql'].str.replace('formula_', '', regex=False)

# Display result
print(df_result[['id', 'spark_sql', 'hierarchy']].head())

# Save logic
spark.createDataFrame(df_result).write.mode("overwrite").option("mergeSchema","true").saveAsTable(f"{EXPR_TABLE_NAME}_v2")

## 3.2: SQL Generator:

In [0]:
import re

CALC_FIELDS_TABLE_NAME = f"{EXPR_TABLE_NAME}_v2"
FILTER_COLUMNS = f"{FILTER_TABLE_NAME}"

AGGREGATE_FUNCTIONS = {
    "count", "sum", "avg", "mean", "min", "max", 
    "first", "last", "collect_list", "collect_set", 
    "stddev", "variance", "kurtosis", "skewness", 
    "approx_distinct", "corr", "covar_pop", "covar_samp"
}

def quote_column_name(column_name):
    if not column_name:
        return column_name
    if column_name.startswith('#'):
        return f"`{column_name}`"
    if column_name.startswith('`') and column_name.endswith('`'):
        return column_name
    if ' ' in column_name or column_name[0].isdigit():
        return f"`{column_name}`"
    return column_name

def parse_filter_column_id(physical_column_id):
    if not physical_column_id or '::' not in physical_column_id:
        return None, None
    parts = physical_column_id.split('::', 1)
    table_name = parts[0].strip()
    column_name = parts[1].strip()
    return table_name, column_name

def is_aggregate_expression(expression):
    if not expression:
        return False
    for func in AGGREGATE_FUNCTIONS:
        pattern = r'\b' + re.escape(func) + r'\s*\('
        if re.search(pattern, expression, re.IGNORECASE):
            return True
    return False

def clean_spark_sql_expr(expr, alias_map=None, column_alias_map=None):
    if not expr: 
        return expr
    cleaned_expr = expr
    if alias_map is not None:
        for full_table_name, alias in alias_map.items():
            raw_table_name = full_table_name.split('.')[-1]
            pattern = r'\b' + re.escape(raw_table_name) + r'\.(?=`?[a-zA-Z0-9_#])'
            cleaned_expr = re.sub(pattern, f"{alias}.", cleaned_expr)
    else:
        cleaned_expr = re.sub(r'\b[a-zA-Z0-9_]+\.(?=`?[a-zA-Z0-9_#])', '', cleaned_expr)
    if column_alias_map is not None:
        for physical_col, model_alias in column_alias_map.items():
            physical_col_clean = physical_col.strip('`')
            model_alias_quoted = quote_column_name(model_alias)
            pattern1 = r'`' + re.escape(physical_col_clean) + r'`'
            pattern2 = r'\b' + re.escape(physical_col_clean) + r'\b'
            cleaned_expr = re.sub(pattern1, model_alias_quoted, cleaned_expr)
            cleaned_expr = re.sub(pattern2, model_alias_quoted, cleaned_expr)
    return cleaned_expr

lineage_query = f'''
    SELECT 
        A.Visualization, 
        A.Liveboard_Column, 
        A.Model_Base_Column, 
        concat(table_catalog,'.',table_schema,'.',Physical_Table) AS Physical_Table, 
        Physical_DB_Column,
        C.spark_sql,
        COALESCE(C.hierarchy, 1) AS hierarchy
    FROM {LINEAGE_TABLE_NAME} AS A
    LEFT JOIN system.information_schema.tables AS B 
        ON UPPER(A.Physical_Table) = UPPER(B.table_name)
    LEFT JOIN {CALC_FIELDS_TABLE_NAME} AS C 
        ON A.Model_Base_Column = C.name
'''

filter_query = f'''
    SELECT 
        Physical_Column_ID
    FROM {FILTER_COLUMNS}
'''

filter_pdf = spark.sql(filter_query).toPandas()

filter_columns = []
for index, row in filter_pdf.iterrows():
    table_name, column_name = parse_filter_column_id(row['Physical_Column_ID'])
    if table_name and column_name:
        filter_columns.append({
            'raw_table': table_name,
            'column': column_name
        })

calc_fields_query = f'''
    SELECT 
        name,
        spark_sql,
        COALESCE(hierarchy, 1) AS hierarchy
    FROM {CALC_FIELDS_TABLE_NAME}
'''

calc_fields_pdf = spark.sql(calc_fields_query).toPandas()

def find_referenced_calc_fields(expression, calc_fields_df):
    if not expression or str(expression) == 'nan' or str(expression) == 'None':
        return set()
    referenced = set()
    pattern = r'`([^`]+)`'
    matches = re.findall(pattern, str(expression))
    for match in matches:
        if match in calc_fields_df['name'].values:
            referenced.add(match)
    for calc_name in calc_fields_df['name'].values:
        if re.search(r'\b' + re.escape(calc_name) + r'\b', str(expression)):
            referenced.add(calc_name)
    return referenced

def add_filter_columns_to_select(select_list, group_by_list, selected_cols, filter_columns, 
                                   joined_tables, base_table, alias_map=None, use_table_alias=True):
    for filter_col in filter_columns:
        raw_table = filter_col['raw_table']
        column_name = filter_col['column']
        quoted_column = quote_column_name(column_name)
        matched_table = None
        for full_table in joined_tables.keys():
            if full_table.endswith(raw_table) or full_table.split('.')[-1] == raw_table:
                matched_table = full_table
                break
        if matched_table:
            table_alias = joined_tables.get(matched_table, 't1')
            unique_key = (matched_table, column_name)
            if unique_key not in selected_cols:
                if use_table_alias:
                    select_clause = f"{table_alias}.{quoted_column}"
                    group_by_clause = f"{table_alias}.{quoted_column}"
                else:
                    select_clause = quoted_column
                    group_by_clause = quoted_column
                if select_clause not in select_list:
                    select_list.append(select_clause)
                    group_by_list.append(group_by_clause)
                    selected_cols.add(unique_key)

if df_joins.empty:
    print(f"WARNING: Join Details DataFrame is empty. No Joins available")
    lineage_pdf = spark.sql(lineage_query).toPandas()
    df_select = lineage_pdf
    all_tables = df_select['Physical_Table']
    base_table = all_tables.value_counts().index[0] if not all_tables.value_counts().empty else "UNKNOWN_TABLE"
    joined_tables = {base_table: 't1'} 
    max_hierarchy = df_select['hierarchy'].max()
    
    if max_hierarchy == 1:
        select_list = []
        group_by_list = [] 
        selected_cols = set()
        for index, row in df_select.iterrows():
            table = row['Physical_Table']
            db_col = quote_column_name(row['Physical_DB_Column'])
            alias_col = quote_column_name(row['Model_Base_Column'])
            spark_sql_expr = row['spark_sql'] 
            unique_key = (table, db_col) if not spark_sql_expr else (table, spark_sql_expr)
            if unique_key not in selected_cols:
                if spark_sql_expr and str(spark_sql_expr) != 'nan' and str(spark_sql_expr) != 'None':
                    clean_expr = clean_spark_sql_expr(spark_sql_expr, alias_map=None, column_alias_map=None)
                    select_clause = f"{clean_expr} AS {alias_col}"
                    if not is_aggregate_expression(clean_expr):
                        group_by_list.append(clean_expr)
                else:
                    select_clause = f"{db_col} AS {alias_col}"
                    group_by_list.append(db_col)
                select_list.append(select_clause)
                selected_cols.add(unique_key)
        add_filter_columns_to_select(select_list, group_by_list, selected_cols, 
                                       filter_columns, joined_tables, base_table, 
                                       alias_map=None, use_table_alias=False)
        select_statement = ",\n  ".join(select_list)
        group_by_clause = ""
        if group_by_list:
            group_by_clause = "\nGROUP BY\n  " + ",\n  ".join(group_by_list)
        final_sql_query = (
            "SELECT\n"
            f"  {select_statement}\n"
            "FROM\n"
            f"  {base_table}"
            f"{group_by_clause};"
        )
    else:
        cte_queries = []
        calc_fields_by_level = {}
        fields_needed = set()
        column_alias_map = {}
        for index, row in df_select.iterrows():
            spark_sql_expr = row['spark_sql']
            if not spark_sql_expr or str(spark_sql_expr) == 'nan' or str(spark_sql_expr) == 'None':
                physical_col = row['Physical_DB_Column']
                model_col = row['Model_Base_Column']
                column_alias_map[physical_col] = model_col
        for index, row in df_select.iterrows():
            alias_col = row['Model_Base_Column']
            spark_sql_expr = row['spark_sql']
            hierarchy_level = row['hierarchy']
            if spark_sql_expr and str(spark_sql_expr) != 'nan' and str(spark_sql_expr) != 'None':
                fields_needed.add(alias_col)
                referenced = find_referenced_calc_fields(spark_sql_expr, calc_fields_pdf)
                fields_needed.update(referenced)
        for index, row in calc_fields_pdf.iterrows():
            calc_name = row['name']
            calc_expr = row['spark_sql']
            calc_hierarchy = row['hierarchy']
            if calc_name in fields_needed:
                if calc_hierarchy not in calc_fields_by_level:
                    calc_fields_by_level[calc_hierarchy] = []
                calc_fields_by_level[calc_hierarchy].append({
                    'name': calc_name,
                    'expr': calc_expr,
                    'hierarchy': calc_hierarchy
                })
        for level in range(1, int(max_hierarchy) + 1):
            level_data = df_select[df_select['hierarchy'] == level]
            select_list = []
            group_by_list = []
            selected_cols = set()
            if level == 1:
                for index, row in df_select.iterrows():
                    spark_sql_expr = row['spark_sql']
                    if not spark_sql_expr or str(spark_sql_expr) == 'nan' or str(spark_sql_expr) == 'None':
                        table = row['Physical_Table']
                        db_col = quote_column_name(row['Physical_DB_Column'])
                        alias_col = quote_column_name(row['Model_Base_Column'])
                        unique_key = (table, db_col)
                        if unique_key not in selected_cols:
                            select_clause = f"{db_col} AS {alias_col}"
                            group_by_list.append(db_col)
                            select_list.append(select_clause)
                            selected_cols.add(unique_key)
                for filter_col in filter_columns:
                    raw_table = filter_col['raw_table']
                    column_name = filter_col['column']
                    quoted_column = quote_column_name(column_name)
                    matched_table = None
                    for full_table in joined_tables.keys():
                        if full_table.endswith(raw_table) or full_table.split('.')[-1] == raw_table:
                            matched_table = full_table
                            break
                    if matched_table:
                        unique_key = (matched_table, column_name)
                        if unique_key not in selected_cols:
                            select_list.append(quoted_column)
                            group_by_list.append(quoted_column)
                            selected_cols.add(unique_key)
            if level in calc_fields_by_level:
                for calc_field in calc_fields_by_level[level]:
                    calc_name = quote_column_name(calc_field['name'])
                    calc_expr = calc_field['expr']
                    if calc_name not in [s.split(' AS ')[-1] for s in select_list]:
                        if level == 1:
                            clean_expr = clean_spark_sql_expr(calc_expr, alias_map=None, column_alias_map=None)
                        else:
                            clean_expr = clean_spark_sql_expr(calc_expr, alias_map=None, column_alias_map=column_alias_map)
                        select_clause = f"{clean_expr} AS {calc_name}"
                        if not is_aggregate_expression(clean_expr):
                            group_by_list.append(clean_expr)
                        select_list.append(select_clause)
            if level > 1:
                for prev_level in range(1, level):
                    if prev_level in calc_fields_by_level:
                        for calc_field in calc_fields_by_level[prev_level]:
                            calc_name = quote_column_name(calc_field['name'])
                            if calc_name not in [s.split(' AS ')[-1] for s in select_list]:
                                select_list.append(calc_name)
                                group_by_list.append(calc_name) 
                for index, row in df_select.iterrows():
                    spark_sql_expr = row['spark_sql']
                    if not spark_sql_expr or str(spark_sql_expr) == 'nan' or str(spark_sql_expr) == 'None':
                        alias_col = quote_column_name(row['Model_Base_Column'])
                        if alias_col not in [s.split(' AS ')[-1] for s in select_list]:
                            select_list.append(alias_col)
                            group_by_list.append(alias_col)
                for filter_col in filter_columns:
                    column_name = filter_col['column']
                    quoted_column = quote_column_name(column_name)
                    if quoted_column not in select_list:
                        select_list.append(quoted_column)
                        group_by_list.append(quoted_column)
            select_statement = ",\n    ".join(select_list)
            group_by_clause = ""
            if group_by_list:
                group_by_clause = "\n  GROUP BY\n    " + ",\n    ".join(group_by_list)
            from_clause = base_table if level == 1 else f"cte_level_{level-1}"
            cte_query = (
                f"  cte_level_{level} AS (\n"
                f"    SELECT\n"
                f"      {select_statement}\n"
                f"    FROM\n"
                f"      {from_clause}"
                f"{group_by_clause}\n"
                f"  )"
            )
            cte_queries.append(cte_query)
        with_clause = "WITH\n" + ",\n".join(cte_queries)
        final_sql_query = (
            f"{with_clause}\n"
            f"SELECT * FROM cte_level_{int(max_hierarchy)};"
        )
else:
    df_joins_pdf = spark.sql(f'''
    WITH CTE AS (
    SELECT Table_1__From, 
        Table_2__To, 
        CASE WHEN UPPER(Join_Type) == 'OUTER' THEN 'FULL OUTER' ELSE Join_Type END AS Join_Type, 
        Explicit_Condition, 
        Relationship_Key,
        b.table_catalog as tbl1_catalog,
        b.table_schema as tbl1_schema,
        C.table_catalog as tbl2_catalog,
        C.table_schema as tbl2_schema
        FROM {JOINS_TABLE_NAME} AS A
        LEFT JOIN system.information_schema.tables AS B ON UPPER(A.Table_1__From) = UPPER(B.table_name)
        LEFT JOIN system.information_schema.tables AS C ON UPPER(A.Table_2__To) = UPPER(C.table_name))
        SELECT 
        concat(tbl1_catalog,'.',tbl1_schema,'.',Table_1__From) AS Table_1__From,
        concat(tbl2_catalog,'.',tbl2_schema,'.',Table_2__To) AS Table_2__To,
        Join_Type, 
        Explicit_Condition, 
        Relationship_Key
    FROM CTE''').toPandas()
    lineage_pdf = spark.sql(lineage_query).toPandas()
    df_joins = df_joins_pdf
    df_select = lineage_pdf
    all_tables = df_joins['Table_1__From']
    table_counts = all_tables.value_counts()
    base_table = table_counts.index[0] if not table_counts.empty else df_joins.iloc[0]['Table_1__From']
    if not table_counts[table_counts > 1].empty:
        base_table = table_counts[table_counts > 1].index[0]
    joined_tables = {base_table: 't1'}
    alias_counter = 2
    sql_joins = []
    joins_to_process = df_joins.to_dict('records')
    while joins_to_process:
        processed_count = 0
        for row in list(joins_to_process):
            table1 = row['Table_1__From']
            table2 = row['Table_2__To']
            join_type = row['Join_Type'].replace('_', ' ')
            condition = row['Explicit_Condition']
            new_table = None
            if table1 in joined_tables and table2 not in joined_tables:
                current_table = table1
                new_table = table2
            elif table2 in joined_tables and table1 not in joined_tables:
                current_table = table2
                new_table = table1
            if new_table:
                new_alias = f't{alias_counter}'
                joined_tables[new_table] = new_alias
                alias_counter += 1
                rawtable1 = table1.split('.')[-1]
                rawtable2 = table2.split('.')[-1]
                aliased_condition = condition.replace(f'{rawtable1}.', f'{joined_tables.get(table1, table1)}.')
                aliased_condition = aliased_condition.replace(f'{rawtable2}.', f'{joined_tables.get(table2, table2)}.')
                join_target = table2 if new_table == table2 else table1
                sql_joins.append(f"{join_type} JOIN {join_target} AS {new_alias} ON {aliased_condition}")
                joins_to_process.remove(row)
                processed_count += 1
            elif table1 in joined_tables and table2 in joined_tables:
                joins_to_process.remove(row)
        if processed_count == 0 and joins_to_process:
            break 
    max_hierarchy = df_select['hierarchy'].max()
    
    if max_hierarchy == 1:
        select_list = []
        group_by_list = []
        selected_cols = set()
        for index, row in df_select.iterrows():
            table = row['Physical_Table']
            db_col = quote_column_name(row['Physical_DB_Column'])
            alias_col = quote_column_name(row['Model_Base_Column'])
            spark_sql_expr = row['spark_sql'] 
            if table in joined_tables:
                table_alias = joined_tables[table]
                unique_key = (table, db_col) if not spark_sql_expr else (table, spark_sql_expr)
                if unique_key not in selected_cols:
                    if spark_sql_expr and str(spark_sql_expr) != 'nan' and str(spark_sql_expr) != 'None':
                        clean_expr = clean_spark_sql_expr(spark_sql_expr, alias_map=joined_tables, column_alias_map=None)
                        select_clause = f"{clean_expr} AS {alias_col}"
                        if not is_aggregate_expression(clean_expr):
                             group_by_list.append(clean_expr)
                    else:
                        select_clause = f"{table_alias}.{db_col} AS {alias_col}"
                        group_by_list.append(f"{table_alias}.{db_col}")
                    select_list.append(select_clause)
                    selected_cols.add(unique_key)
        add_filter_columns_to_select(select_list, group_by_list, selected_cols, 
                                       filter_columns, joined_tables, base_table, 
                                       alias_map=joined_tables, use_table_alias=True)
        select_statement = ",\n  ".join(select_list)
        join_block = "\n  ".join(sql_joins)
        group_by_clause = ""
        if group_by_list:
            group_by_clause = "\nGROUP BY\n  " + ",\n  ".join(group_by_list)
        final_sql_query = (
            "SELECT\n"
            f"  {select_statement}\n"
            "FROM\n"
            f"  {base_table} AS t1\n"
            f"  {join_block}"
            f"{group_by_clause};"
        )
    else:
        cte_queries = []
        join_block = "\n    ".join(sql_joins)
        calc_fields_by_level = {}
        fields_needed = set()
        column_alias_map = {}
        for index, row in df_select.iterrows():
            spark_sql_expr = row['spark_sql']
            if not spark_sql_expr or str(spark_sql_expr) == 'nan' or str(spark_sql_expr) == 'None':
                physical_col = row['Physical_DB_Column']
                model_col = row['Model_Base_Column']
                column_alias_map[physical_col] = model_col
        for index, row in df_select.iterrows():
            alias_col = row['Model_Base_Column']
            spark_sql_expr = row['spark_sql']
            hierarchy_level = row['hierarchy']
            if spark_sql_expr and str(spark_sql_expr) != 'nan' and str(spark_sql_expr) != 'None':
                fields_needed.add(alias_col)
                referenced = find_referenced_calc_fields(spark_sql_expr, calc_fields_pdf)
                fields_needed.update(referenced)
        for index, row in calc_fields_pdf.iterrows():
            calc_name = row['name']
            calc_expr = row['spark_sql']
            calc_hierarchy = row['hierarchy']
            if calc_name in fields_needed:
                if calc_hierarchy not in calc_fields_by_level:
                    calc_fields_by_level[calc_hierarchy] = []
                calc_fields_by_level[calc_hierarchy].append({
                    'name': calc_name,
                    'expr': calc_expr,
                    'hierarchy': calc_hierarchy
                })
        for level in range(1, int(max_hierarchy) + 1):
            select_list = []
            group_by_list = []
            selected_cols = set()
            if level == 1:
                for index, row in df_select.iterrows():
                    spark_sql_expr = row['spark_sql']
                    if not spark_sql_expr or str(spark_sql_expr) == 'nan' or str(spark_sql_expr) == 'None':
                        table = row['Physical_Table']
                        db_col = quote_column_name(row['Physical_DB_Column'])
                        alias_col = quote_column_name(row['Model_Base_Column'])
                        if table in joined_tables:
                            table_alias = joined_tables[table]
                            unique_key = (table, db_col)
                            if unique_key not in selected_cols:
                                select_clause = f"{table_alias}.{db_col} AS {alias_col}"
                                group_by_list.append(f"{table_alias}.{db_col}")
                                select_list.append(select_clause)
                                selected_cols.add(unique_key)
                for filter_col in filter_columns:
                    raw_table = filter_col['raw_table']
                    column_name = filter_col['column']
                    quoted_column = quote_column_name(column_name)
                    matched_table = None
                    for full_table in joined_tables.keys():
                        if full_table.endswith(raw_table) or full_table.split('.')[-1] == raw_table:
                            matched_table = full_table
                            break
                    if matched_table:
                        table_alias = joined_tables.get(matched_table, 't1')
                        unique_key = (matched_table, column_name)
                        if unique_key not in selected_cols:
                            select_clause = f"{table_alias}.{quoted_column}"
                            group_by_clause = f"{table_alias}.{quoted_column}"
                            select_list.append(select_clause)
                            group_by_list.append(group_by_clause)
                            selected_cols.add(unique_key)
            if level in calc_fields_by_level:
                for calc_field in calc_fields_by_level[level]:
                    calc_name = quote_column_name(calc_field['name'])
                    calc_expr = calc_field['expr']
                    if calc_name not in [s.split(' AS ')[-1] for s in select_list]:
                        if level == 1:
                            clean_expr = clean_spark_sql_expr(calc_expr, alias_map=joined_tables, column_alias_map=None)
                        else:
                            clean_expr = clean_spark_sql_expr(calc_expr, alias_map=None, column_alias_map=column_alias_map)
                        select_clause = f"{clean_expr} AS {calc_name}"
                        if not is_aggregate_expression(clean_expr):
                            group_by_list.append(clean_expr)
                        select_list.append(select_clause)
            if level > 1:
                for prev_level in range(1, level):
                    if prev_level in calc_fields_by_level:
                        for calc_field in calc_fields_by_level[prev_level]:
                            calc_name = quote_column_name(calc_field['name'])
                            if calc_name not in [s.split(' AS ')[-1] for s in select_list]:
                                select_list.append(calc_name)
                                group_by_list.append(calc_name)
                for index, row in df_select.iterrows():
                    spark_sql_expr = row['spark_sql']
                    if not spark_sql_expr or str(spark_sql_expr) == 'nan' or str(spark_sql_expr) == 'None':
                        alias_col = quote_column_name(row['Model_Base_Column'])
                        if alias_col not in [s.split(' AS ')[-1] for s in select_list]:
                            select_list.append(alias_col)
                            group_by_list.append(alias_col)
                for filter_col in filter_columns:
                    column_name = filter_col['column']
                    quoted_column = quote_column_name(column_name)
                    if quoted_column not in select_list:
                        select_list.append(quoted_column)
                        group_by_list.append(quoted_column)
            select_statement = ",\n    ".join(select_list)
            group_by_clause = ""
            if group_by_list:
                group_by_clause = "\n  GROUP BY\n    " + ",\n    ".join(group_by_list)
            if level == 1:
                from_clause = f"{base_table} AS t1\n    {join_block}"
            else:
                from_clause = f"cte_level_{level-1}"
            cte_query = (
                f"  cte_level_{level} AS (\n"
                f"    SELECT\n"
                f"      {select_statement}\n"
                f"    FROM\n"
                f"      {from_clause}"
                f"{group_by_clause}\n"
                f"  )"
            )
            cte_queries.append(cte_query)
        with_clause = "WITH\n" + ",\n".join(cte_queries)
        final_sql_query = (
            f"{with_clause}\n"
            f"SELECT * FROM cte_level_{int(max_hierarchy)};"
        )

pyspark_code = f"""
sql_query = \"\"\"
{final_sql_query}
\"\"\"

result_df = spark.sql(sql_query)
result_df.display()
"""

print(pyspark_code)

# 4. Saving the SQL query as File and View Creation:

In [0]:
TML_BASE_DIR = '/Volumes/dbx_migration_poc/dbx_migration_ts/lv_dashfiles_ak'
Query_FILE_PATH = f"{TML_BASE_DIR}/SqlOutputs"
file_name = f"{asset_name}.sql"
full_file_path = os.path.join(Query_FILE_PATH, file_name)
# --- Incase Path NA---
os.makedirs(Query_FILE_PATH, exist_ok=True)
try:
    with open(full_file_path, 'w') as f:
        f.write(final_sql_query)
    
    print(f"Successfully saved the query to: {full_file_path}")

except Exception as e:
    print(f"Error writing to Volume: {e}")

In [0]:
spark.sql(f"""
CREATE OR REPLACE VIEW {CATALOG}.{SCHEMA}.{asset_name}_View
AS {final_sql_query}""")

# 5. Unified Dataset Creation: 

### 5.1 Dataset SQL creation:

In [0]:
import pandas as pd
import re

# Model and Dataset naming
model_name = f"{CATALOG}.{SCHEMA}.{asset_name}_View"
unified_dataset_name = f"{asset_name}_Unified_Dataset"

def map_aggregation_to_spark(agg):
  if not agg or pd.isna(agg): return None
  agg_upper = str(agg).upper().strip()
  mapping = {
    "MAXIMUM": "MAX", "MINIMUM": "MIN", "AVERAGE": "AVG", "MEAN": "AVG",
    "UNIQUE COUNT": "COUNT_DISTINCT", "COUNT DISTINCT": "COUNT_DISTINCT",
    "SUM": "SUM", "COUNT": "COUNT", "STD_DEV": "STDDEV", "VARIANCE": "VARIANCE"
  }
  return mapping.get(agg_upper, agg_upper)

def infer_aggregation_and_clean_name(col_name):
  col_lower = col_name.lower()
  prefixes = [
    ("maximum ", "MAX"), ("minimum ", "MIN"), ("average ", "AVG"), ("avg ", "AVG"),
    ("sum ", "SUM"), ("count of ", "COUNT"), ("count ", "COUNT"),
    ("unique count of ", "COUNT_DISTINCT"), ("unique count ", "COUNT_DISTINCT"),
    ("unique number of ", "COUNT_DISTINCT") 
  ]
  for prefix, func in prefixes:
    if col_lower.startswith(prefix):
      clean_name = col_name[len(prefix):].strip()
      return func, clean_name
  return None, col_name

def quote_identifier(name):
  if not name: return name
  if re.match(r'^[a-zA-Z0-9_]+$', name): return name
  return f"`{name}`"

def format_sql_agg(func, col_expr):
  if func == "COUNT_DISTINCT":
    return f"COUNT(DISTINCT {col_expr})"
  return f"{func}({col_expr})"

def parse_filter_column_id(physical_column_id):
  if not physical_column_id or '::' not in physical_column_id:
    return None, None
  parts = physical_column_id.split('::', 1)
  return parts[0].strip(), parts[1].strip()

def load_filter_columns():
  try:
    filter_query = f"SELECT Physical_Column_ID FROM {FILTER_TABLE_NAME}"
    filter_pdf = spark.sql(filter_query).toPandas()
    filter_columns = []
    for _, row in filter_pdf.iterrows():
      table_name, column_name = parse_filter_column_id(row['Physical_Column_ID'])
      if table_name and column_name:
        filter_columns.append({
          'column': column_name,
          'quoted_column': quote_identifier(column_name)
        })
    return filter_columns
  except Exception as e:
    print(f"Warning: Could not load filter columns: {e}")
    return []

def generate_unified_query(df, filter_columns):
  df_unique = df[['Liveboard_Column', 'Model_Aggregation']].drop_duplicates()
  
  select_items = {}
  group_by_cols = set()
  has_aggregation = False
  
  # Helper to normalize names for deduplication check (removes all non-alphanumeric)
  def normalize(name):
      return re.sub(r'[^a-zA-Z0-9]', '', name).lower()

  existing_cleaned_names = set()

  for _, row in df_unique.iterrows():
    raw_col = row['Liveboard_Column']
    meta_agg = map_aggregation_to_spark(row['Model_Aggregation'])
    inferred_agg, clean_name = infer_aggregation_and_clean_name(raw_col)
    
    final_expr = ""
    source_col_for_alias = ""

    if meta_agg and inferred_agg:
      # Aggregated items
      final_expr = format_sql_agg(inferred_agg, quote_identifier(clean_name))
      source_col_for_alias = clean_name
      has_aggregation = True
    elif meta_agg or inferred_agg:
      # Single aggregation items
      source_col_for_alias = clean_name if inferred_agg else raw_col
      final_expr = quote_identifier(source_col_for_alias)
      group_by_cols.add(final_expr)
    else:
      # Dimensions: Strip wrappers like Month(Date)
      unwrapped_col = re.sub(r'^\w+\s*\(\s*(.+?)\s*\)$', r'\1', raw_col)
      source_col_for_alias = unwrapped_col
      final_expr = quote_identifier(unwrapped_col)
      group_by_cols.add(final_expr)
    
    # Generate Alias from the unwrapped column name
    alias = re.sub(r'[^a-zA-Z0-9]+', '_', source_col_for_alias).strip('_')
    
    # Store in select list and track for deduplication
    select_items[alias] = f"{final_expr} AS {alias}"
    existing_cleaned_names.add(normalize(source_col_for_alias))

  # Add filter columns to select if not already present
  for f_col in filter_columns:
    f_raw = f_col['column']
    f_cleaned = normalize(f_raw)
    
    if f_cleaned not in existing_cleaned_names:
      # Generate alias for filter columns (e.g. "Start Date" -> Start_Date)
      f_alias = re.sub(r'[^a-zA-Z0-9]+', '_', f_raw).strip('_')
      select_items[f_alias] = f"{f_col['quoted_column']} AS {f_alias}"
      group_by_cols.add(f_col['quoted_column'])
      existing_cleaned_names.add(f_cleaned)

  # Assemble SQL
  select_clause = "SELECT " + ", ".join(select_items.values())
  sql = f"{select_clause} FROM {model_name}"
  
  if has_aggregation and group_by_cols:
    sql += " GROUP BY " + ", ".join(group_by_cols)
    
  return sql

# --- Main Execution ---
df_lineage = spark.table(LINEAGE_TABLE_NAME).toPandas()
filter_columns = load_filter_columns()

common_query = generate_unified_query(df_lineage, filter_columns)

dashboard_queries = df_lineage[['visualization_id', 'Visualization']].drop_duplicates().copy()
dashboard_queries.rename(columns={'Visualization': 'visualization_name'}, inplace=True)
dashboard_queries['common_sql_query'] = common_query
dashboard_queries['common_dataset_name'] = unified_dataset_name

display(dashboard_queries)

### 5.2: Mapping Table creation:

In [0]:
import json
import yaml
import pandas as pd
import re
from pathlib import Path
from typing import Dict, List, Any
from datetime import datetime
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, TimestampType

# --- Configuration ---
CATALOG = "dbx_migration_poc"
SCHEMA = "dbx_migration_ts"
TML_VOLUME = "lv_dashfiles_ak"

TML_INPUT_PATH = f"/Volumes/{CATALOG}/{SCHEMA}/{TML_VOLUME}/liveboard"
MAPPING_TABLE = f"{CATALOG}.{SCHEMA}.tml_dbx_metadata_mapping"
FAILURE_LOG_TABLE = f"{CATALOG}.{SCHEMA}.tml_dbx_mapping_failures"


# Setup Functions
def setup_failure_log_table():
    """Create or recreate the failure log table."""
    parts = FAILURE_LOG_TABLE.split('.')
    catalog = parts[0]
    schema = parts[1]
    table_name = parts[2]
    
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS `{catalog}`.`{schema}`")
    
    try:
        spark.sql(f"DROP TABLE IF EXISTS `{catalog}`.`{schema}`.`{table_name}`")
        print(f"Dropped existing failure log table: {table_name}")
    except Exception:
        pass
    
    create_sql = f"""
        CREATE OR REPLACE TABLE `{catalog}`.`{schema}`.`{table_name}` (
            tml_file STRING,
            error_type STRING,
            error_message STRING,
            failure_timestamp TIMESTAMP
        ) USING DELTA
    """
    
    spark.sql(create_sql)
    print(f"Created failure log table: {table_name}")

def setup_mapping_table():
    """Create or recreate the metadata mapping table."""
    parts = MAPPING_TABLE.split('.')
    catalog = parts[0]
    schema = parts[1]
    table_name = parts[2]
    
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS `{catalog}`.`{schema}`")
    
    try:
        spark.sql(f"DROP TABLE IF EXISTS `{catalog}`.`{schema}`.`{table_name}`")
        print(f"Dropped existing table: {table_name}")
    except Exception:
        pass
    
    create_sql = f"""
        CREATE OR REPLACE TABLE `{catalog}`.`{schema}`.`{table_name}` (
            tml_file STRING,
            visualization_id STRING,
            visualization_name STRING,
            chart_type STRING,
            tml_table_name STRING,
            tml_table_id STRING,
            tml_columns_used ARRAY<STRING>,
            tml_columns_raw ARRAY<STRING>,
            databricks_table_name_ToBeFilled STRING COMMENT 'For unique datasets per viz',
            databricks_column_mapping_ToBeFilled STRING COMMENT 'For unique datasets per viz - JSON format',
            common_dataset_name STRING COMMENT 'Shared dataset name for reuse across visualizations',
            common_sql_query STRING COMMENT 'Common SQL query for the shared dataset',
            common_column_mapping STRING COMMENT 'JSON mapping of common columns for shared dataset',
            search_query STRING,
            notes STRING,
            extraction_timestamp TIMESTAMP
        ) USING DELTA
    """
    
    spark.sql(create_sql)
    print(f"Created mapping table: {table_name}")

def parse_tml_file(file_path):
    """Parse TML file (YAML or JSON)."""
    content = dbutils.fs.head(file_path, 10 * 1024 * 1024)
    try:
        return yaml.safe_load(content)
    except yaml.YAMLError:
        return json.loads(content)

# Metadata Extraction Functions
def extract_columns_from_answer(answer: Dict) -> List[str]:
    """Extract all column names used in the answer."""
    columns = []
    for col in answer.get('answer_columns', []):
        col_name = col.get('name')
        if col_name:
            columns.append(col_name)
    table_cols = answer.get('table', {}).get('ordered_column_ids', [])
    columns.extend([c for c in table_cols if c and c not in columns])
    return columns

def extract_table_info(answer: Dict) -> tuple:
    """Extract table name and ID from answer."""
    tables = answer.get('tables', [])
    if tables and len(tables) > 0:
        first_table = tables[0]
        return (first_table.get('name', ''), first_table.get('id', ''))
    return ('', '')

def clean_field_name(field_name: str) -> str:
    """Remove aggregate prefixes and date wrappers."""
    if not field_name:
        return ""
    name = re.sub(r'^(Total |Maximum |Minimum |Average |Unique Number of )\s*', '', field_name, flags=re.IGNORECASE)
    wrapper_keywords = r'Sum|Count|Avg|Min|Max|Unique Count|Monthly|Daily|Weekly|Quarterly|Yearly|Week|Month|Quarter|Year|Day'
    while True:
        match = re.match(r'^(' + wrapper_keywords + r')\s*\((.*)\)$', name, flags=re.IGNORECASE)
        if match:
            name = match.group(2).strip()
        else:
            break
    return name.strip()

def extract_base_columns(columns: List[str]) -> List[str]:
    """Extract base column names without aggregations."""
    base_columns = []
    for col in columns:
        cleaned = clean_field_name(col)
        if cleaned and cleaned not in base_columns:
            base_columns.append(cleaned)
    return base_columns


# Main Extraction Logic
def extract_tml_metadata():
    """Extract metadata from all TML files for mapping purposes."""
    print("--- Setting up mapping and failure log tables ---")
    setup_mapping_table()
    setup_failure_log_table()
    
    try:
        tml_files = [f.path for f in dbutils.fs.ls(TML_INPUT_PATH) 
                     if f.path.endswith(('.tml', '.yaml', '.json'))]
    except Exception as e:
        print(f"ERROR: Cannot list files in '{TML_INPUT_PATH}'. Error: {e}")
        return
    
    if not tml_files:
        print(f"No TML files found in {TML_INPUT_PATH}")
        return
    
    print(f"\nFound {len(tml_files)} TML files to process.")
    
    metadata_records = []
    failure_records = []
    
    for tml_file_path in tml_files:
        filename = Path(tml_file_path).name
        
        try:
            print(f"\n--- Processing: {filename} ---")
            
            try:
                tml_data = parse_tml_file(tml_file_path)
            except Exception as parse_error:
                print(f"  ERROR: Failed to parse TML file - {parse_error}")
                failure_records.append({
                    'tml_file': filename,
                    'error_type': 'PARSE_ERROR',
                    'error_message': str(parse_error)[:1000],
                    'failure_timestamp': datetime.now()
                })
                continue
            
            liveboard = tml_data.get('liveboard')
            if not liveboard:
                print(f"  WARNING: No 'liveboard' key found in {filename}")
                failure_records.append({
                    'tml_file': filename,
                    'error_type': 'INVALID_STRUCTURE',
                    'error_message': "Missing 'liveboard' root key in TML file",
                    'failure_timestamp': datetime.now()
                })
                continue
            
            visualizations = liveboard.get('visualizations', [])
            
            if not visualizations:
                print(f"  WARNING: No visualizations found in {filename}")
                failure_records.append({
                    'tml_file': filename,
                    'error_type': 'NO_VISUALIZATIONS',
                    'error_message': "No visualizations found in liveboard",
                    'failure_timestamp': datetime.now()
                })
                continue
            
            print(f"  Found {len(visualizations)} visualizations")
            
            for viz in visualizations:
                try:
                    answer = viz.get('answer', {})
                    chart = answer.get('chart', {})
                    
                    viz_id = viz.get('id', 'unknown')
                    viz_name = answer.get('name', 'Unnamed')
                    
                    display_mode = answer.get('display_mode', '')
                    chart_type = chart.get('type', 'TABLE_MODE' if display_mode == 'TABLE_MODE' else 'UNKNOWN')
                    
                    table_name, table_id = extract_table_info(answer)
                    columns_used_raw = extract_columns_from_answer(answer)
                    base_columns = extract_base_columns(columns_used_raw)
                    search_query = answer.get('search_query', '')
                    
                    record = {
                        'tml_file': filename,
                        'visualization_id': viz_id,
                        'visualization_name': viz_name,
                        'chart_type': chart_type,
                        'tml_table_name': table_name,
                        'tml_table_id': table_id,
                        'tml_columns_used': base_columns,
                        'tml_columns_raw': columns_used_raw,
                        'databricks_table_name_ToBeFilled': '',
                        'databricks_column_mapping_ToBeFilled': '{}',
                        'common_dataset_name': None,
                        'common_sql_query': None,
                        'common_column_mapping': None,
                        'search_query': search_query,
                        'notes': f"Extracted {len(base_columns)} unique columns",
                        'extraction_timestamp': datetime.now()
                    }
                    
                    metadata_records.append(record)
                    print(f"  - {viz_name} ({chart_type}): {len(base_columns)} columns from table '{table_name}'")
                
                except Exception as viz_error:
                    print(f"  ERROR processing visualization '{viz.get('id', 'unknown')}': {viz_error}")
                    failure_records.append({
                        'tml_file': filename,
                        'error_type': 'VISUALIZATION_ERROR',
                        'error_message': f"Viz ID: {viz.get('id', 'unknown')} - {str(viz_error)[:900]}",
                        'failure_timestamp': datetime.now()
                    })
                    continue
        
        except Exception as e:
            print(f"  ERROR processing {filename}: {e}")
            import traceback
            traceback.print_exc()
            failure_records.append({
                'tml_file': filename,
                'error_type': 'PROCESSING_ERROR',
                'error_message': str(e)[:1000],
                'failure_timestamp': datetime.now()
            })
    
    # --- SAVE METADATA & MERGE WITH DASHBOARD_QUERIES ---
    if metadata_records:
        print(f"\n--- Saving {len(metadata_records)} metadata records ---")
        df = pd.DataFrame(metadata_records)
        
        # ===> NEW: MERGE WITH GENERATED SQL QUERIES <===
        if 'dashboard_queries' in globals() and isinstance(dashboard_queries, pd.DataFrame) and not dashboard_queries.empty:
            print("Merging with generated SQL queries (dashboard_queries DataFrame)...")
            
            # Identify columns to merge (common_sql_query is mandatory, common_dataset_name if exists)
            cols_to_merge = ['visualization_id', 'visualization_name']
            
            # Check if dashboard_queries has the required columns
            has_sql = 'common_sql_query' in dashboard_queries.columns
            has_ds_name = 'common_dataset_name' in dashboard_queries.columns
            
            if has_sql:
                cols_to_merge.append('common_sql_query')
            if has_ds_name:
                cols_to_merge.append('common_dataset_name')
                
            if has_sql:
                try:
                    # Ensure matching types for merge keys
                    df['visualization_id'] = df['visualization_id'].astype(str)
                    df['visualization_name'] = df['visualization_name'].astype(str)
                    dashboard_queries['visualization_id'] = dashboard_queries['visualization_id'].astype(str)
                    dashboard_queries['visualization_name'] = dashboard_queries['visualization_name'].astype(str)
                    
                    # Perform Left Join
                    merged_df = pd.merge(
                        df, 
                        dashboard_queries[cols_to_merge], 
                        on=['visualization_id', 'visualization_name'], 
                        how='left', 
                        suffixes=('', '_gen')
                    )
                    
                    # Update common_sql_query
                    merged_df['common_sql_query'] = merged_df['common_sql_query_gen'].combine_first(merged_df['common_sql_query'])
                    
                    # Update common_dataset_name if it was present
                    if has_ds_name:
                        merged_df['common_dataset_name'] = merged_df['common_dataset_name_gen'].combine_first(merged_df['common_dataset_name'])
                        merged_df = merged_df.drop(columns=['common_dataset_name_gen'])
                        
                    # Drop temporary merge column
                    merged_df = merged_df.drop(columns=['common_sql_query_gen'])
                    
                    df = merged_df
                    print("Merge successful: common_sql_query updated.")
                    
                except Exception as merge_e:
                    print(f"WARNING: Failed to merge generated SQL queries: {merge_e}")
            else:
                print("WARNING: dashboard_queries DataFrame missing 'common_sql_query' column. Skipping merge.")
        else:
            print("No 'dashboard_queries' DataFrame found or it is empty. Skipping SQL merge.")

        df['extraction_timestamp'] = pd.to_datetime(df['extraction_timestamp'])
        
        schema = StructType([
            StructField("tml_file", StringType(), True),
            StructField("visualization_id", StringType(), True),
            StructField("visualization_name", StringType(), True),
            StructField("chart_type", StringType(), True),
            StructField("tml_table_name", StringType(), True),
            StructField("tml_table_id", StringType(), True),
            StructField("tml_columns_used", ArrayType(StringType()), True),
            StructField("tml_columns_raw", ArrayType(StringType()), True),
            StructField("databricks_table_name_ToBeFilled", StringType(), True),
            StructField("databricks_column_mapping_ToBeFilled", StringType(), True),
            StructField("common_dataset_name", StringType(), True),
            StructField("common_sql_query", StringType(), True),
            StructField("common_column_mapping", StringType(), True),
            StructField("search_query", StringType(), True),
            StructField("notes", StringType(), True),
            StructField("extraction_timestamp", TimestampType(), True)
        ])
        
        spark_df = spark.createDataFrame(df, schema=schema)
        spark_df.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(MAPPING_TABLE)
        
        print(f"Successfully saved metadata to {MAPPING_TABLE}")
    else:
        print("\nNo metadata records extracted.")
    
    if failure_records:
        print(f"\n--- Saving {len(failure_records)} failure records ---")
        fail_df = pd.DataFrame(failure_records)
        fail_df['failure_timestamp'] = pd.to_datetime(fail_df['failure_timestamp'])
        
        spark_fail_df = spark.createDataFrame(fail_df)
        spark_fail_df.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(FAILURE_LOG_TABLE)
        
        print(f"Failed TML files logged to {FAILURE_LOG_TABLE}")
    else:
        print("\nNo failures encountered - all TML files processed successfully!")


# ------------------------------------------------------------------------------------------------------------
# Execute Extraction
extract_tml_metadata()

# View Results
try:
    df = spark.table(MAPPING_TABLE)
    # Display results showing the merged SQL query
    print("\n--- Displaying Result with Generated SQL ---")
    display(df.select("visualization_name", "tml_columns_used", "common_sql_query").orderBy("tml_file"))
except Exception as e:
    print(f"Could not display table. Error: {e}")

### 5.3: Aggregation Expression Builder:

In [0]:
import pandas as pd
import re

# ----------------------------------------------------------------------
# 1. CONFIGURATION
# ----------------------------------------------------------------------
CATALOG = "dbx_migration_poc"
SCHEMA = "dbx_migration_ts"
LINEAGE_TABLE = f"{LINEAGE_TABLE_NAME}"
OUTPUT_TABLE = f"{CATALOG}.{SCHEMA}.{asset_name}_support_viz_column_details"

# ----------------------------------------------------------------------
# 2. HELPER FUNCTIONS
# ----------------------------------------------------------------------
def sanitize_name(name):
    if not name: return name
    # Clean special characters and replace with underscore
    clean = re.sub(r'[^a-zA-Z0-9]', '_', name)
    clean = re.sub(r'_+', '_', clean)
    return clean.strip('_')

def normalize_for_check(name):
    """Removes all special characters and spaces for duplicate validation."""
    if not name: return ""
    return re.sub(r'[^a-zA-Z0-9]', '', name).lower()

def quote_identifier(name):
    if not name: return name
    clean_name = name.strip('`')
    if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', clean_name):
        return clean_name
    return f"`{clean_name}`"

def map_aggregation_to_spark(agg):
    if not agg or pd.isna(agg): return None
    agg_upper = str(agg).upper().strip()
    mapping = {
        "MAXIMUM": "MAX", "MINIMUM": "MIN", "AVERAGE": "AVG", "MEAN": "AVG",
        "UNIQUE COUNT": "COUNT_DISTINCT", "COUNT DISTINCT": "COUNT_DISTINCT",
        "SUM": "SUM", "COUNT": "COUNT", "STD_DEV": "STDDEV", "VARIANCE": "VARIANCE"
    }
    return mapping.get(agg_upper, agg_upper)

def infer_aggregation_from_name(col_name):
    col_lower = col_name.lower()
    prefixes = [
        ("maximum ", "MAX"), ("minimum ", "MIN"), ("average ", "AVG"), ("avg ", "AVG"),
        ("sum ", "SUM"), ("count of ", "COUNT"), ("count ", "COUNT"),
        ("unique count of ", "COUNT_DISTINCT"), ("unique count ", "COUNT_DISTINCT"),
        ("unique number of ", "COUNT_DISTINCT")
    ]
    for prefix, func in prefixes:
        if col_lower.startswith(prefix):
            return func, col_name[len(prefix):].strip()
    return None, col_name

def format_sql_agg(func, column):
    if func == "COUNT_DISTINCT":
        return f"COUNT(DISTINCT {column})"
    return f"{func}({column})"

def get_agg_label(func):
    if func == "COUNT_DISTINCT":
        return "COUNT(DISTINCT)"
    return func

# ----------------------------------------------------------------------
# 3. GENERATION LOGIC
# ----------------------------------------------------------------------
def generate_column_details(df):
    results = []
    
    # Drop duplicates to handle lineage grain
    df_subset = df[['visualization_id', 'Visualization', 'Liveboard_Column', 'Model_Base_Column', 'Model_Aggregation']].drop_duplicates()
    
    for viz_id, group in df_subset.groupby('visualization_id'):
        # Track existing columns in this specific visualization to prevent duplicates
        seen_normalized_cols = set()
        
        for _, row in group.iterrows():
            viz_name = row['Visualization']
            col_name = row['Liveboard_Column']
            base_col = row['Model_Base_Column']
            meta_agg = row['Model_Aggregation']
            
            # Validation Step: Check if column is already available in the list
            norm_name = normalize_for_check(col_name)
            if norm_name in seen_normalized_cols:
                continue
            seen_normalized_cols.add(norm_name)
            
            # 1. Determine Aggregations
            inferred_agg, clean_col_name = infer_aggregation_from_name(col_name)
            spark_meta_agg = map_aggregation_to_spark(meta_agg)
            
            agg_val = "NA"
            expr_val = ""
            final_sanitized_name = sanitize_name(col_name) # Default: Full name (e.g. Month_Date)
            
            if spark_meta_agg:
                agg_val = get_agg_label(spark_meta_agg)
                # Expr uses the inner column name, Sanitized uses full name
                inner_ref = quote_identifier(sanitize_name(clean_col_name))
                expr_val = format_sql_agg(spark_meta_agg, inner_ref)
                
            elif inferred_agg:
                agg_val = get_agg_label(inferred_agg)
                inner_ref = quote_identifier(sanitize_name(clean_col_name))
                expr_val = format_sql_agg(inferred_agg, inner_ref)
                
            else:
                # 2. Handle Date Functions (DATE_TRUNC)
                date_match = re.search(r'^(Year|Month|Day|Week|Quarter)\s*\(\s*(.+?)\s*\)$', col_name, re.IGNORECASE)
                
                if date_match:
                    granularity = date_match.group(1).upper()
                    inner_col = date_match.group(2)
                    
                    agg_val = f"TRUNC_{granularity}"
                    # Expression: DATE_TRUNC('MONTH', `Date`)
                    inner_ref = quote_identifier(sanitize_name(inner_col))
                    expr_val = f"DATE_TRUNC('{granularity}', {inner_ref})"
                else:
                    agg_val = "NA"
                    expr_val = quote_identifier(sanitize_name(col_name))

            results.append({
                'VizID': viz_id,
                'VizName': viz_name,
                'ColumnName': col_name,
                'ModelBaseColumn': base_col,
                'Aggregation': agg_val,
                'Expression': expr_val,
                'Santized_Column' : final_sanitized_name # Kept as Month_Date, Start_Date etc.
            })
        
    return pd.DataFrame(results)

# ----------------------------------------------------------------------
# 4. EXECUTION
# ----------------------------------------------------------------------
try:
    df_lineage = spark.table(LINEAGE_TABLE).toPandas()
    df_details = generate_column_details(df_lineage)
    
    spark_df = spark.createDataFrame(df_details)
    spark_df.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(OUTPUT_TABLE)
    
    print(f"✅ Successfully saved {len(df_details)} rows to {OUTPUT_TABLE}")
    display(spark_df)
    
except Exception as e:
    print(f"❌ Error: {e}")

### 5.4: Widget filters:

In [0]:
import pandas as pd
import re
import yaml

# ----------------------------------------------------------------------
# 1. CONFIGURATION
# ----------------------------------------------------------------------
CATALOG = "dbx_migration_poc"
SCHEMA = "dbx_migration_ts"
LINEAGE_TABLE = f"{LINEAGE_TABLE_NAME}"

# Output table for this specific metadata
OUTPUT_TABLE = f"{CATALOG}.{SCHEMA}.{asset_name}_support_viz_filter_metadata"

# Path to the TML file
TML_FILE_PATH = f"/Volumes/dbx_migration_poc/dbx_migration_ts/lv_dashfiles_ak/liveboard/{tml_file}"

print(f"Source Lineage: {LINEAGE_TABLE}")
print(f"Source TML: {TML_FILE_PATH}")

# ----------------------------------------------------------------------
# 2. HELPER FUNCTIONS
# ----------------------------------------------------------------------
def sanitize_name(name):
    """
    Replaces spaces and special characters with underscores.
    """
    if not name: return name
    clean = re.sub(r'[^a-zA-Z0-9]', '_', name)
    clean = re.sub(r'_+', '_', clean)
    return clean.strip('_')

def extract_viz_search_queries(tml_path):
    """
    Reads TML file and returns a dict: {viz_id: search_query_string}
    """
    try:
        with open(tml_path, 'r') as f:
            data = yaml.safe_load(f)
        
        viz_map = {}
        for viz in data.get('liveboard', {}).get('visualizations', []):
            viz_id = viz.get('id')
            query = viz.get('answer', {}).get('search_query', '')
            if viz_id:
                viz_map[viz_id] = query
        return viz_map
    except Exception as e:
        print(f"Warning: Could not extract filters from TML: {e}")
        return {}

def get_filters_for_column(search_query, col_name):
    """
    Parses the search query to find filters applied specifically to 'col_name'.
    Captures:
      1. Dot notation: .weekly, .true
      2. Operators: = 'Val', > 10, in {...}
    """
    if not search_query or not col_name: return None
    
    filters = []
    
    # Escape column name for regex (e.g., escape parenthesis in "Count(x)")
    # TML queries wrap columns in brackets: [Column Name]
    col_pattern = re.escape(col_name)
    
    # 1. Dot Notation Filters (e.g., .weekly, .true, .open)
    # Matches: [Column].value
    # Group 1 captures the value after the dot
    dot_regex = rf"\[{col_pattern}\]\.((?:'[^']*')|(?:\S.*?))(?=\s*\[|$)"
    dot_matches = re.findall(dot_regex, search_query, re.IGNORECASE)
    
    for val in dot_matches:
        # Strip quotes if present in the dot value (e.g. .'value')
        clean_val = val.strip("'")
        # Just append the condition, not the column name
        filters.append(f".{clean_val}")

    # 2. Standard Operators (e.g., =, !=, >, in)
    # Matches: [Column] op Value
    # Value can be quoted string, number, or {set}
    op_regex = rf"\[{col_pattern}\]\s*(=|!=|<>|>|<|>=|<=|in|not\s+in)\s*(\'?[^\'\[\]\s]+\'?|\{{.*?\}}|\d+(?:\.\d+)?)"
    op_matches = re.findall(op_regex, search_query, re.IGNORECASE)
    
    for op, val in op_matches:
        # Just append the condition, not the column name
        filters.append(f"{op} {val}")
        
    return ", ".join(filters) if filters else None

# ----------------------------------------------------------------------
# 3. GENERATION LOGIC
# ----------------------------------------------------------------------
def generate_column_filter_metadata(df, viz_queries):
    results = []
    
    # Get unique columns per visualization from lineage
    df_subset = df[['visualization_id', 'Visualization', 'Liveboard_Column']].drop_duplicates()
    
    for (viz_id, viz_name), group in df_subset.groupby(['visualization_id', 'Visualization']):
        
        # Get the full search query for this visualization
        search_query = viz_queries.get(viz_id, "")
        
        for _, row in group.iterrows():
            raw_col = row['Liveboard_Column']
            
            # 1. Sanitize Name (as requested)
            sanitized_col = sanitize_name(raw_col)
            
            # 2. Extract specific filters for this column
            filter_details = get_filters_for_column(search_query, raw_col)
            
            results.append({
                'viz_id': viz_id,
                'viz_name': viz_name,
                'Sanitized_Column': sanitized_col,
                'Filter_Details': filter_details
            })
            
    return pd.DataFrame(results)

# ----------------------------------------------------------------------
# 4. EXECUTION
# ----------------------------------------------------------------------
# 1. Load Lineage
df_lineage = spark.table(LINEAGE_TABLE).toPandas()
print(f"Loaded {len(df_lineage)} lineage rows.")

# 2. Load TML Search Queries
viz_queries = extract_viz_search_queries(TML_FILE_PATH)
print(f"Loaded queries for {len(viz_queries)} visualizations.")

# 3. Generate Metadata
df_result = generate_column_filter_metadata(df_lineage, viz_queries)

# 4. Save and Display
if not df_result.empty:
    spark_df = spark.createDataFrame(df_result)
    spark_df.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(OUTPUT_TABLE)
    print(f"\n✅ Successfully saved {len(df_result)} rows to {OUTPUT_TABLE}")
    display(spark_df)
else:
    print("No data generated.")

In [0]:
dbutils.notebook.exit(json.dumps({
    "Query": final_sql_query.replace("\n", ""),
    "Filepath": full_file_path
}))