# Imports

In [None]:
import io, logging, json
import pytz
import pandas as pd
from datetime import datetime, timedelta, timezone
from google.cloud import secretmanager
from google.cloud import storage
import snowflake.connector as sc
from snowflake.connector.pandas_tools import write_pandas
from cryptography.hazmat.primitives import serialization

# Function: Fetch Secrets

In [None]:
def fetch_secrets(
        project_id,
        secret_id,
        version_id
    ):
    """
    Access a secret from Google Secret Manager

    Args:
        project_id: Your Google Cloud project ID
        secret_id: The ID of the secret to access
        version_id: The version of the secret (default: "latest")

    Returns:
        The secret payload as a string
    """
    # Create the Secret Manager client
    client = secretmanager.SecretManagerServiceClient()

    # Build the resource name of the secret version
    name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}"

    # Access the secret version
    response = client.access_secret_version(request={"name": name})

    # Decode and parse the JSON payload
    secret_payload = response.payload.data.decode("UTF-8")

    try:
        return json.loads(secret_payload)  # Convert string to JSON
    except json.JSONDecodeError:
        raise ValueError("The secret payload is not a valid JSON")

# Util Functions

### Function: Setup Logger

In [None]:
def setup_logger(
    log_file
):
    """
    Sets up a logger that writes to a log file, console, and Google Cloud Logging.

    Args:
        log_file (str): Path of the log file.

    Returns:
        logger: Configured logger instance.
    """
    try:
        logger = logging.getLogger(log_file)
        logger.setLevel(logging.INFO)
        logger.propagate = False  # Prevent duplicate logs

        # Remove any existing handlers (to prevent duplicate logging)
        if logger.hasHandlers():
            logger.handlers.clear()

        if not logger.handlers:  # Avoid adding multiple handlers
            formatter = logging.Formatter(
                '%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S'
            )

            # File Handler
            file_handler = logging.FileHandler(log_file)
            file_handler.setLevel(logging.INFO)
            file_handler.setFormatter(formatter)
            logger.addHandler(file_handler)

            # Console Handler
            console_handler = logging.StreamHandler()
            console_handler.setLevel(logging.INFO)
            console_handler.setFormatter(formatter)
            logger.addHandler(console_handler)

        return logger

    except Exception as e:
        print(f"Failed to initialize logger: {e}")
        return None

### Function: Handle Exceptions

In [None]:
def handle_exception(
    file_id,
    vai_gcs_bucket,
    run_folder,
    error_folder,
    error_message,
    logger
):
    """
    Logs the error, appends the file_id to error tracking CSV, and triggers a notification.
    """
    try:
        error_df_path = f"{error_folder}/{run_folder}_errors.csv"

        logger.error(f"Error processing file {file_id}: {error_message}")

        gcs_client = storage.Client()
        bucket = gcs_client.bucket(vai_gcs_bucket)
        blob = bucket.blob(error_df_path)

        if blob.exists():
            error_df = pd.read_csv(f"gs://{vai_gcs_bucket}/{error_df_path}")
        else:
            error_df = pd.DataFrame(columns=["File_ID", "Error_Message"])

        error_df = pd.concat([error_df, pd.DataFrame([{"File_ID": file_id, "Error_Message": error_message}])], ignore_index=True)
        error_df.to_csv(f"gs://{vai_gcs_bucket}/{error_df_path}", index=False)
        logger.info(f"Logged error for file {file_id} in {error_df_path}")

    except Exception as e:
        logger.error(f"Failed to write to error tracking file: {e}")

# Functions

### Function: Read CSV from GCS Bucket

In [None]:
def read_gcs_csv(file_path):
    blob = bucket.blob(file_path)
    csv_data = blob.download_as_text()
    return pd.read_csv(io.StringIO(csv_data))

### Function: Insert New Records

In [None]:
def insert_new_records(
    pipeline_run_name,
    vai_gcs_bucket,
    gcs_stagging_folder,
    gcs_errored_folder,
    snf_account,
    snf_user,
    snf_private_key,
    snf_private_key_pwd,
    snf_warehouse,
    snf_databse,
    snf_schema,
    table_name,
    df
):
    """
    Inserts only new records (based on ID) into Snowflake table with UTC load timestamp.

    Steps:
    1. Fetches existing IDs from table.
    2. Filters out rows with existing IDs from DataFrame.
    3. Adds 'LOAD_DATE_UTC' column with current UTC timestamp.
    4. Inserts only new records.

    Args:
        conn: Snowflake connection object.
        table_name (str): Name of the target table.
        df (pd.DataFrame): DataFrame containing the data (must have 'CONTACT_ID' column).

    Returns:
        int: Number of inserted records.
    """

    """
    Fetch Category-Subcategory mapping from Snowflake using a private key stored in GCP Secret Manager.

    :param snf_secret_project_id: GCP project where the secret is stored.
    :param secret_name: Name of the secret containing the Snowflake private key.
    :param snowflake_params: Dictionary containing Snowflake connection parameters.

    :return: Pandas DataFrame with category mappings.
    """

    try:
        # Step 1: Load & Decrypt the Private Key
        snf_private_key = serialization.load_pem_private_key(
            snf_private_key.encode(),
            password=snf_private_key_pwd.encode(),
            backend=None  # Default backend
        )

        # Step 2: Convert to Snowflake Compatible Format
        pkey_bytes = snf_private_key.private_bytes(
            encoding=serialization.Encoding.DER,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption(),
        )

        conn_params = {
            'account': snf_account,
            'user': snf_user,
            'private_key': snf_private_key,
            'warehouse': snf_warehouse,
            'database': snf_databse,
            'schema': snf_schema
        }

        conn = sc.connect(**conn_params)
        cursor = conn.cursor()

        # Step 1: Get existing IDs from Snowflake table for the last two days (current and previous day)
        cursor.execute(f"""
            SELECT DISTINCT(CONTACT_ID) 
            FROM {table_name} 
            WHERE LOAD_DATE >= DATEADD(DAY, -1, CURRENT_DATE)
        """)
        existing_ids = {row[0] for row in cursor.fetchall()}

        # Step 2: Filter DataFrame to keep only new records
        new_records_df = df[~df['CONTACT_ID'].isin(existing_ids)]

        if new_records_df.empty:
            logger.info("No new records to insert")
            return 0

        # Step 3: Add UTC timestamp column
        utc_now = datetime.now(pytz.utc).strftime('%Y-%m-%d %H:%M:%S')
        new_records_df = new_records_df.copy()  # Avoid modifying original df
        new_records_df["LOAD_DATE"] = utc_now  # Add new column

        # Step 4: Insert new records into Snowflake
        success, nchunks, nrows, _ = write_pandas(conn, new_records_df, table_name)

        logger.info(f"Inserted {nrows} new records with UTC load date")
        logger.info(f"Skipped {len(df) - len(new_records_df)} existing records")

        cursor.close()
        conn.close()
        return nrows

    except Exception as e:
        handle_exception("N/A", vai_gcs_bucket, pipeline_run_name, f"{pipeline_run_name}/Errored", str(e), logger)

# Main Function

In [None]:
try:
    configs = fetch_secrets(
        project_id,
        secret_id,
        version_id
    )

    # GCP Configuration
    gcp_project_id = configs.get("VAI_GCP_PROJECT_ID")
    gcp_project_location = configs.get("GCP_PROJECT_LOCATION")
    vai_gcs_bucket = configs.get("VAI_GCP_PIPELINE_BUCKET")

    # Pipeline Configuration
    gcs_stagging_folder = f"{pipeline_run_name}/Stagging"
    gcs_errored_folder = f"{pipeline_run_name}/Errored"
    gcs_logs_folder = f"{pipeline_run_name}/Logs"
    gcs_transcripts_folder = f"{pipeline_run_name}/Transcripts"
    gcs_intra_call_dfs_folder = f"{pipeline_run_name}/Stagging/IntraCallDFs"
    gcs_inter_call_dfs_folder = f"{pipeline_run_name}/Stagging/InterCallDFs"

    # Snowflake Configuration
    snf_account = configs.get("VAI_SNF_ACCOUNT")
    snf_user = configs.get("VAI_SNF_USER")
    snf_private_key = configs.get("private_key")
    snf_private_key_pwd = configs.get("VAI_SNF_PRIVATE_KEY_PWD")
    snf_warehouse = configs.get("VAI_SNF_WAREHOUSE")
    snf_database = configs.get("VAI_SNF_DATABASE")
    snf_schema = configs.get("VAI_SNF_SCHEMA")

    # Step 2: Download Master Log File from GCS
    log_file = f"{pipeline_run_name}.logs"
    client = storage.Client()
    bucket = client.bucket(vai_gcs_bucket)
    blob = bucket.blob(f"{gcs_logs_folder}/{log_file}")
    # Download master log file
    blob.download_to_filename(log_file)

    logger = setup_logger(log_file)
    logger.info("")
    logger.info("")
    logger.info("============================================================================")
    logger.info("COMPONENT: Write Data to Snowflake.")
    logger.info("============================================================================")
    logger.info("Fetched Master Log File from GCS bucket.")

    # Read Inter & Intra Call DataFrames
    inter_call_df = read_gcs_csv(f"{gcs_stagging_folder}/master_inter_call_df.csv")
    inter_call_df.columns = inter_call_df.columns.str.upper() # For snowflake Schema matching
    intra_call_df = read_gcs_csv(f"{gcs_stagging_folder}/master_intra_call_df.csv")
    intra_call_df.columns = intra_call_df.columns.str.upper() # For snowflake Schema matching

    logger.info(f"Started: writing data to snowflake.")
    table_name ='SRC_GCP_INTER_CALLS'    
    logger.info(f"Writing data to table: {snf_database}.{table_name}")
    insert_new_records(
        pipeline_run_name,
        vai_gcs_bucket,
        gcs_stagging_folder,
        gcs_errored_folder,
        snf_account,
        snf_user,
        snf_private_key,
        snf_private_key_pwd,
        snf_warehouse,
        snf_database,
        snf_schema,
        table_name,
        inter_call_df
    )
    logger.info(f"SRC_GCP_INTER_CALLS: Inserted records #{len(inter_call_df)}")


    logger.info(f"Writing data to table: {snf_database}.{table_name}")
    table_name ='SRC_GCP_INTRA_CALLS'
    insert_new_records(
        pipeline_run_name,
        vai_gcs_bucket,
        gcs_stagging_folder,
        gcs_errored_folder,
        snf_account,
        snf_user,
        snf_private_key,
        snf_private_key_pwd,
        snf_warehouse,
        snf_database,
        snf_schema,
        table_name,
        intra_call_df
    )
    logger.info(f"SRC_GCP_INTRA_CALLS: Inserted records #{len(intra_call_df)}")
    logger.info(f"Completed: writing data to snowflake.")

except Exception as e:
        handle_exception("N/A", vai_gcs_bucket, pipeline_run_name, f"{pipeline_run_name}/Errored", str(e), logger)