In [0]:
%run ./helpers

In [0]:
import json
import yaml
from datacontract.data_contract import DataContract

In [0]:
def create_tables(yaml_file_path):
    data_contract = DataContract(data_contract_file=yaml_file_path, spark=spark)
    queries_ddl_list = data_contract.export("sql")[:-1].split(";")
    print(queries_ddl_list)
    for query in queries_ddl_list:
        query = f'''{query}'''
        try:
            spark.sql(query)
        except Exception as e:
            print(f"{e}\n")
            if "[SCHEMA_NOT_FOUND]" in str(e):
                match = query.split("REPLACE TABLE")[1].split("(")[0].strip()
                catalog = match.split(".")[0] if match else None
                schema = match.split(".")[1] if match else None
                spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{schema}")
                print(f"created schema {catalog}.{schema} successfully\n")
                spark.sql(query)
                continue
        print(f"COMPLETED RUNNING DDL QUERY:\n{query}\n")

In [0]:
def get_general_data_quality_rules(table_name, columns=None):
    """
    Generates a generic set of data quality SQL rules for a given data contract.

    These rules include:
    1. A row count check to ensure the table contains data.
    2. A uniqueness check across specified columns to ensure no duplicate rows.
    Args:
        table_name (str): The name of the table for which to create rules.
        columns (list, optional): List of column names to use for the duplicate check.
                                  If None or empty, the duplicate rule will be skipped or invalid.
    Returns:
        list: A list of data quality rule dictionaries formatted for use in a data contract.
    """
    partition_by_clause = ", ".join(columns) if columns else ""
    
    general_data_quality_rules = {
        f"{table_name}": {
            "quality": [
                {
                    "type": "sql",
                    "description": f"Ensures '{table_name}' table has data",
                    "query": f"SELECT COUNT(*) FROM {table_name}",
                    "mustBeGreaterThan": 0
                }
            ]
        }
    }

    # Add duplicate check only if valid columns are provided
    if partition_by_clause:
        general_data_quality_rules[table_name]["quality"].append(
            {
                "type": "sql",
                "description": f"Ensure '{table_name}' table has no duplicate rows across all columns",
                "query": f"""
                    SELECT COUNT(*)
                    FROM (
                        SELECT *, COUNT(*) OVER (PARTITION BY {partition_by_clause}) AS row_count
                        FROM {table_name}
                    ) AS subquery
                    WHERE row_count > 1
                """,
                "mustBe": 0
            }
        )

    # Clean up SQL formatting (flatten multi-line SQL to single-line strings)
    for dq_rule in general_data_quality_rules[table_name]["quality"]:
        dq_rule["query"] = ' '.join(dq_rule["query"].split())
    return general_data_quality_rules[table_name]["quality"]

In [0]:
def get_custom_data_quality_rules(table_name, custom_dq_rules_input):
    """
    Generates a custom set of data quality SQL rules for a given data contract table.
    These rules are specific to the `customer` table and include:
    1. A row count check to ensure the table does not exceed 100 records.
    2. A null check to ensure all customers have an email.
    3. A null check to ensure all customers have both first and last names.
    Args:
        table_name (str): The name of the table for which to generate custom rules.
    Returns:
        list: A list of data quality rule dictionaries formatted for use in a data contract.
    Raises:
        KeyError: If no custom rules are defined for the specified table.
    """

    custom_data_quality_rules = {}
    for rule in custom_dq_rules_input:
        custom_data_quality_rules.update(rule)

    if table_name not in custom_data_quality_rules:
        raise KeyError(f"No custom data quality rules defined for table: {table_name}")
    return custom_data_quality_rules[table_name]["quality"]

In [0]:
# Function to extract properties from a table
def get_table_properties(table):
    """
    Extracts properties from a given table.

    Args:
        table (dict): The table definition containing fields.

    Returns:
        list: A list of property dictionaries for the table.
    """
    # Extract fields from the table
    fields = table.get("fields", [])
    properties = []
    
    # Iterate over each field in the table
    for field in fields:
        property = {}
        # Iterate over each key-value pair in the field
        for key, value in field.items():
            if key == "tags":
                # Convert tags to the required format
                property["tags"] = [f"{k}:{v}" for k, v in value.items()] if value else []
            elif key == "type":
                # Set both logical and physical types to the field type
                property["logicalType"] = property["physicalType"] = value
            elif key == "unique":
                # Set the unique property
                property[key] = value
            elif key == "required":
                # Set the nullable property
                property[key] = value
            else:
                # Set any other properties
                property[key] = value if value else ""
        # Add the property to the list of properties
        properties.append(property)
    # Return the list of properties
    return properties

In [0]:
def get_table_data_quality_rules(table):
    """
    Retrieves the data quality rules defined for a given table.
    
    Args:
        table (dict): The table definition containing data quality rules.
    
    Returns:
        list: A list of data quality rules for the table, defaulting to an empty list if not found.
    """
    # Get the data quality rules from the table, defaulting to an empty list if not specified
    return table.get("quality", [])

In [0]:
def build_odcs_schema(schema, odcs_schema):
    """
    Builds the ODCS schema by extracting tables from the input schema and 
    converting them into the ODCS table format.

    Args:
        schema (dict): The input schema containing tables.
        odcs_schema (list): The list to store the built ODCS schema.

    Returns:
        None
    """

    # Extract tables from the schema
    schema_tables = schema.get("tables", [])  # Get the tables from the schema, defaulting to an empty list if not found

    # Iterate over each table in the schema
    for table in schema_tables:
        # Get the table name
        table_name = table.get("name", "")  # Get the table name, defaulting to an empty string if not found

        # Get the table tags
        table_tags = table.get("tags", {})  # Get the table tags, defaulting to an empty dictionary if not found

        # Extract properties from the table
        properties = get_table_properties(table)  # Extract properties from the table using the get_table_properties function
        table_quality_rules = get_table_data_quality_rules(table)  # Get the table quality rules using the get_table_data_quality_rules function
        generic_quality_rules = get_general_data_quality_rules(table_name)  # Get the generic quality rules using the get_general_data_quality_rules function
        quality_rules = generic_quality_rules + table_quality_rules  # Combine the generic and table quality rules

        # Create the ODCS table object
        odcs_table = {
            # Get the logical type, defaulting to 'object' if not specified
            "logicalType": table.get("logicalType", "object"),  
            # Get the table name
            "name": table_name,  
            # Get the physical name, defaulting to the table name if not specified
            "physicalName": table.get("physicalName", table_name),  
            # Get the physical type, defaulting to 'object' if not specified
            "physicalType": table.get("physicalType", "object"),  
            # Convert table tags to the required format
            "tags": [f"{k}:{v}" for k, v in table_tags.items()] if table_tags else [],  
            # Add the table properties
            "properties": properties,  
            "quality": quality_rules  # Add the combined quality rules
        }
        # Append the ODCS table to the schema
        odcs_schema.append(odcs_table)  # Add the ODCS table to the odcs_schema list

In [0]:
def convert_json_to_odcs(json_file, odcs_template_file, host, environment, type = "databricks"):
    """
    Converts JSON schema to ODCS (Open Data Catalog Specification) format.

    Args:
        json_file (str): The path to the JSON file containing the schema.
        odcs_template_file (str): The path to the ODCS template YAML file.
        host (str): The host URL.
        environment (str): The environment name.
        type (str): environment type

    Returns:
        dict: The ODCS schema in dictionary format.
    """
    # Initialize the schema JSON variable
    schema_json = None
    
    # Load the JSON schema from the file
    with open(json_file, "r") as file:
        schema_json = json.load(file)

    # Load the ODCS template from the YAML file
    with open(odcs_template_file, "r") as file:
        odcs_template = yaml.safe_load(file)

    # Extract models from the schema JSON
    models = schema_json.get("models", [])
    
    # Initialize lists for servers, ODCS schema, and schema tags
    servers = []
    odcs_schema = []
    stags = []

    # Iterate over each model in the schema
    for model in models:
        # Extract catalog name and tags from the model
        catalog_name = model.get("catalog", "")
        catalog_tags = model.get("tags", {})

        # Extract schema from the model
        schema = model.get("schema", {})
        schema_name = schema.get("name", "")
        schema_tags = schema.get("tags")

        # Convert schema tags to key-value pairs if present
        if schema_tags is not None:
            # Iterate over each tag and append to the schema tags list
            for k, v in schema_tags.items():
                stags.append(f"{k}:{v}")

        # Create a server entry for the model
        servers.append(
            {
                "catalog": catalog_name,
                "host": host,
                "type": type,
                "schema": schema_name,
                "server": environment,
            }
        )

        # Build the ODCS schema for the model
        build_odcs_schema(schema, odcs_schema)

    # Update the ODCS template with the built schema, tags, and servers
    odcs_template["schema"] = odcs_schema
    odcs_template["tags"] = stags
    odcs_template["servers"] = servers
    
    # Return the updated ODCS template
    return odcs_template