# Imports

In [1]:
from google.cloud import secretmanager
from datetime import datetime, timezone, UTC
import kfp
from kfp import dsl, compiler, components
import json, logging
from google.cloud import aiplatform
from google.cloud.aiplatform import pipeline_jobs
from google.cloud import logging as cloud_logging

import warnings
warnings.filterwarnings("ignore", message="Skipping checksum validation")

# Component: Listing new Transcripts

In [2]:
@dsl.component(
    base_image=f"us-central1-docker.pkg.dev/dev-posigen/dev-voiceai/dev-voice-ai-docker-image:dev-4"
)
def list_download_calls_s3_to_gcs(
    pipeline_run_name: str,
    project_id: str,
    secret_id: str,
    version_id: str
):
    import boto3
    import pandas as pd
    import logging, json
    from google.cloud import secretmanager
    from google.cloud import storage
    from google.cloud import logging as cloud_logging
    from datetime import datetime, timedelta, timezone, UTC
    from concurrent.futures import ThreadPoolExecutor, as_completed

    import warnings
    warnings.filterwarnings("ignore", message="Skipping checksum validation")
    warnings.filterwarnings("ignore", category=UserWarning)

    """
    ========================================================
    Function Definitions
    ========================================================
    """
    def fetch_secrets(project_id, secret_id, version_id="latest"):
        """
        Access a secret from Google Secret Manager

        Args:
            project_id: Your Google Cloud project ID
            secret_id: The ID of the secret to access
            version_id: The version of the secret (default: "latest")

        Returns:
            The secret payload as a string
        """
        # Create the Secret Manager client
        client = secretmanager.SecretManagerServiceClient()

        # Build the resource name of the secret version
        name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}"

        # Access the secret version
        response = client.access_secret_version(request={"name": name})

        # Decode and parse the JSON payload
        secret_payload = response.payload.data.decode("UTF-8")

        try:
            return json.loads(secret_payload)  # Convert string to JSON
        except json.JSONDecodeError:
            raise ValueError("The secret payload is not a valid JSON")

    def setup_logger(log_file):
        """
        Sets up a logger that writes to a log file, console, and Google Cloud Logging.

        Args:
            log_file (str): Path of the log file.

        Returns:
            logger: Configured logger instance.
        """
        try:
            logger = logging.getLogger("vertex_pipeline_logger")
            logger.setLevel(logging.INFO)
            logger.propagate = False  # Prevent duplicate logs

            if not logger.handlers:  # Avoid adding multiple handlers
                formatter = logging.Formatter(
                    '%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S'
                )

                # File Handler
                file_handler = logging.FileHandler(log_file)
                file_handler.setLevel(logging.INFO)
                file_handler.setFormatter(formatter)
                logger.addHandler(file_handler)

                # Console Handler
                console_handler = logging.StreamHandler()
                console_handler.setLevel(logging.INFO)
                console_handler.setFormatter(formatter)
                logger.addHandler(console_handler)

            return logger

        except Exception as e:
            print(f"Failed to initialize logger: {e}")
            return None

    def handle_exception(
        file_id,
        vai_gcs_bucket,
        run_folder,
        error_folder,
        error_message
    ):
        """
        Logs the error, appends the file_id to error tracking CSV, and triggers a notification.
        """
        try:
            error_df_path = f"{error_folder}/{run_folder}_errors.csv"

            logger.error(f"Error processing file {file_id}: {error_message}")

            gcs_client = storage.Client()
            bucket = gcs_client.bucket(vai_gcs_bucket)
            blob = bucket.blob(error_df_path)

            if blob.exists():
                error_df = pd.read_csv(f"gs://{vai_gcs_bucket}/{error_df_path}")
            else:
                error_df = pd.DataFrame(columns=["File_ID", "Error_Message"])

            error_df = pd.concat([error_df, pd.DataFrame([{"File_ID": file_id, "Error_Message": error_message}])], ignore_index=True)
            error_df.to_csv(f"gs://{vai_gcs_bucket}/{error_df_path}", index=False)
            logger.info(f"Logged error for file {file_id} in {error_df_path}")

        except Exception as e:
            logger.error(f"Failed to write to error tracking file: {e}")


    def generate_gcs_folders(    
        pipeline_run_name,
        vai_gcs_bucket
    ):
        try:
             # Setup logger
            logging.info("Started: generating GCS pipeline folders.")
            gcs_folders = {}
            gcs_folders['gcs_staging_folder'] = f"{pipeline_run_name}/Stagging"
            gcs_folders['gcs_intra_call_dfs_folder'] = f"{pipeline_run_name}/Stagging/IntraCallDFs"
            gcs_folders['gcs_inter_call_dfs_folder'] = f"{pipeline_run_name}/Stagging/InterCallDFs"
            gcs_folders['gcs_transcripts_folder'] = f"{pipeline_run_name}/Transcripts"
            gcs_folders['gcs_errored_folder'] = f"{pipeline_run_name}/Errored"
            gcs_folders['gcs_logs_folder'] = f"{pipeline_run_name}/Logs"

            # Initialize GCS Client
            gcs_client = storage.Client()
            bucket = gcs_client.bucket(vai_gcs_bucket)

            # Create empty folders directly
            for folder in gcs_folders.values():
                blob = bucket.blob(f"{folder}/")
                blob.upload_from_string("", content_type="application/x-www-form-urlencoded")
                logging.info(f"Created folder: {folder}")

            logging.info("Completed: generating GCS pipeline folders.")
            return gcs_folders

        except Exception as e:
            handle_exception("N/A", vai_gcs_bucket, pipeline_run_name, f"{pipeline_run_name}/Errored", str(e))


    def generate_s3_folder_prefix(
        pipeline_run_name,
        vai_gcs_bucket,
        gcs_errored_folder
    ):
        try:
            logger.info("Started: generating S3 folder prefix.")
            # Get current date and time
            current_datetime = datetime.now()

            # Check if the run is around midnight (e.g., between 00:00 and 01:00)
            if current_datetime.hour == 0:
                adjusted_datetime = current_datetime - timedelta(days=1)  # Move to the previous day
            else:
                adjusted_datetime = current_datetime  # Keep the current day

            # Extract year, month, and day from the adjusted date
            year = str(adjusted_datetime.year)
            month = f"{adjusted_datetime.month:02d}"
            day = f"{adjusted_datetime.day:02d}"

            # Construct the prefix for S3 listing
            prefix = f"{year}/{month}/{day}/"
            logger.info("Completed: generating S3 folder prefix {prefix}.")

            return prefix

        except Exception as e:
            handle_exception("N/A", vai_gcs_bucket, pipeline_run_name, gcs_errored_folder, str(e))


    def get_list_calls_to_process(
        pipeline_run_name,
        vai_gcs_bucket,
        gcs_staging_folder,
        gcs_errored_folder,
        aws_access_key,
        aws_secret_key,
        s3_analysis_bucket,
        s3_transcripts_location,
        s3_prefix,
    ):
        try:
            logger.info(f"Started: listing calls from: {s3_transcripts_location}/{s3_prefix}")
            # Initialize S3 Client
            s3_client = boto3.client(
                's3',
                aws_access_key_id=aws_access_key,
                aws_secret_access_key=aws_secret_key
            )

            all_files = []
            paginator = s3_client.get_paginator('list_objects_v2')
            pages = paginator.paginate(Bucket=s3_analysis_bucket, Prefix=f"{s3_transcripts_location}/{s3_prefix}")

            # Get current UTC time (timezone-aware)
            current_time = datetime.now(timezone.utc)
            # Calculate the time threshold (2 hours before the current time)
            time_threshold = current_time - timedelta(hours=time_interval)
            logger.info(f"Fetching Calls between: {time_threshold.time()} and {current_time.time()}")

            all_files = []

            for page in pages:
                for obj in page.get('Contents', []):
                    file_path = obj['Key']
                    s3_ts = obj['LastModified']

                    # Extract timestamp from filename
                    try:
                        # Skip non-JSON files
                        if file_path.endswith('.json'):
                            call_id = file_path.split('/')[-1].split("_analysis_")[0]
                            call_timestamp = pd.to_datetime(file_path.split('analysis_')[-1].split('.')[0].replace('Z', ""), utc=True)

                            # Compare only the time part
                            if call_timestamp.time() <= time_threshold.time():
                                all_files.append({
                                    'File': file_path,
                                    'Call_ID': call_id,
                                    'File_Timestamp': call_timestamp,
                                    'File_Date': call_timestamp.date().strftime('%Y-%m-%d'),
                                    'File_Time': call_timestamp.time().strftime('%H:%M:%S'),
                                    'S3_Timestamp': s3_ts,
                                    'S3_Date': s3_ts.strftime('%Y-%m-%d'),
                                    'S3_Time': s3_ts.strftime('%H:%M:%S')
                                })
                    except Exception as e:
                        logger.warning(f"Skipping file {file_path} due to timestamp parsing error: {e}")
                        continue

            if all_files:
                df_calls_list = pd.DataFrame(all_files).sort_values(['File_Timestamp'], ascending=False)
                df_calls_list['Time_Bin'] = df_calls_list['File_Timestamp'].dt.floor('2h')
                # Subset the DataFrame for only the most recent 2 hours bin
                df_calls_list = df_calls_list[df_calls_list['Time_Bin'] == df_calls_list['Time_Bin'].max()]
                logger.info(f"Files to process for the last 2 hours: {len(df_calls_list)}")

                # Write the DataFrame to GCS
                logger.info(f"Files to process for the last 2 hours: {len(df_calls_list)}")
                csv_path = f"gs://{vai_gcs_bucket}/{gcs_folders['gcs_staging_folder']}/{pipeline_run_name}_transcripts_to_process.csv"
                df_calls_list.to_csv(csv_path, index=False)
                logger.info(f"Written Transcripts list to GCS: {csv_path}")
                logger.info(f"Completed: listing calls to process Calls#: {len(df_calls_list)}")

                return df_calls_list

            else:
                logger.info(f"0 Files fetched.")
                return pd.DataFrame()

        except Exception as e:
            handle_exception("N/A", vai_gcs_bucket, pipeline_run_name, gcs_errored_folder, str(e))

    def download_transcripts_to_gcs(
        file,
        pipeline_run_name,
        vai_gcs_bucket,
        gcs_staging_folder,
        gcs_errored_folder,
        gcs_transcripts_folder,
        s3_client,
        s3_analysis_bucket
    ):
        """Download transcript from S3 and upload to GCS."""

        local_file_path = f"/tmp/{file.split('/')[-1]}"  # Temporary local storage
        gcs_blob_path = f"{gcs_transcripts_folder}/{file.split('/')[-1]}"
        gcs_bucket = storage.Client().bucket(vai_gcs_bucket)

        try:
            # Download file from S3
            s3_client.download_file(s3_analysis_bucket, file, local_file_path)

            # Upload to GCS
            blob = gcs_bucket.blob(gcs_blob_path)
            blob.upload_from_filename(local_file_path, checksum=None)

            return file, None

        except Exception as e:
            logger.error(f"Error: Failed to process {file} -> {str(e)}")
            handle_exception(file, vai_gcs_bucket, pipeline_run_name, gcs_errored_folder, str(e))
            return None, file


    """
    ========================================================
    Variables
    ========================================================
    """
    configs = fetch_secrets(
        project_id,
        secret_id,
        version_id
    )

    log_file = f"{pipeline_run_name}.logs"  
    vai_gcs_bucket = configs.get("VAI_GCP_PIPELINE_BUCKET")
    aws_access_key = configs.get("VAI_AWS_ACCESS_KEY")
    aws_secret_key = configs.get("VAI_AWS_SECRET_KEY")
    s3_analysis_bucket = configs.get("VAI_S3_ANALYSIS_BUCKET")
    s3_transcripts_location = configs.get("VAI_S3_TRANSCRIPTS_LOCATION")
    time_interval = 3

    """
    ========================================================
    Function Calling
    ========================================================
    """
    logger = setup_logger(log_file)

    gcs_folders = generate_gcs_folders(
        pipeline_run_name,
        vai_gcs_bucket
    )

    gcs_staging_folder = gcs_folders['gcs_staging_folder']
    gcs_transcripts_folder = gcs_folders['gcs_transcripts_folder']
    gcs_errored_folder = gcs_folders['gcs_errored_folder']
    gcs_logs_folder = gcs_folders['gcs_logs_folder']

    s3_prefix = generate_s3_folder_prefix(
        pipeline_run_name,
        vai_gcs_bucket,
        gcs_errored_folder
    )

    df_calls_list = get_list_calls_to_process(
        pipeline_run_name,
        vai_gcs_bucket,
        gcs_staging_folder,
        gcs_errored_folder,
        aws_access_key,
        aws_secret_key,
        s3_analysis_bucket,
        s3_transcripts_location,
        s3_prefix
    )

    if len(df_calls_list)>0:
        files_list = df_calls_list.File.to_list()
        s3_client = boto3.client(
            "s3", 
            aws_access_key_id=aws_access_key, 
            aws_secret_access_key=aws_secret_key
        )

        success_downloads = []
        failed_downloads = []

        with ThreadPoolExecutor(max_workers=5) as executor:
            logger.info(f"Started: bulk download to GCS transcripts#: {len(files_list)}")

            future_to_file = {
                executor.submit(
                    download_transcripts_to_gcs,
                    file,
                    pipeline_run_name,
                    vai_gcs_bucket,
                    gcs_staging_folder,
                    gcs_errored_folder,
                    gcs_transcripts_folder,
                    s3_client,
                    s3_analysis_bucket
                ): file for file in files_list
            }

            for future in as_completed(future_to_file):
                try:
                    success, failed = future.result()  # Get results

                    if success:
                        success_downloads.append(success)
                    if failed:
                        failed_downloads.append(failed)

                except Exception as e:
                    logger.error(f"Unexpected Error: {str(e)}")
                    handle_exception("N/A", vai_gcs_bucket, pipeline_run_name, gcs_errored_folder, str(e))
        # Upload logs to GCS Bucket:
        gcs_bucket = storage.Client().bucket(vai_gcs_bucket)
        blob = gcs_bucket.blob(f"{gcs_logs_folder}/{log_file}")
        blob.upload_from_filename(log_file, checksum=None)

        logger.info(f"Completed: bulk download to GCS transcripts, Success#: {len(success_downloads)}, Failed#: {len(failed_downloads)}")

    else:
        logger.info("No Calls to Process.")
        # Upload logs to GCS Bucket:
        gcs_bucket = storage.Client().bucket(vai_gcs_bucket)
        blob = gcs_bucket.blob(f"{gcs_logs_folder}/{log_file}")
        blob.upload_from_filename(log_file, checksum=None)

# Component: Parallel Process Call Transcripts

In [5]:
# @dsl.component(
#     base_image=f"us-central1-docker.pkg.dev/dev-posigen/dev-voiceai/dev-voice-ai-docker-image:dev-4"
# )
# def process_transcripts(
#     pipeline_run_name: str,
#     project_id: str,
#     secret_id: str,
#     version_id: str
# ):
from concurrent.futures import ThreadPoolExecutor
import threading
import concurrent.futures

import pandas as pd
import numpy as np
from scipy.special import softmax
import logging
import re, os, json, io
from datetime import datetime, timezone, UTC
from typing import List, Dict

from pydantic import BaseModel, Field, ValidationError

import snowflake.connector as sc
from cryptography.hazmat.primitives import serialization

from google.cloud import secretmanager
from google.cloud import storage
from google.cloud import dlp_v2
from google.cloud import logging as cloud_logging

import vertexai
import vertexai.preview.generative_models as generative_models
from vertexai.generative_models import GenerativeModel, GenerationConfig, Part

# Sentiments
from transformers import pipeline
from transformers import AutoTokenizer, AutoConfig
from transformers import AutoModelForSequenceClassification

os.environ["TOKENIZERS_PARALLELISM"] = "false"

MODEL = f"cardiffnlp/twitter-roberta-base-sentiment-latest"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
config = AutoConfig.from_pretrained(MODEL)
model_sentiment = AutoModelForSequenceClassification.from_pretrained(MODEL)

# Initialize Google Cloud Logging client
cloud_logging_client = cloud_logging.Client()
cloud_logging_client.setup_logging()

"""
========================================================
Function: Exception hanlding mechanism
========================================================
"""
def fetch_secrets(
    project_id,
    secret_id,
    version_id="latest"
):
    """
    Access a secret from Google Secret Manager

    Args:
        project_id: Your Google Cloud project ID
        secret_id: The ID of the secret to access
        version_id: The version of the secret (default: "latest")

    Returns:
        The secret payload as a string
    """
    # Create the Secret Manager client
    client = secretmanager.SecretManagerServiceClient()

    # Build the resource name of the secret version
    name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}"

    # Access the secret version
    response = client.access_secret_version(request={"name": name})

    # Decode and parse the JSON payload
    secret_payload = response.payload.data.decode("UTF-8")

    try:
        return json.loads(secret_payload)  # Convert string to JSON
    except json.JSONDecodeError:
        raise ValueError("The secret payload is not a valid JSON")
        
        
def setup_logger(
    log_file
):
    """
    Sets up a logger that writes to a log file, console, and Google Cloud Logging.

    Args:
        log_file (str): Path of the log file.

    Returns:
        logger: Configured logger instance.
    """
    try:
        logger = logging.getLogger("vertex_pipeline_logger")
        logger.setLevel(logging.INFO)
        logger.propagate = False  # Prevent duplicate logs

        if not logger.handlers:  # Avoid adding multiple handlers
            formatter = logging.Formatter(
                '%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S'
            )

            # File Handler
            file_handler = logging.FileHandler(log_file)
            file_handler.setLevel(logging.INFO)
            file_handler.setFormatter(formatter)
            logger.addHandler(file_handler)

            # Console Handler
            console_handler = logging.StreamHandler()
            console_handler.setLevel(logging.INFO)
            console_handler.setFormatter(formatter)
            logger.addHandler(console_handler)

        return logger

    except Exception as e:
        print(f"Failed to initialize logger: {e}")
        return None
    
# Function to create thread-specific log files
def setup_thread_logger(
    contact_id
):
    """Create a separate log file for each transcript."""
    timestamp = datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    log_filename = f"{contact_id}_{timestamp}.log"
    log_filepath = os.path.join(temp_log_folder, log_filename)

    thread_logger = logging.getLogger(log_filename)
    thread_logger.setLevel(logging.INFO)
    
    formatter = logging.Formatter(
        '%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S'
    )

    # Remove handlers to prevent duplication
    if thread_logger.hasHandlers():
        thread_logger.handlers.clear()

    file_handler = logging.FileHandler(log_filepath)
    file_handler.setFormatter(formatter)
    thread_logger.addHandler(file_handler)

    return thread_logger, log_filepath


def merge_logs(
    log_files,
    master_log_file,
    master_logger
):
    """Merge all thread logs into master log file."""
    sorted_logs = sorted(log_files)  # Sort logs based on filename timestamps

    with open(master_log_file, "a") as master_log:
        for log_file in sorted_logs:
            with open(log_file, "r") as thread_log:
                master_log.write(thread_log.read() + "\n")

    master_logger.info(f"All thread logs merged into: {master_log_file}")


def handle_exception(
    file_id,
    vai_gcs_bucket,
    run_folder,
    error_folder,
    error_message,
    logger
):
    """
    Logs the error, appends the file_id to error tracking CSV, and triggers a notification.
    """
    try:
        error_df_path = f"{error_folder}/{run_folder}_errors.csv"

        logger.error(f"Error processing file {file_id}: {error_message}")

        gcs_client = storage.Client()
        bucket = gcs_client.bucket(vai_gcs_bucket)
        blob = bucket.blob(error_df_path)

        if blob.exists():
            error_df = pd.read_csv(f"gs://{vai_gcs_bucket}/{error_df_path}")
        else:
            error_df = pd.DataFrame(columns=["File_ID", "Error_Message"])

        error_df = pd.concat([error_df, pd.DataFrame([{"File_ID": file_id, "Error_Message": error_message}])], ignore_index=True)
        error_df.to_csv(f"gs://{vai_gcs_bucket}/{error_df_path}", index=False)
        logger.info(f"Logged error for file {file_id} in {error_df_path}")

    except Exception as e:
        logger.error(f"Failed to write to error tracking file: {e}")

def fetch_transcripts_from_gcs(
    pipeline_run_name,
    vai_gcs_bucket,
    gcs_stagging_folder,
    gcs_errored_folder,
    gcs_transcripts_folder,
    master_logger
):
    """
    List all files in a GCS bucket, handling pagination.

    :param bucket_name: Name of the GCS bucket
    :param prefix: (Optional) Folder path to filter files
    :return: List of file paths
    """
    try:
        master_logger.info(f"Fetching Transcripts from GCS: {gcs_transcripts_folder}")
        client = storage.Client()
        bucket = client.bucket(vai_gcs_bucket)
        blobs_iterator = bucket.list_blobs(prefix=gcs_transcripts_folder)  # GCS handles pagination internally

        transcripts_list = []
        for page in blobs_iterator.pages:  # Handling pagination
            for blob in page:
                if not blob.name.endswith("/"):
                    transcripts_list.append(blob.name)
                    # transcripts_list.append(os.path.basename(blob.name))
        master_logger.info(f"Completed: Fetching from GCS Transcripts List#: {len(transcripts_list)}")
        return transcripts_list

    except Exception as e:
        handle_exception("N/A", vai_gcs_bucket, pipeline_run_name, gcs_errored_folder, str(e), master_logger)


def fetch_category_mapping_from_snowflake(
    pipeline_run_name,
    vai_gcs_bucket,
    gcs_stagging_folder,
    gcs_errored_folder,
    snf_account,
    snf_user,
    snf_private_key,
    snf_private_key_pwd,
    snf_warehouse,
    snf_catsubcat_databse,
    snf_catsubcat_schema,
    snf_catsubcat_view,
    master_logger
):
    """
    Fetch Category-Subcategory mapping from Snowflake using a private key stored in GCP Secret Manager.

    :param snf_secret_project_id: GCP project where the secret is stored.
    :param secret_name: Name of the secret containing the Snowflake private key.
    :param snowflake_params: Dictionary containing Snowflake connection parameters.

    :return: Pandas DataFrame with category mappings.
    """

    try:
        # Step 1: Load & Decrypt the Private Key
        snf_private_key = serialization.load_pem_private_key(
            snf_private_key.encode(),
            password=snf_private_key_pwd.encode(),
            backend=None  # Default backend
        )

        # Step 2: Convert to Snowflake Compatible Format
        pkey_bytes = snf_private_key.private_bytes(
            encoding=serialization.Encoding.DER,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption(),
        )

        # Step 3: Connect to Snowflake
        catsubcat_conn_params = {
            'account': snf_account,
            'user': snf_user,
            'private_key': snf_private_key,
            'warehouse': snf_warehouse,
            'database': snf_catsubcat_databse,
            'schema': snf_catsubcat_schema
        }

        # Connect to Snowflake
        conn = sc.connect(**catsubcat_conn_params)

        # Fetch data from Snowflake
        query = f"SELECT CATEGORY, SUBCATEGORY FROM {snf_catsubcat_view}"
        df = pd.read_sql(query, conn)
        conn.close()
        master_logger.info("Completed: Fetching Category, Sub-Category Mapping.")

        return df

    except Exception as e:
        handle_exception("N/A", vai_gcs_bucket, pipeline_run_name, gcs_errored_folder, str(e), master_logger)

"""
========================================================
Function: Create Dataframe: Intra Call 
========================================================
"""            
def mask_pii_in_captions(
    contact_id,
    df,
    project_id,
    thread_logger
):
    """
    Masks PII data in the 'caption' column of a pandas DataFrame using Google Cloud DLP API.

    Args:
        contact_id: Identifier for logging purposes
        df (pandas.DataFrame): DataFrame with a 'caption' column to process
        project_id (str): Your Google Cloud project ID

    Returns:
        pandas.DataFrame: DataFrame with masked PII in the 'caption' column
    """
    try:
        thread_logger.info(f"{contact_id}: Masking PII Data")

        # Create a copy of the DataFrame to avoid modifying the original
        masked_df = df.copy()

        # Add unique markers to each caption to identify them after processing
        masked_df['original_index'] = masked_df.index
        masked_df['marked_caption'] = masked_df.index.astype(str) + "|||SEPARATOR|||" + masked_df['caption'].astype(str)

        # Concatenate all captions for bulk processing
        all_captions = "\n===RECORD_BOUNDARY===\n".join(masked_df['marked_caption'])

        # Initialize DLP client
        dlp_client = dlp_v2.DlpServiceClient()

        # Specify the parent resource name
        parent = f"projects/{project_id}/locations/global"

        # Custom dictionary detector for PosiGen
        posigen_dictionary = {
            "info_type": {"name": "CUSTOM_DICTIONARY_POSIGEN"},
            "dictionary": {
                "word_list": {
                    "words": ["posigen", "Posigen", "PosiGen", "POSIGEN"]
                }
            }
        }

        # Configure inspection config with rule set for exclusions
        inspect_config = {
            "info_types": [
                {"name": "CREDIT_CARD_NUMBER"},
                {"name": "CREDIT_CARD_EXPIRATION_DATE"},
                {"name": "STREET_ADDRESS"},
                {"name": "IP_ADDRESS"},
                {"name": "DATE_OF_BIRTH"}
            ],
            "min_likelihood": dlp_v2.Likelihood.POSSIBLE,
            "custom_info_types": [posigen_dictionary],
            "rule_set": [
                {
                    "info_types": [{"name": "CUSTOM_DICTIONARY_POSIGEN"}],
                    "rules": [
                        {
                            "exclusion_rule": {
                                "matching_type": dlp_v2.MatchingType.MATCHING_TYPE_FULL_MATCH,
                                "dictionary": {
                                    "word_list": {
                                        "words": ["posigen", "Posigen", "PosiGen", "POSIGEN"]
                                    }
                                }
                            }
                        }
                    ]
                }
            ]
        }

        # Configure deidentification to use "[REDACTED]" instead of asterisks
        deidentify_config = {
            "info_type_transformations": {
                "transformations": [
                    {
                        "info_types": [
                            {"name": "CREDIT_CARD_NUMBER"},
                            {"name": "CREDIT_CARD_EXPIRATION_DATE"},
                            {"name": "STREET_ADDRESS"},
                            {"name": "IP_ADDRESS"},
                            {"name": "DATE_OF_BIRTH"}
                        ],
                        "primitive_transformation": {
                            "replace_config": {
                                "new_value": {"string_value": "[REDACTED]"}
                            }
                        }
                    }
                ]
            }
        }

        # Create deidentify request
        item = {"value": all_captions}

        # Call the DLP API
        try:
            response = dlp_client.deidentify_content(
                request={
                    "parent": parent,
                    "deidentify_config": deidentify_config,
                    "inspect_config": inspect_config,
                    "item": item,
                }
            )
        except Exception as e:
            thread_logger.error(f"{contact_id}: Error in DLP API call: {e}")
            return df  # Return original DataFrame if masking fails

        # Get processed content and split by record boundaries
        processed_content = response.item.value
        processed_records = processed_content.split("\n===RECORD_BOUNDARY===\n")

        # Create mapping from original indices to processed captions
        processed_dict = {}
        for record in processed_records:
            parts = record.split("|||SEPARATOR|||", 1)
            if len(parts) == 2:
                idx, content = parts
                processed_dict[int(idx)] = content

        # Update the DataFrame with redacted content
        masked_df['caption'] = masked_df.apply(
            lambda row: processed_dict.get(row['original_index'], row['caption']), 
            axis=1
        )

        # Additional processing to mask all digits with asterisks
        def mask_digits(text):
            """Replaces digits with asterisks while preserving '[REDACTED]' markers."""
            if not isinstance(text, str):
                return text
            parts = text.split("[REDACTED]")
            for i in range(len(parts)):
                parts[i] = re.sub(r'\d', '*', parts[i])
            return "[REDACTED]".join(parts)

        # Apply the digit masking function to each processed caption
        masked_df['caption'] = masked_df['caption'].apply(mask_digits)

        # Drop temporary columns
        masked_df.drop(['original_index', 'marked_caption'], axis=1, inplace=True)

        thread_logger.info(f"{contact_id}: Completed Masking PII Data")
        return masked_df

    except Exception as e:
        raise RuntimeError(f"mask_pii_in_captions() failed: {str(e)}")


def get_sentiment_label(row):
    try:
        # Check conditions in order of priority (Positive > Negative > Neutral)
        if row['positive'] > row['negative'] and row['positive'] > row['neutral']:
            return 'Positive'
        elif row['negative'] > row['positive'] and row['negative'] > row['neutral']:
            return 'Negative'
        else:
            return 'Neutral'

    except Exception as e:
        raise RuntimeError(f"get_sentiment_label() failed: {str(e)}")

def get_different_times(
    intra_call,
    thread_logger
):
    try:
        # Apply formatting to both time columns
        intra_call['start_time_second'] = (intra_call['Begin_Offset'] / 1000).astype(int)
        intra_call['end_time_second'] = (intra_call['End_Offset'] / 1000).astype(int)
        intra_call['time_spoken_second'] = intra_call['end_time_second'] - intra_call['start_time_second']
        intra_call['time_spoken_second'] = intra_call['time_spoken_second'].where(intra_call['time_spoken_second'] >= 0, 0)
        intra_call['time_spoken_second'] = intra_call['time_spoken_second'].fillna(0).astype(int)
        intra_call['time_silence_second'] = intra_call['start_time_second'].shift(-1) - intra_call['end_time_second']
        intra_call['time_silence_second'] = intra_call['time_silence_second'].where(intra_call['time_silence_second'] >= 0, 0)
        intra_call['time_silence_second'] = intra_call['time_silence_second'].fillna(0).astype(int)
        intra_call['load_date'] = datetime.now()

        # Dropping time formatted columns
        intra_call = intra_call.drop(['Begin_Offset', 'End_Offset'], axis=1)

        return intra_call

    except Exception as e:
        raise RuntimeError(f"get_different_times() failed: {str(e)}")

def get_sentiment_scores(
    contact_id,
    text_list,
    thread_logger
):
    try:
        thread_logger.info(f"{contact_id}: Calculating Caption Sentiments.")
        dict_sentiments = []
        for text in text_list:
            encoded_input = tokenizer(text, return_tensors='pt')
            output = model_sentiment(**encoded_input)
            scores = output[0][0].detach().numpy()
            scores = np.round(np.multiply(softmax(scores), 100), 2)
            merged_dict = dict(zip(list(config.id2label.values()), list(scores)))
            dict_sentiments.append(merged_dict)

        df_dict_sentiments = pd.DataFrame(dict_sentiments)
        df_dict_sentiments['sentiment_lable'] = df_dict_sentiments[['positive','negative','neutral']].apply(get_sentiment_label, axis=1)
        thread_logger.info(f"{contact_id}: Completed calculating Caption Sentiments.")

        return df_dict_sentiments

    except Exception as e:
        raise RuntimeError(f"get_sentiment_scores() failed: {str(e)}")

def process_transcript(
    contact_id,
    transcript_data,
    tokenizer,
    thread_logger
):
    """
    Pre-process the transcript loaded from S3 Buckets:
    1. Load the transcript as Pandas Dataframe.
    2. Select only the necessary columns ['BeginOffsetMillis', 'EndOffsetMillis', 'ParticipantId', 'Content', 'Sentiment', 'LoudnessScore'].
    3. Format the time in minutes and seconds.
    4. Rename the columns for better understanding.
    """
    try:
        thread_logger.info(f"{contact_id}: Loading the Transcript as Pandas Dataframe.")
        transcript_df = pd.json_normalize(transcript_data['Transcript'])

        # Select the relevant Columns
        columns_to_select = [
            'BeginOffsetMillis',
            'EndOffsetMillis',
            'ParticipantId',
            'Content'
        ]
        formatted_df = transcript_df[columns_to_select].copy()

        # Optionally rename columns to reflect their new format
        formatted_df = formatted_df.rename(columns={
            'BeginOffsetMillis': 'Begin_Offset',
            'EndOffsetMillis': 'End_Offset',
            'Content': 'caption',
            'Sentiment': 'sentiment_label',
            'ParticipantId': 'speaker_tag'
        })

        # Inserting the Call ID:
        formatted_df.insert(loc=0, column='contact_id', value=contact_id)
        formatted_df['call_language'] = transcript_data['LanguageCode']

        thread_logger.info(f"{contact_id}: Returning formated DataFrame.")
        return formatted_df

    except Exception as e:
        raise RuntimeError(f"process_transcript() failed: {str(e)}")


def create_intra_call_df(
    contact_id,
    gcp_project_id,
    vai_gcs_bucket,
    pipeline_run_name,
    transcript_data,
    tokenizer,
    thread_logger
):
    try:
        thread_logger.info(f"{contact_id}: Creating df_intra_call ")
        intra_call = process_transcript(
            contact_id,
            transcript_data,
            tokenizer,
            thread_logger
        )
        
        df_sentiment_scores = get_sentiment_scores(
            contact_id,
            intra_call.caption.to_list(),
            thread_logger
        )
        
        intra_call = pd.concat([intra_call, df_sentiment_scores], axis=1)    
        intra_call = get_different_times(
            intra_call,
            thread_logger
        )
        
        intra_call = mask_pii_in_captions(
            contact_id,
            intra_call,
            gcp_project_id,
            thread_logger
        )
        
        thread_logger.info(f"{contact_id}: Successfully created df_intra_call ")

        return intra_call

    except Exception as e:
        raise RuntimeError(f"create_intra_call_df() failed: {str(e)}")

"""
========================================================
Function: Create Dataframe: Inter Call 
========================================================
"""
def dict_to_newline_string(data):
    """Converts a dictionary into a new-line formatted string."""
    try:
        formatted_str = ""
        for key, value in data.items():
            formatted_str += f"{key}:\n"
            for item in value:
                formatted_str += f"  - {item}\n"
        return formatted_str.strip()

    except Exception as e:
        raise RuntimeError(f"dict_to_newline_string() failed: {str(e)}")


class CategoryValidator:
    def __init__(self, df_cat_subcat_mapping):
        """
        Initialize with category mapping from a Pandas DataFrame.
        :param df_cat_subcat_mapping: Pandas DataFrame containing 'CATEGORY' and 'SUBCATEGORY' columns.
        """
        self.df_cat_subcat_mapping = df_cat_subcat_mapping  # Ensure only the correct DataFrame is used
        self.valid_categories = set(df_cat_subcat_mapping['CATEGORY'].dropna().unique())
        self.category_subcategory_map = self._create_category_mapping()

    def _create_category_mapping(self):
        """Create category to subcategory mapping."""
        try:
            mapping = {}
            for _, row in self.df_cat_subcat_mapping.dropna().iterrows():
                category = row['CATEGORY']
                subcategory = row['SUBCATEGORY']

                if category not in mapping:
                    mapping[category] = set()

                if subcategory:  # Only add non-null subcategories
                    mapping[category].add(subcategory)

            return mapping

        except Exception as e:
            raise RuntimeError(f"_create_category_mapping() failed: {str(e)}")

    def validate_category(self, category: str) -> bool:
        """Check if category is valid."""
        try:
            return category in self.valid_categories

        except Exception as e:
            raise RuntimeError(f"validate_category() failed: {str(e)}")

    def validate_subcategory(self, category: str, subcategory: str) -> bool:
        """Check if subcategory is valid for the given category."""
        try:
            return category in self.category_subcategory_map and subcategory in self.category_subcategory_map[category]

        except Exception as e:
            raise RuntimeError(f"validate_subcategory() failed: {str(e)}")

    def get_valid_subcategories(self, category: str) -> set:
        """Get valid subcategories for a category."""
        try:
            return self.category_subcategory_map.get(category, set())

        except Exception as e:
            raise RuntimeError(f"get_valid_subcategories() failed: {str(e)}")


class CallSummary(BaseModel):
    summary: str = Field(..., max_length=500)

class CallTopic(BaseModel):
    primary_topic: str = Field(..., max_length=100)
    category: str = Field(..., max_length=100)
    sub_category: str = Field(..., max_length=100)

    def validate_category_mapping(
        self,
        category_validator: CategoryValidator,
        thread_logger
    ):
        """Validate category and subcategory against mapping. Replace with 'Unspecified' if invalid."""
        try:
            if not category_validator.validate_category(self.category):
                thread_logger.warning(f"Invalid category: {self.category}. Replacing with 'Unspecified'.")
                self.category = "Unspecified"
                self.sub_category = "Unspecified"
            elif not category_validator.validate_subcategory(self.category, self.sub_category):
                thread_logger.warning(f"Invalid subcategory '{self.sub_category}' for category '{self.category}'. Replacing subcategory with 'Unspecified'.")
                self.sub_category = "Unspecified"

        except Exception as e:
            raise RuntimeError(f"validate_category_mapping() failed: {str(e)}")

class AgentCoaching(BaseModel):
    strengths: List[str] = Field(..., max_items=3)
    improvement_areas: List[str] = Field(..., max_items=3)
    specific_recommendations: List[str] = Field(..., max_items=4)
    skill_development_focus: List[str] = Field(..., max_items=3)

class TranscriptAnalysis(BaseModel):
    call_summary: CallSummary
    call_topic: CallTopic
    agent_coaching: AgentCoaching

class KPIExtractor:
    def __init__(
        self,
        project_id: str,
        location: str,
        df_cat_subcat_mapping,
        thread_logger
    ):
        """
        Initialize the KPIExtractor with Vertex AI model and category validator.
        :param project_id: GCP Project ID
        :param location: GCP Region
        :param df_cat_subcat_mapping: Pandas DataFrame with 'CATEGORY' and 'SUBCATEGORY'
        """
        vertexai.init(project=project_id, location=location)
        self.model = GenerativeModel("gemini-1.5-flash-002")
        self.category_validator = CategoryValidator(df_cat_subcat_mapping)

        self.generation_config = {
            "temperature": 0.3,
            "max_output_tokens": 1024,
            "top_p": 0.8,
            "top_k": 40
        }

        self.safety_settings = {
            generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
            generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
            generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
            generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
        }


    def get_categories_prompt(self) -> str:
        """Create prompt section for valid categories and subcategories, handling null values"""
        try:
            categories_prompt = []

            for category, subcategories in self.category_validator.category_subcategory_map.items():
                if category is None:  # Skip if category is None
                    continue

                # Ensure subcategories are valid (remove None values)
                valid_subcategories = [subcat for subcat in subcategories if subcat is not None]

                if valid_subcategories:
                    subcats = ', '.join(sorted(valid_subcategories))
                else:
                    subcats = "No defined subcategories"

                categories_prompt.append(f"Category '{category}' can have subcategories: {subcats}")

            return '\n'.join(categories_prompt)

        except Exception as e:
            raise RuntimeError(f"get_categories_prompt() failed: {str(e)}")


    def create_prompt(self, transcript):
        """Create structured prompt with category guidance"""
        categories_guidance = self.get_categories_prompt()

        return f"""
        Analyze this call transcript and provide a structured analysis in the exact JSON format specified below.
        Keep responses concise, specific, and actionable.

        Guidelines:
        - Call summary should be factual and highlight key interactions
        - Topics and categories MUST match the following valid mappings:
        {categories_guidance}
        - Coaching points should be specific and actionable
        - All responses must follow the exact structure specified
        - Ensure all lists have the specified maximum number of items
        - All text fields must be clear, professional, and free of fluff

        Transcript:
        {transcript}

        Required Output Structure:
        {{
            "call_summary": {{
                "summary": "3-4 line overview of the call"
            }},
            "call_topic": {{
                "primary_topic": "Main topic of discussion",
                "category": "MUST BE ONE OF THE VALID CATEGORIES LISTED ABOVE",
                "sub_category": "MUST BE A VALID SUB-CATEGORY FOR THE CHOSEN CATEGORY"
            }},
            "agent_coaching": {{
                "strengths": ["Strength 1", "Strength 2", "Strength 3"],
                "improvement_areas": ["Area 1", "Area 2", "Area 3"],
                "specific_recommendations": ["Rec 1", "Rec 2", "Rec 3", "Rec 4"],
                "skill_development_focus": ["Skill 1", "Skill 2", "Skill 3"]
            }}
        }}

        Rules:
        1. Maintain exact JSON structure
        2. No additional fields or comments
        3. No markdown formatting
        4. Ensure all arrays have the exact number of items specified
        5. Keep all text concise and professional
        6. Do not mention any PII information such as Customer Name etc.
        7. STRICTLY use only the categories and subcategories from the provided mapping
        """

    def extract_json(self, response):
        """Extract valid JSON from response"""
        try:
            match = re.search(r'```json\s*([\s\S]*?)\s*```', response)
            if match:
                json_str = match.group(1)
            else:
                json_str = response.strip()
            return json.loads(json_str)

        except Exception as e:
            raise RuntimeError(f"extract_json() failed: {str(e)}")


    def validate_response(
        self,
        response_json,
        thread_logger,
        contact_id = None        
    ):
        """Validate response using Pydantic models and category mapping"""
        try:
            # First validate basic structure with Pydantic
            analysis = TranscriptAnalysis(**response_json)

            # Then validate category mapping
            analysis.call_topic.validate_category_mapping(self.category_validator, thread_logger)

            return analysis

        except Exception as e:
            raise RuntimeError(f"validate_response() failed: {str(e)}")


    def extract_genai_kpis(self, transcript, contact_id = None):
        """Extract KPIs from transcript with validation"""
        try:
            # Generate prompt
            prompt = self.create_prompt(transcript)

            # Get response from Gemini
            response = self.model.generate_content(
                prompt,
                generation_config=self.generation_config,
                safety_settings=self.safety_settings
            )

            # Parse JSON response
            response_json = self.extract_json(response.text)

            # Validate response structure and categories
            validated_response = self.validate_response(response_json, contact_id)

            return validated_response.model_dump()

        except Exception as e:
            raise RuntimeError(f"extract_genai_kpis() failed: {str(e)}")

            
"""
========================================================
Function: Create Dataframe Inter Call
========================================================
"""
def create_inter_call_df(
    contact_id,
    vai_gcs_bucket,
    gcs_stagging_folder,
    pipeline_run_name,
    transcript_data,
    ac_last_modified_date,
    df_intra_call,
    gcp_project_id,
    gcp_project_location,
    df_cat_subcat_mapping,
    thread_logger
):
    try:
        thread_logger.info(f"{contact_id}: Creating df_inter_call ")
        thread_logger.info(f"{contact_id}: Extracting KPIs from Gemini")      
        extractor = KPIExtractor(
            gcp_project_id,
            gcp_project_location,
            df_cat_subcat_mapping,
            thread_logger
        )
        transcript = " ".join(df_intra_call.caption)
        call_gen_kpis = extractor.extract_genai_kpis(transcript)
        thread_logger.info(f"{contact_id}: Completed Extracting KPIs from Gemini") 

        inter_call_dict = {}
        inter_call_dict['contact_id'] = str(df_intra_call['contact_id'][0])
        inter_call_dict['call_text'] = " ".join(df_intra_call.caption)
        inter_call_dict['call_summary'] = call_gen_kpis['call_summary']['summary']
        inter_call_dict['topic'] = call_gen_kpis['call_topic']['primary_topic']
        inter_call_dict['category'] = call_gen_kpis['call_topic']['category']
        inter_call_dict['sub_category'] = call_gen_kpis['call_topic']['sub_category']
        inter_call_dict['agent_coaching'] = dict_to_newline_string(call_gen_kpis['agent_coaching'])
        df_inter_call = pd.DataFrame(pd.Series(inter_call_dict)).T
        # Replace values where Categories are not in allowed list
        allowed_categories = df_cat_subcat_mapping['CATEGORY'].drop_duplicates().to_list()
        df_inter_call.loc[
            ~df_inter_call['category'].isin(allowed_categories) | df_inter_call['category'].isna(),
            ['category', 'sub_category']
        ] = 'Unspecified'

        df_inter_call['agent_speech_speed'] = transcript_data['ConversationCharacteristics']['TalkSpeed']['DetailsByParticipant']['AGENT']['AverageWordsPerMinute']
        df_inter_call['customer_speech_speed'] = transcript_data['ConversationCharacteristics']['TalkSpeed']['DetailsByParticipant']['CUSTOMER']['AverageWordsPerMinute']
        df_inter_call['total_talktime_agent_second'] = int(transcript_data['ConversationCharacteristics']['TalkTime']['DetailsByParticipant']['AGENT']['TotalTimeMillis']/1000)
        df_inter_call['total_talktime_customer_second'] = int(transcript_data['ConversationCharacteristics']['TalkTime']['DetailsByParticipant']['CUSTOMER']['TotalTimeMillis']/1000)
        df_inter_call['total_talktime_call_second'] = int(transcript_data['ConversationCharacteristics']['TalkTime']['TotalTimeMillis']/1000)
        df_inter_call['total_duration_call_second'] = int(transcript_data['ConversationCharacteristics']['TotalConversationDurationMillis']/1000)
        df_inter_call['total_dead_air_call_second'] = df_inter_call['total_duration_call_second'] - df_inter_call['total_talktime_call_second']
        # df_inter_call['customer_instance_id'] = transcript_data['CustomerMetadata']['InstanceId']
        # df_inter_call['call_job_status'] = transcript_data['JobStatus']
        df_inter_call['call_language'] = transcript_data['LanguageCode']
        df_inter_call['call_s3_uri'] = transcript_data['CustomerMetadata']['InputS3Uri']
        df_inter_call['ac_last_modified_date'] = ac_last_modified_date
        thread_logger.info(f"{contact_id}: Successfully created df_inter_call ")

        return df_inter_call

    except Exception as e:
        raise RuntimeError(f"create_inter_call_df() failed: {str(e)}")


"""
========================================================
Function: Process Single Transcript
========================================================
"""
def process_single_transcript(
    pipeline_run_name,
    gcp_project_id,
    vai_gcs_bucket,
    gcs_stagging_folder,
    gcs_errored_folder,
    gcs_logs_folder,
    gcs_intra_call_dfs_folder,
    gcs_inter_call_dfs_folder,
    transcript_path,
    tokenizer,
    gcp_project_location,
    df_cat_subcat_mapping
):
    contact_id = transcript_path.split('/')[-1].split('analysis')[0].strip('_')
    ac_last_modified_date = datetime.strptime(
            transcript_path.split('/')[-1].split('analysis_')[-1].split('.')[0].replace('_', ':'),
            '%Y-%m-%dT%H:%M:%SZ'
        )
    
    thread_logger, log_filepath = setup_thread_logger(contact_id)

    try:
        client = storage.Client()
        bucket = client.bucket(vai_gcs_bucket)
        blob = bucket.blob(transcript_path)
        transcript_data = json.loads(blob.download_as_text())

        thread_logger.info(f"{contact_id}: started processing")

        df_intra_call = create_intra_call_df(
            contact_id,
            gcp_project_id,
            vai_gcs_bucket,
            pipeline_run_name,
            transcript_data,
            tokenizer,
            thread_logger
        )

        df_inter_call = create_inter_call_df(
            contact_id,
            vai_gcs_bucket,
            gcs_stagging_folder,
            pipeline_run_name,
            transcript_data,
            ac_last_modified_date,
            df_intra_call,
            gcp_project_id,
            gcp_project_location,
            df_cat_subcat_mapping,
            thread_logger
        )

        if not df_intra_call.empty and not df_inter_call.empty:
            csv_path_df_intra_call = f"gs://{vai_gcs_bucket}/{gcs_intra_call_dfs_folder}/{contact_id}_df_intra_call.csv"
            df_intra_call.to_csv(csv_path_df_intra_call, index=False)
            thread_logger.info(f"{contact_id}: Persisted: {contact_id}_df_intra_call.csv")

            csv_path_df_inter_call = f"gs://{vai_gcs_bucket}/{gcs_inter_call_dfs_folder}/{contact_id}_df_inter_call.csv"
            df_inter_call.to_csv(csv_path_df_inter_call, index=False)
            thread_logger.info(f"{contact_id}: Persisted: {contact_id}_df_inter_call.csv")

            thread_logger.info(f"{contact_id}: Processing Complete")
            thread_logger.info("")
            thread_logger.info("")
        
    except Exception as e:
        handle_exception(contact_id, vai_gcs_bucket, pipeline_run_name, gcs_errored_folder, str(e), thread_logger)
        return None # Continue processing other files

    return log_filepath

def merge_and_save_transcripts(
    bucket_name,
    input_folder,
    output_folder,
    output_file
):
    try:
        """Reads, merges all files in a GCS folder, and saves the master DataFrame as CSV."""
        client = storage.Client()
        bucket = client.bucket(bucket_name)

        dfs = [
            pd.read_parquet(bucket.blob(blob.name).open("rb")) if blob.name.endswith(".parquet") 
            else pd.read_csv(bucket.blob(blob.name).open("r")) 
            for blob in bucket.list_blobs(prefix=input_folder) 
            if blob.name.endswith(('.csv', '.parquet'))
        ]

        if dfs:
            master_df = pd.concat(dfs, ignore_index=True)

            # Convert DataFrame to CSV in-memory
            csv_buffer = io.StringIO()
            master_df.to_csv(csv_buffer, index=False)

            # Upload CSV to GCS
            bucket.blob(f"{output_folder}/{output_file}").upload_from_string(
                csv_buffer.getvalue(), content_type="text/csv"
            )
            logger.info(f"Completed: merging and writing {output_file} to {output_folder}")

    except Exception as e:
        logger.error(f"Error processing {input_folder}: {str(e)}")


"""
========================================================
Variables
========================================================
"""
project_id = "dev-posigen"
secret_id = "dev-cx-voiceai"
version_id= "latest"
pipeline_run_name = "cx-voiceai-process-calls-2025-03-25-18-26-56"

configs = fetch_secrets(
    project_id,
    secret_id,
    version_id
)

# GCP Configuration
gcp_project_id = configs.get("VAI_GCP_PROJECT_ID")
gcp_project_location = configs.get("GCP_PROJECT_LOCATION")
vai_gcs_bucket = configs.get("VAI_GCP_PIPELINE_BUCKET")

# Pipeline Configuration
gcs_stagging_folder = f"{pipeline_run_name}/Stagging"
gcs_errored_folder = f"{pipeline_run_name}/Errored"
gcs_logs_folder = f"{pipeline_run_name}/Logs"
gcs_transcripts_folder = f"{pipeline_run_name}/Transcripts"
gcs_intra_call_dfs_folder = f"{pipeline_run_name}/Stagging/IntraCallDFs"
gcs_inter_call_dfs_folder = f"{pipeline_run_name}/Stagging/InterCallDFs"

# Snowflake Configuration
snf_account = configs.get("VAI_SNF_ACCOUNT")
snf_user = configs.get("VAI_SNF_USER")
snf_private_key = configs.get("private_key")
snf_private_key_pwd = configs.get("VAI_SNF_PRIVATE_KEY_PWD")
snf_warehouse = configs.get("VAI_SNF_WAREHOUSE")
snf_catsubcat_databse = configs.get("VAI_SNF_CATSUBCAT_DATABASE")
snf_catsubcat_schema = configs.get("VAI_SNF_CATSUBCAT_SCHEMA")
snf_catsubcat_view = configs.get("VAI_SNF_CATSUBCAT_VIEW")

# Max parallelism for multi-threading
max_parallelism = 10

# Step 2: Download Master Log File from GCS
master_log_file = f"{pipeline_run_name}.logs"
client = storage.Client()
bucket = client.bucket(vai_gcs_bucket)
blob = bucket.blob(f"{gcs_logs_folder}/{master_log_file}")
# Download master log file
blob.download_to_filename(master_log_file)

master_logger = setup_logger(master_log_file)

temp_log_folder = "temp_logs"
os.makedirs(temp_log_folder, exist_ok=True)

df_cat_subcat_mapping = fetch_category_mapping_from_snowflake(
    pipeline_run_name,
    vai_gcs_bucket,
    gcs_stagging_folder,
    gcs_errored_folder,
    snf_account,
    snf_user,
    snf_private_key,
    snf_private_key_pwd,
    snf_warehouse,
    snf_catsubcat_databse,
    snf_catsubcat_schema,
    snf_catsubcat_view,
    master_logger
)

transcripts_list = fetch_transcripts_from_gcs(
    pipeline_run_name,
    vai_gcs_bucket,
    gcs_stagging_folder,
    gcs_errored_folder,
    gcs_transcripts_folder,
    master_logger
)

# Multi-threaded execution
threads_log_files = []  # Store generated log files

# Multi-threaded execution
with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallelism) as executor:
    futures = [
        executor.submit(
            process_single_transcript,
            pipeline_run_name,
            gcp_project_id, 
            vai_gcs_bucket,
            gcs_stagging_folder,
            gcs_errored_folder,
            gcs_logs_folder,
            gcs_intra_call_dfs_folder,
            gcs_inter_call_dfs_folder,
            transcript_path,
            tokenizer,
            gcp_project_location,
            df_cat_subcat_mapping
        ) for transcript_path in transcripts_list[38:50]
    ]

    for future in concurrent.futures.as_completed(futures):
        threads_log_files.append(future.result())

# Merge all threaded transcripts
threads_log_files = [file for file in threads_log_files if isinstance(file, str) and file.endswith(".log")]
merge_logs(
    threads_log_files,
    master_log_file,
    master_logger
)

# # Step 3: Merge all outputs into master files after processing
# merge_and_save_transcripts(
#     vai_gcs_bucket,
#     gcs_intra_call_dfs_folder,
#     gcs_stagging_folder,
#     "master_intra_call_df.csv"
# )

# merge_and_save_transcripts(
#     vai_gcs_bucket,
#     gcs_inter_call_dfs_folder,
#     gcs_stagging_folder,
#     "master_inter_call_df.csv"
# )

# Upload the master log file back into GCS Bucket

Some weights of the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment-latest were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  df = pd.read_sql(query, conn)
2025-03-30 09:21:08 [INFO]: Completed: Fetching Category, Sub-Category Mapping.
2025-03-30 09:21:08 [INFO]: Fetching Transcripts from GCS: cx-voiceai-process-calls-2025-03-25-18-26-56/Transcripts
2025-03-30 09:21:08 [INFO]: Completed: Fetching from GCS 

NameError: name 'logger' is not defined

In [61]:
logger.info("testing this log entry")

2025-03-30 09:07:52 [INFO]: testing this log entry


# Component: Write data to snowflake

In [None]:
# pipeline_run_name = VAI_GCP_PIPELINE_RUN_NAME
# print(pipeline_run_name)
# gcp_project_id = VAI_GCP_PROJECT_ID
# print(gcp_project_id)
# gcp_project_location = VAI_GCP_PROJECT_LOCATION
# print(gcp_project_location)
# vai_gcs_bucket = VAI_GCP_PIPELINE_BUCKET
# print(vai_gcs_bucket)
# gcs_stagging_folder =f"{pipeline_run_name}/Stagging"
# print(gcs_stagging_folder)
# gcs_errored_folder =f"{pipeline_run_name}/Errored"
# print(gcs_errored_folder)
# snf_account = VAI_SNF_ACCOUNT
# print(snf_account)
# snf_user = VAI_SNF_USER
# print(snf_user)
# snf_private_key = VAI_SNF_PRIVATE_KEY
# # print(snf_private_key)
# snf_private_key_pwd = VAI_SNF_PRIVATE_KEY_PWD
# print(snf_private_key_pwd)
# snf_warehouse = VAI_SNF_WAREHOUSE
# print(snf_warehouse)
# snf_schema = VAI_SNF_SCHEMA
# print(snf_schema)
# snf_database = VAI_SNF_DATABASE
# print(snf_database)

In [None]:
@dsl.component(
    # base_image=f"us-central1-docker.pkg.dev/dev-posigen/dev-voice-ai/voice-ai-docker-image:latest"
    base_image=f"us-central1-docker.pkg.dev/dev-posigen/dev-voiceai/dev-voice-ai-docker-image:dev-4"
)
def write_data_to_snowflake(
    pipeline_run_name: str,
    gcp_project_id: str,
    gcp_project_location: str,
    vai_gcs_bucket: str,
    snf_account: str,
    snf_user: str,
    snf_private_key: str,
    snf_private_key_pwd: str,
    snf_warehouse: str,
    snf_database: str,
    snf_schema: str
):
    import io, logging
    import pytz
    import pandas as pd
    from datetime import datetime, timedelta, timezone
    from google.cloud import storage
    import snowflake.connector as sc
    from snowflake.connector.pandas_tools import write_pandas
    from cryptography.hazmat.primitives import serialization
    
    def insert_new_records(
        pipeline_run_name,
        vai_gcs_bucket,
        gcs_stagging_folder,
        gcs_errored_folder,
        snf_account,
        snf_user,
        snf_private_key,
        snf_private_key_pwd,
        snf_warehouse,
        snf_databse,
        snf_schema,
        table_name,
        df
    ):
        """
        Inserts only new records (based on ID) into Snowflake table with UTC load timestamp.

        Steps:
        1. Fetches existing IDs from table.
        2. Filters out rows with existing IDs from DataFrame.
        3. Adds 'LOAD_DATE_UTC' column with current UTC timestamp.
        4. Inserts only new records.

        Args:
            conn: Snowflake connection object.
            table_name (str): Name of the target table.
            df (pd.DataFrame): DataFrame containing the data (must have 'CONTACT_ID' column).

        Returns:
            int: Number of inserted records.
        """

        """
        Fetch Category-Subcategory mapping from Snowflake using a private key stored in GCP Secret Manager.

        :param snf_secret_project_id: GCP project where the secret is stored.
        :param secret_name: Name of the secret containing the Snowflake private key.
        :param snowflake_params: Dictionary containing Snowflake connection parameters.

        :return: Pandas DataFrame with category mappings.
        """

        try:
            # Step 1: Load & Decrypt the Private Key
            snf_private_key = serialization.load_pem_private_key(
                snf_private_key.encode(),
                password=snf_private_key_pwd.encode(),
                backend=None  # Default backend
            )

            # Step 2: Convert to Snowflake Compatible Format
            pkey_bytes = snf_private_key.private_bytes(
                encoding=serialization.Encoding.DER,
                format=serialization.PrivateFormat.PKCS8,
                encryption_algorithm=serialization.NoEncryption(),
            )

            conn_params = {
                'account': snf_account,
                'user': snf_user,
                'private_key': snf_private_key,
                'warehouse': snf_warehouse,
                'database': snf_databse,
                'schema': snf_schema
            }

            conn = sc.connect(**conn_params)
            cursor = conn.cursor()

            # Step 1: Get existing IDs from Snowflake table
            cursor.execute(f"SELECT DISTINCT(CONTACT_ID) FROM {table_name}")
            existing_ids = {row[0] for row in cursor.fetchall()}

            # Step 2: Filter DataFrame to keep only new records
            new_records_df = df[~df['CONTACT_ID'].isin(existing_ids)]

            if new_records_df.empty:
                logger.info("No new records to insert")
                return 0

            # Step 3: Add UTC timestamp column
            utc_now = datetime.now(pytz.utc).strftime('%Y-%m-%d %H:%M:%S')
            new_records_df = new_records_df.copy()  # Avoid modifying original df
            new_records_df["LOAD_DATE"] = utc_now  # Add new column

            # Step 4: Insert new records into Snowflake
            success, nchunks, nrows, _ = write_pandas(conn, new_records_df, table_name)

            logger.info(f"Inserted {nrows} new records with UTC load date")
            logger.info(f"Skipped {len(df) - len(new_records_df)} existing records")

            cursor.close()
            conn.close()
            return nrows

        except Exception as e:
            handle_exception("N/A", vai_gcs_bucket, pipeline_run_name, f"{pipeline_run_name}/Errored", str(e))

    # Function to read CSV from GCS
    def read_gcs_csv(file_path):
        blob = bucket.blob(file_path)
        csv_data = blob.download_as_text()
        return pd.read_csv(io.StringIO(csv_data))

    try:
        gcs_stagging_folder=f"{pipeline_run_name}/Stagging"
        gcs_errored_folder=f"{pipeline_run_name}/Errored"
        
        client = storage.Client()
        bucket = client.bucket(vai_gcs_bucket)

        # Read Inter & Intra Call DataFrames
        inter_call_df = read_gcs_csv(f"{gcs_stagging_folder}/master_inter_call_df.csv")
        inter_call_df.columns = inter_call_df.columns.str.upper() # For snowflake Schema matching
        intra_call_df = read_gcs_csv(f"{gcs_stagging_folder}/master_intra_call_df.csv")
        intra_call_df.columns = intra_call_df.columns.str.upper() # For snowflake Schema matching

        logger.info(f"Started: writing data to snowflake.")
        table_name ='SRC_GCP_INTER_CALLS'    
        logger.info(f"Writing data to table: {snf_database}.{table_name}")
        insert_new_records(
            pipeline_run_name,
            vai_gcs_bucket,
            gcs_stagging_folder,
            gcs_errored_folder,
            snf_account,
            snf_user,
            snf_private_key,
            snf_private_key_pwd,
            snf_warehouse,
            snf_database,
            snf_schema,
            table_name,
            inter_call_df
        )
        logger.info(f"SRC_GCP_INTER_CALLS: Inserted records #{len(inter_call_df)}")


        logger.info(f"Writing data to table: {snf_database}.{table_name}")
        table_name ='SRC_GCP_INTRA_CALLS'
        insert_new_records(
            pipeline_run_name,
            vai_gcs_bucket,
            gcs_stagging_folder,
            gcs_errored_folder,
            snf_account,
            snf_user,
            snf_private_key,
            snf_private_key_pwd,
            snf_warehouse,
            snf_database,
            snf_schema,
            table_name,
            intra_call_df
        )
        logger.info(f"SRC_GCP_INTRA_CALLS: Inserted records #{len(intra_call_df)}")
        logger.info(f"Completed: writing data to snowflake.")
        
    except Exception as e:
            handle_exception("N/A", vai_gcs_bucket, pipeline_run_name, f"{pipeline_run_name}/Errored", str(e))

# Define the Pipeline

In [3]:
@dsl.pipeline(
    name="VAI Audio to KPI Pipeline",
    description="Process Amazon Audio Transcripts into KPIs"
)
def vai_audio_to_kpi_pipeline(
    pipeline_run_name: str,
    project_id: str,
    secret_id: str,
    version_id: str
):
    """
    Pipeline to:
    1. List calls from S3 and download them to GCS.
    2. Process each transcript in parallel using Kubeflow Pipelines.
    """

    # Step 1: List and Download Calls from S3 to GCS
    get_calls_to_process = list_download_calls_s3_to_gcs(
        pipeline_run_name=pipeline_run_name,
        project_id=project_id,
        secret_id=secret_id,
        version_id=version_id
    )

    # # Step 2: Parallel Process Transcripts (linked to Step 1)
    # process_calls = process_transcripts(
    #     log_file=log_file,
    #     pipeline_run_name=pipeline_run_name,
    #     gcp_project_id=gcp_project_id,
    #     gcp_project_location=gcp_project_location,
    #     vai_gcs_bucket=vai_gcs_bucket,
    #     snf_account=snf_account,
    #     snf_user=snf_user,
    #     snf_private_key=snf_private_key,
    #     snf_private_key_pwd=snf_private_key_pwd,
    #     snf_warehouse=snf_warehouse,
    #     snf_catsubcat_databse=snf_catsubcat_databse,
    #     snf_catsubcat_schema=snf_catsubcat_schema,
    #     snf_catsubcat_view=snf_catsubcat_view,
    #     max_parallelism=max_parallelism
    # )
    
    # Enforce sequential execution
    # process_calls.after(get_calls_to_process)
    
# Step 3: Write the Data to Snowflake
#     persist_to_snowflake = write_data_to_snowflake(
#         pipeline_run_name=pipeline_run_name,
#         gcp_project_id=gcp_project_id,
#         gcp_project_location=gcp_project_location,
#         vai_gcs_bucket=vai_gcs_bucket,
#         snf_account=snf_account,
#         snf_user=snf_user,
#         snf_private_key=snf_private_key,
#         snf_private_key_pwd=snf_private_key_pwd,
#         snf_warehouse=snf_warehouse,
#         snf_database=snf_database,
#         snf_schema=snf_schema
#     )
    
#     # Enforce sequential execution
#     persist_to_snowflake.after(process_calls)

# Compile the Pipeline

In [4]:
compiler.Compiler().compile(vai_audio_to_kpi_pipeline, 'cx-voiceai-process-calls.yaml')

# Build Pipeline

In [5]:
TIMESTAMP = timestamp = datetime.now(UTC).strftime("%Y-%m-%d-%H-%M-%S")

# Initialize Vertex AI
aiplatform.init(project="dev-posigen", location="us-central1")

# Create pipeline job
job = pipeline_jobs.PipelineJob(
    display_name = f"vai-pipeline-run-{TIMESTAMP}".lower(),
    job_id = f"vai-pipeline-run-{TIMESTAMP}".lower(),
    template_path = f"cx-voiceai-process-calls.yaml",
    pipeline_root = f"gs://dev-aws-connect-audio",
    project = "dev-posigen",
    location = "us-central1",
    enable_caching = False,
    parameter_values={
        "pipeline_run_name": f"cx-voiceai-process-calls-{TIMESTAMP}",
        "project_id": "dev-posigen",
        "secret_id": "dev-cx-voiceai",
        "version_id": "1"
    }
)

# Run the Pipeline 

### Run on GCP

In [6]:
job.run()

Creating PipelineJob
PipelineJob created. Resource name: projects/275963620760/locations/us-central1/pipelineJobs/vai-pipeline-run-2025-03-26-14-16-45
To use this PipelineJob in another session:
pipeline_job = aiplatform.PipelineJob.get('projects/275963620760/locations/us-central1/pipelineJobs/vai-pipeline-run-2025-03-26-14-16-45')
View Pipeline Job:
https://console.cloud.google.com/vertex-ai/locations/us-central1/pipelines/runs/vai-pipeline-run-2025-03-26-14-16-45?project=275963620760
PipelineJob projects/275963620760/locations/us-central1/pipelineJobs/vai-pipeline-run-2025-03-26-14-16-45 current state:
3
PipelineJob projects/275963620760/locations/us-central1/pipelineJobs/vai-pipeline-run-2025-03-26-14-16-45 current state:
3
PipelineJob projects/275963620760/locations/us-central1/pipelineJobs/vai-pipeline-run-2025-03-26-14-16-45 current state:
3
PipelineJob projects/275963620760/locations/us-central1/pipelineJobs/vai-pipeline-run-2025-03-26-14-16-45 current state:
3
PipelineJob proje

### Run Locally

In [None]:
# client = kfp.Client()