# Imports

In [1]:
# from kfp import dsl
# from kfp.v2 import compiler
from google.cloud import aiplatform
from google.cloud import dlp_v2

from typing import List, Dict
import boto3
from botocore.exceptions import ClientError
import json, os, ast, re
import logging
import time
from datetime import datetime, timedelta
import pandas as pd, numpy as np
from scipy.special import softmax
from pydantic import BaseModel, Field, ValidationError
import pytz

import snowflake.connector as sc
from snowflake.connector.pandas_tools import write_pandas

import vertexai
import vertexai.preview.generative_models as generative_models
from vertexai.generative_models import GenerativeModel, GenerationConfig, Part

# Sentiments
from transformers import pipeline
from transformers import AutoTokenizer, AutoConfig
from transformers import AutoModelForSequenceClassification

# Variables

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Temporary secrets manager
with open("../../sun/secrets/configs.json", 'r') as secrets_file:
    configs = json.load(secrets_file)

loc_logs = configs.get("loc_logs")
excel_path = configs.get("excel_path")

aws_access_key = configs.get("aws_access_key")
aws_secret_key = configs.get("aws_secret_key")

# AWS
s3_source_bucket = configs.get('s3_source_bucket')
s3_transcripts_location = configs.get('s3_transcripts_location')

# GCP
gcp_project_id=configs.get('gcp_project_id')
gcp_prjct_location=configs.get('gcp_prjct_location')

# Snowflake
private_key_file = configs.get('snowflakegcp_rsa_key')
private_key_file_pwd = configs.get('snf_ssh_key_pass')

# Define the Snowflake View containing category mappings
snf_catsubcat_view = configs.get('snf_catsubcat_view')

# # Sentiment Scores
MODEL = f"cardiffnlp/twitter-roberta-base-sentiment-latest"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
config = AutoConfig.from_pretrained(MODEL)
model_sentiment = AutoModelForSequenceClassification.from_pretrained(MODEL)

Some weights of the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment-latest were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Util Functions

## Misc Utils

### Initiate Master Inter and Intra Dataframes

In [3]:
def initiate_master_dataframes():
    if os.path.isfile("df_intra_calls_data.csv"):
        logger.info("df_intra_calls_data.csv exists.") 
        df_intra_calls_data = pd.read_csv("df_intra_calls_data.csv")
        df_intra_calls_data.CONTACT_ID = df_intra_calls_data.CONTACT_ID.astype('string')
    else:
        logger.info("df_intra_calls_data.csv does not exists.")
        df_intra_calls_data = pd.DataFrame()

    if os.path.isfile("df_inter_calls_data.csv"):
        logger.info("df_inter_calls_data.csv exists.")
        df_inter_calls_data = pd.read_csv("df_inter_calls_data.csv")
        df_inter_calls_data.CONTACT_ID = df_inter_calls_data.CONTACT_ID.astype('string')
    else:
        logger.info("df_inter_calls_data.csv does not exists.")
        df_inter_calls_data = pd.DataFrame()

    return df_intra_calls_data, df_inter_calls_data

## Function: Listing Transcripts

In [4]:
def list_new_transcripts_from_folderlist(
    aws_access_key,
    aws_secret_key,
    source_bucket,
    custom_location,
    folderlist
):
    """
    Fetch audio file metadata from S3 folders using pagination.
    """

    s3_client = boto3.client(
        's3',
        aws_access_key_id=aws_access_key,
        aws_secret_access_key=aws_secret_key
    )

    # logger.info("Fetching New Transcripts to process...")
    all_files = []

    # Fetch files from each folder in the list
    for folder in folderlist:
        try:
            paginator = s3_client.get_paginator('list_objects_v2')
            pages = paginator.paginate(Bucket=source_bucket, Prefix=folder)

            for page in pages:
                for obj in page.get('Contents', []):
                    file_path = obj['Key']
                    s3_ts = obj['LastModified']

                    # Skip non-JSON files
                    if not file_path.endswith('.json'):
                        continue

                    call_id = file_path.split('/')[-1].split("_analysis_")[0]
                    call_timestamp = pd.to_datetime(file_path.split('analysis_')[-1].split('.')[0].replace('Z', ""))

                    all_files.append({
                        'File': file_path,
                        'ID': call_id,
                        'File_Timestamp': call_timestamp,
                        'File_Date': call_timestamp.date().strftime('%Y-%m-%d'),
                        'File_Time': call_timestamp.time().strftime('%H:%M:%S'),
                        'S3_Timestamp': s3_ts,
                        'S3_Date': s3_ts.strftime('%Y-%m-%d'),
                        'S3_Time': s3_ts.strftime('%H:%M:%S')
                    })

        except ClientError as e:
            logger.error(f"Error accessing S3: {e}")
            continue

    if all_files:
        df_calls_list = pd.DataFrame(all_files).sort_values(['File_Timestamp'], ascending=False)
        df_calls_list['Time_Bin'] = df_calls_list['File_Timestamp'].dt.floor('2h')
    else:
        df_calls_list = pd.DataFrame()

    return df_calls_list

## Function: Read Transcripts

In [5]:
def fetch_transcript_from_s3(aws_access_key: str, aws_secret_key: str, s3_source_bucket: str, file_key):
    """
    Read Transcript JSON content from a specific file in S3.
    
    :param bucket_name: Name of the S3 bucket
    :param file_key: Full path/key of the JSON file
    :return: Parsed JSON content
    """
    s3_client = boto3.client(
            's3',
            aws_access_key_id=aws_access_key,
            aws_secret_access_key=aws_secret_key
        )
    
    try:
        # Download the file
        response = s3_client.get_object(Bucket=s3_source_bucket, Key=file_key)
        
        # Read the content
        json_content = response['Body'].read().decode('utf-8')
        
        # Parse JSON
        return json.loads(json_content)
    
    except Exception as e:
        logger.info
        logger.error(f"Error reading Transcript JSON file {file_key}: {e}")
        logger.info("")
        pass

## Fetching Category, Sub-Category Mapping

In [6]:
def fetch_category_mapping_from_snowflake(catsubcat_conn_params):
    """Fetch Category-Subcategory mapping from a Snowflake View and return as DataFrame."""
    try:
        conn = sc.connect(**catsubcat_conn_params)
        query = f"SELECT CATEGORY, SUBCATEGORY FROM {snf_catsubcat_view}"
        df = pd.read_sql(query, conn)
        conn.close()
        return df
    except Exception as e:
        raise RuntimeError(f"Error fetching category mapping from Snowflake: {e}")


# Create Intra-call Dataframe

In [7]:
def millis_to_hhmmss(millis):
    """Convert milliseconds to mm:ss format"""
    total_seconds = int(millis / 1000)
    hours = total_seconds // 3600
    minutes = total_seconds // 60
    seconds = total_seconds % 60
    return f"{minutes:02d}:{seconds:02d}"

def convert_to_seconds(time_str):
    try:
        # Parse time string using datetime
        time_obj = datetime.strptime(time_str, '%H:%M:%S')
        # Convert to timedelta and extract total seconds
        total_seconds = time_obj.minute * 60 + time_obj.second
        return total_seconds
    except ValueError:
        pass

def process_transcript(
    transcript_data: dict,
    contact_id: str
):
    """
    Pre-process the transcript loaded from S3 Buckets:
    1. Load the transcript as Pandas Dataframe.
    2. Select only the necessary columns ['BeginOffsetMillis', 'EndOffsetMillis', 'ParticipantId', 'Content', 'Sentiment', 'LoudnessScore'].
    3. Format the time in minutes and seconds.
    4. Rename the columns for better understanding.
    """
    # Load the Transcript as Pandas Dataframe
    transcript_df = pd.json_normalize(transcript_data['Transcript'])

    # Select the relevant Columns
    columns_to_select = [
        'BeginOffsetMillis',
        'EndOffsetMillis',
        'ParticipantId',
        'Content'
    ]
    formatted_df = transcript_df[columns_to_select].copy()
    
    # Optionally rename columns to reflect their new format
    formatted_df = formatted_df.rename(columns={
        'BeginOffsetMillis': 'Begin_Offset',
        'EndOffsetMillis': 'End_Offset',
        'Content': 'caption',
        'Sentiment': 'sentiment_label',
        'ParticipantId': 'speaker_tag'
    })

    # Inserting the Call ID:
    formatted_df.insert(loc=0, column='contact_id', value=contact_id)
    formatted_df['call_language'] = transcript_data['LanguageCode']

    return formatted_df

def get_sentiment_label(row):
    # Check conditions in order of priority (Positive > Negative > Neutral)
    if row['positive'] > row['negative'] and row['positive'] > row['neutral']:
        return 'Positive'
    elif row['negative'] > row['positive'] and row['negative'] > row['neutral']:
        return 'Negative'
    else:
        return 'Neutral'

def get_sentiment_scores(text_list):
    dict_sentiments = []
    for text in text_list:
        encoded_input = tokenizer(text, return_tensors='pt')
        output = model_sentiment(**encoded_input)
        scores = output[0][0].detach().numpy()
        scores = np.round(np.multiply(softmax(scores), 100), 2)
        merged_dict = dict(zip(list(config.id2label.values()), list(scores)))
        dict_sentiments.append(merged_dict)

    df_dict_sentiments = pd.DataFrame(dict_sentiments)
    df_dict_sentiments['sentiment_lable'] = df_dict_sentiments[['positive','negative','neutral']].apply(get_sentiment_label, axis=1)
    
    return df_dict_sentiments

def get_different_times(intra_call):
    # Apply formatting to both time columns
    intra_call['start_time_second'] = (intra_call['Begin_Offset'] / 1000).astype(int)
    # intra_call['Begin_Offset'] = intra_call['Begin_Offset'].apply(millis_to_hhmmss)
    intra_call['end_time_second'] = (intra_call['End_Offset'] / 1000).astype(int)
    # intra_call['End_Offset'] = intra_call['End_Offset'].apply(millis_to_hhmmss)
    intra_call['time_spoken_second'] = intra_call['end_time_second'] - intra_call['start_time_second']
    intra_call['time_spoken_second'] = intra_call['time_spoken_second'].where(intra_call['time_spoken_second'] >= 0, 0)
    intra_call['time_spoken_second'] = intra_call['time_spoken_second'].fillna(0).astype(int)
    intra_call['time_silence_second'] = intra_call['start_time_second'].shift(-1) - intra_call['end_time_second']
    intra_call['time_silence_second'] = intra_call['time_silence_second'].where(intra_call['time_silence_second'] >= 0, 0)
    intra_call['time_silence_second'] = intra_call['time_silence_second'].fillna(0).astype(int)
    intra_call['load_date'] = datetime.now()

    # Dropping time formatted columns
    intra_call = intra_call.drop(['Begin_Offset', 'End_Offset'], axis=1)

    return intra_call
logger = logging.getLogger(__name__)

def mask_cvv_contextually(df):
    """Masks CVVs with contextual awareness."""
    cvv_context = False
    indices_to_mask = []
    masked_captions = df['caption'].tolist()
    context_timeout = 3  # Number of lines to wait before resetting context
    context_counter = 0

    cvv_patterns = [
        r'\b\d{3}\b',
        r'\b\d{4}\b'
    ]

    for i, caption in enumerate(df['caption']):
        if re.search(r'\b(?:cvv|security code|digits on the back|card verification|3 digits at the back of the card|the 3 digit code)\b', caption, re.IGNORECASE):
            cvv_context = True
            indices_to_mask = []
            context_counter = 0
        elif cvv_context:
            clean_caption = re.sub(r'[^0-9]', '', caption)  # Extract numbers
            if clean_caption:
                indices_to_mask.append(i)
                context_counter += 1

                for pattern in cvv_patterns:
                    if re.search(pattern, clean_caption):
                        for idx in indices_to_mask:
                            masked_captions[idx] = "[CVV_REDACTED]"
                        cvv_context = False
                        indices_to_mask = []
                        break #break the for loop, as we have found the matching pattern
                else:
                    if context_counter > context_timeout:
                        cvv_context = False
                        indices_to_mask = []
            else:
                context_counter += 1
                if context_counter > context_timeout:
                    cvv_context = False
                    indices_to_mask = []

        else:
            cvv_context = False
            indices_to_mask = []

    df['caption'] = masked_captions
    return df

def mask_expiration_date_contextually(df):
    """Masks expiration dates with contextual awareness, redacting only the pattern."""
    exp_date_context = False
    masked_captions = df['caption'].tolist()
    context_timeout = 4  # Number of lines to wait before resetting context
    context_counter = 0

    exp_patterns = [
        r'\b(0[1-9]|1[0-2])\s*/\s*(\d{2}|\d{4})\b',  # MM/YY, MM/YYYY
        r'\b([1-9])\s*/\s*(\d{2}|\d{4})\b',  # M/YY, M/YYYY
        r'\b(0[1-9]|1[0-2])(\d{2}|\d{4})\b',  # MMYY, MMYYYY
        r'\b([1-9])(\d{2}|\d{4})\b',  # MYY, MYYY
        r'\b(0[1-9]|1[0-2])\s+(\d{2}|\d{4})\b',  # MM DD, MM YYYY
        r'\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?,?\s*(?:\d{2,4})?\b', # Month name followed by day number and optional year
        r'\b(0[1-9]|1[0-2]):(\d{2}|\d{4})\b', #MM:YY, MM:YYYY
        r'\b(0[1-9]|1[0-2])[-/](\d{2}|\d{4})\b|\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?,?\s*\d{2,4}\b' #DLP regex
    ]

    for i, caption in enumerate(df['caption']):
        if re.search(r'\b(?:expiration date|exp date|expiry date|Expiration is)\b', caption, re.IGNORECASE):
            exp_date_context = True
            context_counter = 0
            for pattern in exp_patterns:
                match = re.search(pattern, caption)
                if match:
                    masked_captions[i] = re.sub(re.escape(match.group(0)), "[EXPIRY_DATE_REDACTED]", masked_captions[i])
                    exp_date_context = False
                    context_counter = 0
                    break

        elif exp_date_context:
            if context_counter <= context_timeout:
                for pattern in exp_patterns:
                    match = re.search(pattern, caption)
                    if match:
                        masked_captions[i] = re.sub(re.escape(match.group(0)), "[EXPIRY_DATE_REDACTED]", masked_captions[i])
                        exp_date_context = False
                        context_counter = 0
                        break
                else:
                    context_counter += 1
            else:
                exp_date_context = False
                context_counter = 0
        else:
            context_counter = 0

    df['caption'] = masked_captions
    return df

def mask_card_numbers_contextually(df):
    """Masks card numbers across multiple lines, redacting only the numbers."""

    card_number_context = False
    concatenated_number = ""
    indices_to_mask = []
    masked_captions = df['caption'].tolist()
    context_timeout = 5  # Number of lines to wait before resetting context
    context_counter = 0

    for i, caption in enumerate(df['caption']):
        cleaned_caption = re.sub(r'[^0-9]', '', caption)  # Extract only digits

        # Detect context where card number is mentioned
        if re.search(r'\b(?:card number|credit card|new card number|card details|that is)\b', caption, re.IGNORECASE):
            card_number_context = True
            concatenated_number = ""
            indices_to_mask = [i]
            context_counter = 0

        elif card_number_context:
            if cleaned_caption:  # If line contains numbers, capture them
                concatenated_number += cleaned_caption
                indices_to_mask.append(i)

            context_counter += 1

        # If total digits collected suggest a credit card number, redacting them
        if len(concatenated_number) >= 13 and len(concatenated_number) <= 19:
            for idx in indices_to_mask:
                numbers = re.findall(r'\d+', df['caption'][idx])
                for num in numbers:
                    masked_captions[idx] = re.sub(re.escape(num), "[CARD_NUMBER_REDACTED]", masked_captions[idx])
            card_number_context = False  # Reset context
            concatenated_number = ""
            indices_to_mask = []
        elif context_counter > context_timeout:
            card_number_context = False
            concatenated_number = ""
            indices_to_mask = []

    df['caption'] = masked_captions
    return df





def mask_emails_contextually(df):
    """Masks email addresses spoken across multiple lines with contextual awareness."""
    
    email_context = False
    email_parts = []
    indices_to_mask = []
    masked_captions = df['caption'].tolist()
    context_timeout = 5  
    context_counter = 0

    # Improved regex to match emails correctly (including optional spaces around '@' and '.')
    email_pattern = r'\b[A-Za-z0-9._%+-]+@\s*[A-Za-z0-9.-]+\s*\.[A-Z|a-z]{2,}\b'

    for i, caption in enumerate(df['caption']):
        # Remove spaces around '@' and '.' to catch incorrectly spaced emails
        cleaned_caption = re.sub(r'\s*@\s*', '@', caption)
        cleaned_caption = re.sub(r'\s*\.\s*', '.', cleaned_caption)

        # Immediate masking if full email is found in one line
        if re.search(email_pattern, cleaned_caption):
            masked_captions[i] = re.sub(email_pattern, "[EMAIL_REDACTED]", cleaned_caption)
            email_context = False  
            continue  

        # Detect email context
        if re.search(r'\b(?:email|email address|send to|mail to)\b', cleaned_caption, re.IGNORECASE):
            email_context = True
            email_parts = []
            indices_to_mask = []
            context_counter = 0

        elif email_context:
            # Capture words containing '@' or adjacent to it
            potential_parts = re.findall(r'\b[A-Za-z0-9._%+-]+(?:@|(?:@[A-Za-z0-9.-]+))?\b', cleaned_caption)
            valid_parts = [part for part in potential_parts if '@' in part or len(email_parts) > 0]

            if valid_parts:
                indices_to_mask.append(i)
                context_counter += 1
                email_parts.extend(valid_parts)
                
                if re.search(email_pattern, "".join(email_parts)):
                    for idx in indices_to_mask:
                        masked_captions[idx] = "[EMAIL_REDACTED]"
                    email_context = False
                    email_parts = []
                    indices_to_mask = []
                    break  
            else:
                context_counter += 1
                if context_counter > context_timeout:
                    email_context = False
                    email_parts = []
                    indices_to_mask = []

        else:
            email_context = False
            email_parts = []
            indices_to_mask = []

    df['caption'] = masked_captions
    return df
 

def mask_check_number(df):
    """Masks check numbers with contextual awareness."""
    check_context = False
    masked_captions = df['caption'].tolist()
    context_timeout = 3  # Adjust as needed
    context_counter = 0

    # Regex to capture check numbers (4-8 digits is common)
    check_number_regex = r'\b\d{4,8}\b'

    for i, caption in enumerate(masked_captions):
        if re.search(r'\b(?:check number|cheque number|check #|cheque #)\b', caption, re.IGNORECASE):
            check_context = True
            context_counter = 0
            numbers = re.findall(check_number_regex, caption)
            if numbers:
                for num in numbers:
                    masked_captions[i] = re.sub(re.escape(num), "[CHECK_NUMBER_REDACTED]", masked_captions[i])

        elif check_context:
            numbers = re.findall(check_number_regex, caption)
            if numbers:
                for num in numbers:
                    masked_captions[i] = re.sub(re.escape(num), "[CHECK_NUMBER_REDACTED]", masked_captions[i])

            context_counter += 1
            if context_counter > context_timeout:
                check_context = False

        else:
            check_context = False

    df['caption'] = masked_captions
    return df


def mask_routing_number(df):
    """Masks routing numbers with contextual awareness."""
    routing_context = False
    masked_captions = df['caption'].tolist()
    context_timeout = 3
    context_counter = 0

    for i, caption in enumerate(df['caption']):
        if re.search(r'\b(?:routing number|ABA number|bank routing|bank details)\b', caption, re.IGNORECASE):
            routing_context = True
            context_counter = 0
            numbers = re.findall(r'\b\d{9}\b', caption)
            for num in numbers:
                masked_captions[i] = re.sub(re.escape(num), "[ROUTING_NUMBER_REDACTED]", masked_captions[i])

        elif routing_context:
            numbers = re.findall(r'\b\d{9}\b', caption)  # Extract 9-digit numbers
            if numbers:
                for num in numbers:
                    masked_captions[i] = re.sub(re.escape(num), "[ROUTING_NUMBER_REDACTED]", masked_captions[i])

            context_counter += 1
            if context_counter > context_timeout:
                routing_context = False

        else:
            routing_context = False

    df['caption'] = masked_captions
    return df

def mask_account_number(df):
    """Masks account numbers with contextual awareness."""
    account_context = False
    masked_captions = df['caption'].tolist()
    context_timeout = 5
    context_counter = 0

    for i, caption in enumerate(df['caption']):
        if re.search(r'\b(?:account number|bank account|checking account|savings account|bank details)\b', caption, re.IGNORECASE):
            account_context = True
            context_counter = 0
            numbers = re.findall(r'\b\d{6,18}\b', caption)
            for num in numbers:
                masked_captions[i] = re.sub(re.escape(num), "[ACCOUNT_NUMBER_REDACTED]", masked_captions[i])

        elif account_context:
            numbers = re.findall(r'\b\d{6,18}\b', caption)  # Extract 6-18 digit numbers
            if numbers:
                for num in numbers:
                    masked_captions[i] = re.sub(re.escape(num), "[ACCOUNT_NUMBER_REDACTED]", masked_captions[i])

            context_counter += 1
            if context_counter > context_timeout:
                account_context = False

        else:
            account_context = False

    df['caption'] = masked_captions
    return df


def mask_card_ending(df):
    """Masks card endings with proximity-aware date exclusion."""
    masked_captions = df['caption'].astype(str).tolist()
    address_keywords = ["billing address", "address", "zip code", "postal code", "street", "city", "state"]
    date_keywords = ["january", "february", "march", "april", "may", "june", "july", "august", "september", "october", "november", "december"]
    proximity_window = 50  # Adjust as needed

    for i, caption in enumerate(masked_captions):
        caption_lower = caption.lower()

        # Check for "card ending in" context
        if "card ending in" in caption_lower:
            match = re.search(r'card ending in\s*(\d{4,6})\b', caption_lower)
            if match:
                ending = match.group(1)
                ending_index = caption_lower.find(match.group(0))

                # Check for date keywords within the proximity window
                date_found_nearby = False
                for date_keyword in date_keywords:
                    if date_keyword in caption_lower[max(0, ending_index - proximity_window):min(len(caption_lower), ending_index + proximity_window)]:
                        date_found_nearby = True
                        break

                if not any(address_keyword in caption_lower for address_keyword in address_keywords) and not date_found_nearby:
                    masked_captions[i] = re.sub(re.escape(ending), "[CARD_ENDING_REDACTED]", masked_captions[i])
                continue

        # Check for "on the card" context
        if "on the card" in caption_lower:
            match = re.search(r'on the card\s*(\d{4,6})\b', caption_lower)
            if match:
                ending = match.group(1)
                ending_index = caption_lower.find(match.group(0))

                # Check for date keywords within the proximity window
                date_found_nearby = False
                for date_keyword in date_keywords:
                    if date_keyword in caption_lower[max(0, ending_index - proximity_window):min(len(caption_lower), ending_index + proximity_window)]:
                        date_found_nearby = True
                        break

                if not any(address_keyword in caption_lower for address_keyword in address_keywords) and not date_found_nearby:
                    masked_captions[i] = re.sub(re.escape(ending), "[CARD_ENDING_REDACTED]", masked_captions[i])
                continue

        # Check for other card ending keywords
        card_ending_keywords = ["ending in", "ending with", "ends in"]
        for keyword in card_ending_keywords:
            if keyword in caption_lower:
                numbers = re.findall(r'\b\d{4,6}\b', caption)
                if numbers:
                    for num in numbers:
                        ending_index = caption_lower.find(keyword)
                        date_found_nearby = False
                        for date_keyword in date_keywords:
                            if date_keyword in caption_lower[max(0, ending_index - proximity_window):min(len(caption_lower), ending_index + proximity_window)]:
                                date_found_nearby = True
                                break

                        if not any(address_keyword in caption_lower for address_keyword in address_keywords) and not date_found_nearby:
                            masked_captions[i] = re.sub(re.escape(num), "[CARD_ENDING_REDACTED]", masked_captions[i])
                break

    df['caption'] = masked_captions
    return df



def mask_account_ending(df):
    """Masks account endings with context and non-digit character check."""
    masked_captions = df['caption'].astype(str).tolist()

    for i, caption in enumerate(masked_captions):
        try:
            match = re.search(r'(?i)(account ending in|account ending with|account ends in)\s(\d{4,6})\b', caption)
            if match:
                ending = match.group(2)
                ending_index = match.start(2)

                # Check for non-digit characters before or after the ending digits
                before = caption[:ending_index].strip()
                after = caption[ending_index + len(ending):].strip()

                if not before or not after:
                    # If there's nothing before or after, it's likely just the ending digits.
                    masked_captions[i] = re.sub(re.escape(ending), "[ACCOUNT_ENDING_REDACTED]", masked_captions[i])
                else:
                    # check if the surrounding characters are part of the context.
                    if re.search(r'(?i)(account ending in|account ending with|account ends in)', before) or re.search(r'(?i)(account ending in|account ending with|account ends in)', after):
                        masked_captions[i] = re.sub(re.escape(ending), "[ACCOUNT_ENDING_REDACTED]", masked_captions[i])

        except Exception as e:
            logger.error(f"Error processing line {i}: {e}")

    df['caption'] = masked_captions
    return df


"""=================Adress Masking==========="""

def mask_address(df):
    """Redacts addresses from a DataFrame's 'caption' column."""
    def clean_text(text):
        """Remove extra spaces and fix uppercase letter spacing."""
        text = re.sub(r'\b([A-Z])(?:\s([A-Z]))+\b', lambda m: ''.join(m.group().split()), text)
        text = re.sub(r'\s+', ' ', text).strip()
        return text

    STATE_PATTERN = r'\b(?:Alabama|Alaska|Arizona|Arkansas|California|Colorado|Connecticut|Delaware|Florida|Georgia|' \
                    r'Hawaii|Idaho|Illinois|Indiana|Iowa|Kansas|Kentucky|Louisiana|Maine|Maryland|Massachusetts|Michigan|' \
                    r'Minnesota|Mississippi|Missouri|Montana|Nebraska|Nevada|New Hampshire|New Jersey|New Mexico|New York|' \
                    r'North Carolina|North Dakota|Ohio|Oklahoma|Oregon|Pennsylvania|Rhode Island|South Carolina|' \
                    r'South Dakota|Tennessee|Texas|Utah|Vermont|Virginia|Washington|West Virginia|Wisconsin|Wyoming)\b'

    CITY_STATE_ZIP_PATTERN = r'\b[A-Za-z]+(?:\s[A-Za-z]+)*,?\s(?:' + \
                              r'Alabama|Alaska|Arizona|Arkansas|California|Colorado|Connecticut|Delaware|Florida|Georgia|Hawaii|Idaho|' + \
                              r'Illinois|Indiana|Iowa|Kansas|Kentucky|Louisiana|Maine|Maryland|Massachusetts|Michigan|Minnesota|Mississippi|' + \
                              r'Missouri|Montana|Nebraska|Nevada|New Hampshire|New Jersey|New Mexico|New York|North Carolina|North Dakota|' + \
                              r'Ohio|Oklahoma|Oregon|Pennsylvania|Rhode Island|South Carolina|South Dakota|Tennessee|Texas|Utah|Vermont|' + \
                              r'Virginia|Washington|West Virginia|Wisconsin|Wyoming)\s*\d{5}(?:-\d{4})?\b'

    ZIPCODE_PATTERN = r'(?<!\d{3}-\d{3}-)\b\d{5}(?:-\d{4})?\b(?!-\d{3})'

    HOUSE_NUMBER_PATTERN = r'\b\d{1,5}(?=\s+[A-Za-z]+\s+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Lane|Ln|Drive|Dr|Court|Ct)\b)'

    STREET_SUFFIX_PATTERN = r'\b\d{1,5}\s+[A-Za-z\s]+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Lane|Ln|Drive|Dr|Court|Ct)\b'
    SANDWICHED_PATTERN = r'\[ADDRESS_REDACTED\](?:,)?\s+(.*?)\s+\[ADDRESS_REDACTED\]|\[ADDRESS_REDACTED\]\s+in\s+([A-Za-z\s]+)\b'

    masked_captions = df['caption'].tolist()
    for i, caption in enumerate(masked_captions):
        caption = clean_text(caption)
        caption = re.sub(CITY_STATE_ZIP_PATTERN, '[ADDRESS_REDACTED]', caption)
        caption = re.sub(ZIPCODE_PATTERN, '[ADDRESS_REDACTED]', caption)
        caption = re.sub(STATE_PATTERN, '[ADDRESS_REDACTED]', caption)
        caption = re.sub(STREET_SUFFIX_PATTERN, '[ADDRESS_REDACTED]', caption)
        caption = re.sub(HOUSE_NUMBER_PATTERN, '[ADDRESS_REDACTED]', caption)
        caption = re.sub(SANDWICHED_PATTERN, '[ADDRESS_REDACTED]', caption)
        masked_captions[i] = caption
    df['caption'] = masked_captions
    return df


def mask_pii_in_captions(contact_id, df, project_id):
    """Masks PII data in the 'caption' column, removing post-processing phone unredaction."""
    logger.info(f"{contact_id}: Masking PII Data")

    masked_df = df.copy()
    masked_df['original_index'] = masked_df.index
    masked_df['previous_caption'] = masked_df['caption'].shift(1)

    # Apply contextual card number, expiration date, and CVV redaction FIRST
    masked_df = mask_card_numbers_contextually(masked_df)
    masked_df = mask_cvv_contextually(masked_df)
    masked_df = mask_expiration_date_contextually(masked_df)
    masked_df = mask_emails_contextually(masked_df)
    masked_df = mask_routing_number(masked_df)
    masked_df = mask_account_number(masked_df)
    masked_df = mask_card_ending(masked_df)
    masked_df = mask_account_ending(masked_df)
    masked_df = mask_address(masked_df)
    masked_df = mask_check_number(masked_df)
    
    cvv_requested = False
    exp_requested = False

    def preprocess_text(row):
        nonlocal cvv_requested, exp_requested
        result = row['caption']
        if not re.search(r'\b(?:cvv|security code|digits on the back|card verification|3 digits at the back of the card)\b', row['caption'], re.IGNORECASE):
            cvv_requested = False
        if not re.search(r'\b(?:exp|expires|expiration|expiry)\b', row['caption'], re.IGNORECASE):
            exp_requested = False

        if re.search(r'\b(?:cvv|security code|digits on the back|card verification|3 digits at the back of the card)\b', row['caption'], re.IGNORECASE):
            cvv_requested = True
        if re.search(r'\b(?:exp|expires|expiration|expiry)\b', row['caption'], re.IGNORECASE):
            exp_requested = True
        return result

    masked_df['caption'] = masked_df.apply(preprocess_text, axis=1)
    

    masked_df['marked_caption'] = masked_df.index.astype(str) + "|||SEPARATOR|||" + masked_df['caption'].astype(str)
    all_captions = "\n===RECORD_BOUNDARY===\n".join(masked_df['marked_caption'])


    dlp_client = dlp_v2.DlpServiceClient()
    parent = f"projects/{project_id}/locations/global"

    inspect_config = {
        "info_types": [
            {"name": "CREDIT_CARD_NUMBER"},
            {"name": "STREET_ADDRESS"},
            
            {"name": "IP_ADDRESS"},
            {"name": "DATE_OF_BIRTH"},
            {"name": "PHONE_NUMBER"}, # add phone number to inspect
            {"name": "EMAIL_ADDRESS"} # added email address
        ],
        "custom_info_types": [
            {
                "info_type": {"name": "CREDIT_CARD_EXPIRATION_DATE"},
                "regex": {"pattern": r'\b(0[1-9]|1[0-2])[-/](\d{2}|\d{4})\b|\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?,?\s*\d{2,4}\b'},
                "likelihood": dlp_v2.Likelihood.POSSIBLE
            }
        ],
        "min_likelihood": dlp_v2.Likelihood.POSSIBLE
    }

    deidentify_config = {
        "info_type_transformations": {
            "transformations": [
                {
                    "info_types": [
                        {"name": "CREDIT_CARD_NUMBER"},
                        {"name": "CREDIT_CARD_EXPIRATION_DATE"},
                        {"name": "STREET_ADDRESS"},
                        
                        {"name": "IP_ADDRESS"},
                        {"name": "DATE_OF_BIRTH"},
                        {"name": "EMAIL_ADDRESS"} # added email address
                    ],
                    "primitive_transformation": {
                        "replace_config": {"new_value": {"string_value": "[REDACTED]"}}
                    }
                }
            ]
        }
    }

    try:
        response = dlp_client.deidentify_content(
            request={
                "parent": parent,
                "deidentify_config": deidentify_config,
                "inspect_config": inspect_config,
                "item": {"value": all_captions}
            }
        )
    except Exception as e:
        logger.error(f"{contact_id}: Error in DLP API call: {e}")
        return df

    processed_content = response.item.value
    processed_records = processed_content.split("\n===RECORD_BOUNDARY===\n")

    processed_dict = {
        int(parts[0]): parts[1]
        for record in processed_records
        if (parts := record.split("|||SEPARATOR|||", 1)) and len(parts) == 2
    }

    masked_df['caption'] = masked_df.apply(
        lambda row: processed_dict.get(row['original_index'], row['caption']),
        axis=1
    )

    masked_df.drop(['original_index', 'marked_caption', 'previous_caption'], axis=1, inplace=True)

    logger.info(f"{contact_id}: Completed Masking PII Data")

    return masked_df

    
def create_intra_call_df(aws_access_key: str, aws_secret_key: str, transcript_data: dict, contact_id: str, gcp_project_id: str):
    intra_call = process_transcript(transcript_data, contact_id)        
    df_sentiment_scores = get_sentiment_scores(intra_call.caption.to_list())
    intra_call = pd.concat([intra_call, df_sentiment_scores], axis=1)    
    intra_call = get_different_times(intra_call)
    intra_call = mask_pii_in_captions(contact_id, intra_call, gcp_project_id)
    
    return intra_call

# Create Inter-call Dataframe

In [8]:
class CategoryValidator:
    def __init__(self, cat_subcat_mapping):
        """
        Initialize with category mapping from a Snowflake View.
        :param snowflake_conn_params: Dictionary containing Snowflake connection details.
        :param view_name: Name of the Snowflake View containing category mappings.
        """
        self.category_mapping = cat_subcat_mapping
        self.valid_categories = set(self.category_mapping['CATEGORY'].unique())
        self.category_subcategory_map = self._create_category_mapping()

    def _create_category_mapping(self):
        """Create category to subcategory mapping."""
        mapping = {}
        for _, row in self.category_mapping.iterrows():
            if row['CATEGORY'] not in mapping:
                mapping[row['CATEGORY']] = set()
            mapping[row['CATEGORY']].add(row['SUBCATEGORY'])
        return mapping

    def validate_category(self, category: str) -> bool:
        """Check if category is valid."""
        return category in self.valid_categories

    def validate_subcategory(self, category: str, subcategory: str) -> bool:
        """Check if subcategory is valid for given category."""
        return category in self.category_subcategory_map and subcategory in self.category_subcategory_map[category]

    def get_valid_subcategories(self, category: str):
        """Get valid subcategories for a category."""
        return self.category_subcategory_map.get(category, set())

class CallSummary(BaseModel):
    summary: str = Field(..., max_length=500)

class CallTopic(BaseModel):
    primary_topic: str = Field(..., max_length=100)
    category: str = Field(..., max_length=100)
    sub_category: str = Field(..., max_length=100)

    def validate_category_mapping(self, category_validator: CategoryValidator) -> bool:
        """Validate category and subcategory against mapping"""
        if not category_validator.validate_category(self.category):
            logger.error(f"Invalid category: {self.category}")
        if not category_validator.validate_subcategory(self.category, self.sub_category):
            logger.error(f"Invalid subcategory '{self.sub_category}' for category '{self.category}'")

class AgentCoaching(BaseModel):
    strengths: List[str] = Field(..., max_items=3)
    improvement_areas: List[str] = Field(..., max_items=3)
    specific_recommendations: List[str] = Field(..., max_items=4)
    skill_development_focus: List[str] = Field(..., max_items=3)

class TranscriptAnalysis(BaseModel):
    call_summary: CallSummary
    call_topic: CallTopic
    agent_coaching: AgentCoaching

class KPIExtractor:
    def __init__(self, project_id: str, location: str, excel_path: str):
        vertexai.init(project=project_id, location=location)
        self.model = GenerativeModel("gemini-1.5-flash-002")
        self.category_validator = CategoryValidator(cat_subcat_mapping)
        
        self.generation_config = {
            "temperature": 0.3,
            "max_output_tokens": 1024,
            "top_p": 0.8,
            "top_k": 40,
            "response_format": "json"
        }
        
        self.safety_settings = {
            generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
            generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
            generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
            generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
        }

    
    def get_categories_prompt(self) -> str:
        """Create prompt section for valid categories and subcategories, handling null values"""
        categories_prompt = []
        
        for category, subcategories in self.category_validator.category_subcategory_map.items():
            if category is None:  # Skip if category is None
                continue
            
            # Ensure subcategories are valid (remove None values)
            valid_subcategories = [subcat for subcat in subcategories if subcat is not None]
            
            if valid_subcategories:
                subcats = ', '.join(sorted(valid_subcategories))
            else:
                subcats = "No defined subcategories"
            
            categories_prompt.append(f"Category '{category}' can have subcategories: {subcats}")
        
        return '\n'.join(categories_prompt)

    def create_prompt(self, transcript: str) -> str:
        """Create structured prompt with category guidance"""
        categories_guidance = self.get_categories_prompt()
        
        return f"""
        Analyze this call transcript and provide a structured analysis in the exact JSON format specified below.
        Keep responses concise, specific, and actionable.
    
        Guidelines:
        - Call summary should be factual and highlight key interactions
        - Topics and categories MUST match the following valid mappings:
        {categories_guidance}
        - If the calls are received, forwarded or reached to a voicemail then always map:
            Category="Unsuccessful Contact" and Sub-Category="Voicemail"
        - Coaching points should be specific and actionable
        - All responses must follow the exact structure specified
        - Ensure all lists have the specified maximum number of items
        - All text fields must be clear, professional, and free of fluff
    
        Transcript:
        {transcript}
    
        Required Output Structure:
        {{
            "call_summary": {{
                "summary": "3-4 line overview of the call"
            }},
            "call_topic": {{
                "primary_topic": "Main topic of discussion",
                "category": "MUST BE ONE OF THE VALID CATEGORIES LISTED ABOVE",
                "sub_category": "MUST BE A VALID SUB-CATEGORY FOR THE CHOSEN CATEGORY"
            }},
            "agent_coaching": {{
                "strengths": ["Strength 1", "Strength 2", "Strength 3"],
                "improvement_areas": ["Area 1", "Area 2", "Area 3"],
                "specific_recommendations": ["Rec 1", "Rec 2", "Rec 3", "Rec 4"],
                "skill_development_focus": ["Skill 1", "Skill 2", "Skill 3"]
            }}
        }}
    
        Rules:
        1. Maintain exact JSON structure
        2. No additional fields or comments
        3. No markdown formatting
        4. Ensure all arrays have the exact number of items specified
        5. Keep all text concise and professional
        6. Do not mention any PII information such as Customer Name etc.
        7. STRICTLY use only the categories and subcategories from the provided mapping
        """
    
    def extract_json(self, response: str) -> Dict:
        """Extract valid JSON from response"""
        match = re.search(r'```json\s*([\s\S]*?)\s*```', response)
        if match:
            json_str = match.group(1)
        else:
            json_str = response.strip()
        
        try:
           return json.loads(json_str)
        except json.JSONDecodeError:
           logger.error("Invalid JSON response")
           pass

    def validate_response(self, response_json: Dict, contact_id: str = None) -> TranscriptAnalysis:
        """Validate response using Pydantic models and category mapping"""
        try:
            # First validate basic structure with Pydantic
            analysis = TranscriptAnalysis(**response_json)
            
            # Then validate category mapping
            analysis.call_topic.validate_category_mapping(self.category_validator)
            
            return analysis
        except ValidationError as e:
            logger.error(f"{contact_id if contact_id else ''}: Pydantic validation error - {e}")
            pass
        except ValueError as e:
            logger.error(f"{contact_id if contact_id else ''}: Category validation error - {e}")
            pass


    def extract_genai_kpis(
       self,
       transcript: str,
       contact_id: str = None
    ):
        """Extract KPIs from transcript with validation and retries"""
        max_retries = 3
        attempt = 0
    
        while attempt < max_retries:
            try:
                # Generate prompt
                prompt = self.create_prompt(transcript)
    
                # Get response from Gemini
                response = self.model.generate_content(prompt)
    
                # Parse JSON response
                response_json = self.extract_json(response.text)
    
                # If response is empty, retry
                if not response_json or "NA" in response_json.values():
                    logger.warning(f"Attempt {attempt + 1}: Gemini returned NA or empty response. Retrying...")
                    attempt += 1
                    time.sleep(2)  # Wait before retrying
                    continue
    
                # Validate response
                validated_response = self.validate_response(response_json, contact_id)
    
                if validated_response:
                    return validated_response.model_dump()
    
                logger.warning(f"Attempt {attempt + 1}: Invalid response structure. Retrying...")
                attempt += 1
                time.sleep(2)  # Wait before retrying
    
            except Exception as e:
                logger.error(f"Attempt {attempt + 1}: Error extracting KPIs: {str(e)}")
                attempt += 1
                time.sleep(2)  # Wait before retrying
    
        logger.error(f"Failed to extract valid KPIs after {max_retries} attempts.")
        return {"error": "Failed to extract KPIs after multiple attempts"}

def dict_to_newline_string(data: dict) -> str:
    """Converts a dictionary into a new-line formatted string."""
    formatted_str = ""
    for key, value in data.items():
        formatted_str += f"{key}:\n"
        for item in value:
            formatted_str += f"  - {item}\n"
    return formatted_str.strip()

def create_inter_call_df(
    gcp_project_id: str,
    gcp_prjct_location: str,
    df_intra_call: pd.DataFrame,
    transcript_data: dict,
    ac_last_modified_date: datetime,
    cat_subcat_mapping: pd.DataFrame
):
    try:
        contact_id = df_intra_call.contact_id.unique        
    
        # logger.info(f"{contact_id}: Extracting KPIs from Gemini")      
        extractor = KPIExtractor(gcp_project_id, gcp_prjct_location, excel_path)
        transcript = " ".join(df_intra_call.caption)
        call_gen_kpis = extractor.extract_genai_kpis(transcript)
        # logger.info(f"{contact_id}: Completed Extracting KPIs from Gemini") 
    
        # logger.info(f"{contact_id}: Creating Inter Call df")
        inter_call_dict = {}
        inter_call_dict['contact_id'] = str(df_intra_call['contact_id'][0])
        inter_call_dict['call_text'] = " ".join(df_intra_call.caption)
        inter_call_dict['call_summary'] = call_gen_kpis['call_summary']['summary']
        inter_call_dict['topic'] = call_gen_kpis['call_topic']['primary_topic']
        inter_call_dict['category'] = call_gen_kpis['call_topic']['category']
        inter_call_dict['sub_category'] = call_gen_kpis['call_topic']['sub_category']
        inter_call_dict['agent_coaching'] = dict_to_newline_string(call_gen_kpis['agent_coaching'])
        
        df_inter_call = pd.DataFrame(pd.Series(inter_call_dict)).T
        
        # Replace values where Categories are not in allowed list
        allowed_categories = cat_subcat_mapping['CATEGORY'].drop_duplicates().to_list()
        df_inter_call.loc[
            ~df_inter_call['category'].isin(allowed_categories) | df_inter_call['category'].isna(),
            ['category', 'sub_category']
        ] = 'Unspecified'
    
        # Add metadata from AWS
        # df_inter_call['account_id'] = transcript_data['AccountId']
        df_inter_call['agent_speech_speed'] = transcript_data['ConversationCharacteristics']['TalkSpeed']['DetailsByParticipant']['AGENT']['AverageWordsPerMinute']
        df_inter_call['customer_speech_speed'] = transcript_data['ConversationCharacteristics']['TalkSpeed']['DetailsByParticipant']['CUSTOMER']['AverageWordsPerMinute']
        df_inter_call['total_talktime_agent_second'] = int(transcript_data['ConversationCharacteristics']['TalkTime']['DetailsByParticipant']['AGENT']['TotalTimeMillis']/1000)
        df_inter_call['total_talktime_customer_second'] = int(transcript_data['ConversationCharacteristics']['TalkTime']['DetailsByParticipant']['CUSTOMER']['TotalTimeMillis']/1000)
        df_inter_call['total_talktime_call_second'] = int(transcript_data['ConversationCharacteristics']['TalkTime']['TotalTimeMillis']/1000)
        df_inter_call['total_duration_call_second'] = int(transcript_data['ConversationCharacteristics']['TotalConversationDurationMillis']/1000)
        df_inter_call['total_dead_air_call_second'] = df_inter_call['total_duration_call_second'] - df_inter_call['total_talktime_call_second']
        df_inter_call['call_language'] = transcript_data['LanguageCode']
        df_inter_call['call_s3_uri'] = transcript_data['CustomerMetadata']['InputS3Uri']
        df_inter_call['ac_last_modified_date'] = ac_last_modified_date
        # df_inter_call['load_date'] = datetime.now()
        
        return df_inter_call

    except Exception as e:
        logger.info
        logger.error(f"{contact_id}: Error Creating Intra Call df: {e}")
        logger.info("")
        pass

# Writing Dataframe to Snowflake

In [9]:
def insert_new_records(conn, table_name, df):
    """
    Inserts only new records (based on ID) into Snowflake table with UTC load timestamp.

    Steps:
    1. Fetches existing IDs from table.
    2. Filters out rows with existing IDs from DataFrame.
    3. Adds 'LOAD_DATE_UTC' column with current UTC timestamp.
    4. Inserts only new records.

    Args:
        conn: Snowflake connection object.
        table_name (str): Name of the target table.
        df (pd.DataFrame): DataFrame containing the data (must have 'CONTACT_ID' column).

    Returns:
        int: Number of inserted records.
    """    
    cursor = conn.cursor()

    # Step 1: Get existing IDs from Snowflake table
    cursor.execute(f"SELECT DISTINCT(CONTACT_ID) FROM {table_name}")
    existing_ids = {row[0] for row in cursor.fetchall()}
    
    # Step 2: Filter DataFrame to keep only new records
    new_records_df = df[~df['CONTACT_ID'].isin(existing_ids)]
    
    if new_records_df.empty:
        logger.info("No new records to insert")
        return 0
    
    # Step 3: Add UTC timestamp column
    utc_now = datetime.now(pytz.utc).strftime('%Y-%m-%d %H:%M:%S')
    new_records_df = new_records_df.copy()  # Avoid modifying original df
    new_records_df["LOAD_DATE"] = utc_now  # Add new column

    # Step 4: Insert new records into Snowflake
    success, nchunks, nrows, _ = write_pandas(conn, new_records_df, table_name)
    
    logger.info(f"Inserted {nrows} new records with UTC load date")
    logger.info(f"Skipped {len(df) - len(new_records_df)} existing records")
    
    cursor.close()
    return nrows

### Handle Duplicates

In [10]:
def handle_duplicates(data_frame, columns_to_check):
    """
    Dedupes the final Dataframes to be written to the Snowflake Tables
    """
    # Remove duplicate records based on the specified columns, keeping the first occurrence
    df_cleaned = data_frame.drop_duplicates(subset=columns_to_check, keep="first")

    # (Optional) Reset index after removing duplicates
    df_cleaned = df_cleaned.reset_index(drop=True)

    return df_cleaned

# Exception Handling

# Logging Handling

In [11]:
def setup_logger(log_file):
    """
    Sets up a logger that writes to both file and console with timestamp.

    Args:
        log_file (str): Name of the log file to write to.

    Returns:
        logger: Configured logger instance.
    """
    logger = logging.getLogger('voice_ai_logger')

    # Reset handlers if already exist
    if logger.hasHandlers():
        logger.handlers.clear()

    # Set log level
    logger.setLevel(logging.DEBUG)

    # Create formatter
    formatter = logging.Formatter(
        '%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S'
    )

    # Create file handler
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # # Create a stream handler (console)
    # console_handler = logging.StreamHandler()
    # console_handler.setFormatter(formatter)
    # logger.addHandler(console_handler)

    return logger


def reset_logging():
    """Removes all logging handlers and resets the logger."""
    # Get all loggers
    loggers = list(logging.root.manager.loggerDict.values())
    
    # Include root logger explicitly
    loggers.append(logging.getLogger())

    for logger in loggers:
        if isinstance(logger, logging.Logger):  # Ensure it's a valid logger instance
            for handler in logger.handlers[:]:  # Copy list to avoid modification issues
                logger.removeHandler(handler)
                handler.close()

            logger.setLevel(logging.NOTSET)
            logger.propagate = False  # Avoid duplicate logs from propagation

# Main Function

In [12]:
# Get current date
current_date = datetime.now()

# Get data for last 2 days
num_days = 2
dates_to_check = []
for i in range(num_days):
    check_date = current_date - timedelta(days=i)
    dates_to_check.append(check_date)

# Get data for the last 12 hrs:
start_datetime = current_datetime - timedelta(hours=12)

# Process each date
folderlist = []
for date in dates_to_check:
    year = str(date.year)
    month = f"{date.month:02d}"
    day = f"{date.day:02d}"
    
    # Construct the prefix for S3 listing
    prefix = f"{s3_transcripts_location}/{year}/{month}/{day}/"
    folderlist.append(prefix)
folderlist[:num_days]

['Analysis/Voice/2025/04/09/', 'Analysis/Voice/2025/04/08/']

In [13]:
reset_logging()

# Setup logger
log_file='voice_ai_runtime_Masking.logs'
logger = setup_logger(log_file)

In [14]:
catsubcat_conn_params = {
    'account': configs.get('snf_account'),
    'user': configs.get('snf_user'),
    'private_key_file': configs.get('snf_private_key_file'),
    'private_key_file_pwd':configs.get('snf_private_key_pwd'),
    'warehouse': configs.get('snf_warehouse'),
    'database': 'POSIGEN_DEV',
    'schema': configs.get('snf_catsubcat_schema')
}

catsubcat_conn_params

{'account': 'XV37144.us-central1.gcp',
 'user': 'GCP_INTEGRATION',
 'private_key_file': '../../sun/secrets/snowflakegcp_rsa_key.p8',
 'private_key_file_pwd': '$07@rF0r@77!',
 'warehouse': 'DATAPLATR',
 'database': 'POSIGEN_DEV',
 'schema': 'SIGMA_CX'}

In [15]:
logger.info("Fetching Category, Sub-category Mapping.")
cat_subcat_mapping = fetch_category_mapping_from_snowflake(catsubcat_conn_params)

  df = pd.read_sql(query, conn)


In [16]:
# Get the transcripts in to_process_folder
df_list_transcripts = list_new_transcripts_from_folderlist(aws_access_key, aws_secret_key, s3_source_bucket, s3_transcripts_location, folderlist[:num_days])
df_list_transcripts.to_csv("df_list_transcripts.csv")

logger.info("")
logger.info(f"Transcripts to process: {len(df_list_transcripts)}")
# logger.info(df_list_transcripts.groupby(['File_Date']).size().reset_index(name='frequency'))
# logger.info(df_list_transcripts[:3000].groupby(['File_Date', 'Time_Bin']).size().reset_index(name='Count'))
logger.info("")

In [17]:
# Initiating Master DataFrames
print("Called: To make sure previous calls are loaded into memory again.")
logger.info("Called: Initiate Master Dataframes")
df_intra_calls_data, df_inter_calls_data = initiate_master_dataframes()
print(len(df_inter_calls_data))
print(len(df_intra_calls_data))

Called: To make sure previous calls are loaded into memory again.
0
0


In [18]:
# If there are transcripts to be processed
if len(df_list_transcripts) == 0:
    logger.info("No Transcripts to Process")
    logger.info("")

else:
    for transcript in df_list_transcripts.File.to_list():
        logger.info("--------------------------")
        start_time = time.time()
        try:
            # get the call ID
            contact_id = transcript.split('/')[-1].split('.')[0].split('analysis')[0].strip('_')
            ac_last_modified_date = datetime.strptime(
                                            transcript.split('analysis_')[-1].split('.')[0].replace('Z', ""), 
                                            "%Y-%m-%dT%H:%M:%S"
                                        ).strftime('%Y-%m-%d %H:%M:%S')
    
            # Check if Call Already Processed
            if (len(df_intra_calls_data) > 0 and contact_id in df_intra_calls_data.CONTACT_ID.unique()) and (len(df_inter_calls_data) > 0 and contact_id in df_inter_calls_data.CONTACT_ID.unique()):
                logger.info(f"{contact_id}: Call already Processed.")
                # break
    
            else:
                # get the audio transcript file name
                logger.info(f"{contact_id}: Processing") 
                
                # Get the Transcript file from S3 Bucket
                logger.info(f"{contact_id}: Fetching Transcript from S3")
                transcript_data = fetch_transcript_from_s3(aws_access_key, aws_secret_key, s3_source_bucket, transcript)
                logger.info(f"{contact_id}: Successfully fetched the Transcript from S3 {len(transcript_data)}")
            
                if transcript_data: 
                    # Create the Inter Call KPIs
                    logger.info(f"{contact_id}: Creating df_intra_call ")
                    df_intra_call = create_intra_call_df(aws_access_key, aws_secret_key, transcript_data, contact_id, gcp_project_id)
                    logger.info(f"{contact_id}: Successfully created df_intra_call ")
                        
                    # Create the Intra Call KPIs
                    logger.info(f"{contact_id}: Creating df_inter_call ")
                    df_inter_call = create_inter_call_df(gcp_project_id, gcp_prjct_location, df_intra_call, transcript_data, ac_last_modified_date, cat_subcat_mapping)
                    logger.info(f"{contact_id}: Successfully created df_inter_call ")


                    # ###============================================================###
                    # Save DataFrames only when both are having data
                    if not df_intra_call.empty and not df_inter_call.empty:
                        # Appending to Intra-calls Master DataFrame
                        df_intra_call.columns = df_intra_call.columns.str.upper()  # Capitalising Column names for Snowflake
                        df_intra_calls_data = pd.concat([df_intra_calls_data, df_intra_call], ignore_index=True)
                        df_intra_calls_data.to_csv("df_intra_calls_data.csv", index=False)
                        logger.info(f"{contact_id}: Persisted df_intra_calls_data to CSV.")

                        # Appending to Inter-calls Master DataFrame
                        df_inter_call.columns = df_inter_call.columns.str.upper()  # Capitalising Column names for Snowflake
                        df_inter_calls_data = pd.concat([df_inter_calls_data, df_inter_call], ignore_index=True)
                        df_inter_calls_data.to_csv("df_inter_calls_data.csv", index=False)
                        logger.info(f"{contact_id}: Persisted df_intra_calls_data to CSV.")
                        # logger.info(f"{contact_id}: Processing Complete")

                    else:
                        if df_intra_call.empty:
                            logger.error("Intra Call DataFrame was not created successfully.")
                        if df_inter_call.empty:
                            logger.error("Inter Call DataFrame was not created successfully.")
            
            if len(df_inter_calls_data)%20 == 0:
                logger.info("--------------------------")
                logger.info(f"Processed {len(df_inter_calls_data)} files")
                logger.info("--------------------------")
        
        except Exception as e:
            logger.error("--------------------------")
            logger.error(f"{contact_id}: Exception - {e}")
            logger.error("--------------------------")
            continue

            
        end_time = time.time()
        elapsed_time = end_time - start_time  # Time taken for this iteration
        minutes, seconds = divmod(elapsed_time, 60)
        logger.info(f"{contact_id}: Processed Call #{len(df_inter_calls_data)} in {int(minutes)} min {seconds:.2f} sec elapsed")
        logger.info("--------------------------")
        logger.info("")
        logger.info("")
        
    logger.info(f"Removing duplicates in df_inter_calls_data.")
    columns_to_check = ["CONTACT_ID"]
    df_inter_calls_data = handle_duplicates(df_inter_calls_data, columns_to_check)
    logger.info(f"Removing duplicates in df_intra_calls_data.")
    columns_to_check = ["CONTACT_ID", "SPEAKER_TAG", "CAPTION", "START_TIME_SECOND", "END_TIME_SECOND"]
    df_intra_calls_data = handle_duplicates(df_intra_calls_data, columns_to_check)

    logger.info(f"Completed processing {len(df_list_transcripts)} Calls")

# Load Data to Snowflake

In [19]:
conn_params = {
    'account': configs.get('snf_account'),
    'user': configs.get('snf_user'),
    'private_key_file': configs.get('snf_private_key_file'),
    'private_key_file_pwd':configs.get('snf_private_key_pwd'),
    'warehouse': configs.get('snf_warehouse'),
    'database': 'POSIGEN_DEV',
    'schema': configs.get('snf_schema')
}

conn_params

{'account': 'XV37144.us-central1.gcp',
 'user': 'GCP_INTEGRATION',
 'private_key_file': '../../sun/secrets/snowflakegcp_rsa_key.p8',
 'private_key_file_pwd': '$07@rF0r@77!',
 'warehouse': 'DATAPLATR',
 'database': 'POSIGEN_DEV',
 'schema': 'SRC_GCP'}

In [20]:
start_time = time.time()
logger.info(f"Writing Dataframe to Snowflake.")
conn = sc.connect(**conn_params)

table_name ='SRC_GCP_INTER_CALLS_TEMP'    
logger.info(f"Writing data to table: {conn_params['database']}.{table_name}")
insert_new_records(conn, table_name, df_inter_calls_data)
end_time = time.time()
elapsed_time = end_time - start_time  # Time taken for this iteration
minutes, seconds = divmod(elapsed_time, 60)
logger.info(f"SRC_GCP_INTER_CALLS: Inserted records #{len(df_inter_calls_data)} in {int(minutes)} min {seconds:.2f} sec elapsed")
        
start_time = time.time()
logger.info(f"Writing data to table: {conn_params['database']}.{table_name}")
table_name ='SRC_GCP_INTRA_CALLS_TEMP'
insert_new_records(conn, table_name, df_intra_calls_data)
end_time = time.time()
elapsed_time = end_time - start_time  # Time taken for this iteration
minutes, seconds = divmod(elapsed_time, 60)
logger.info(f"SRC_GCP_INTRA_CALLS: Inserted records #{len(df_intra_calls_data)} in {int(minutes)} min {seconds:.2f} sec elapsed")

conn.close()

# Validate

In [25]:
df_inter_calls_data.CATEGORY.value_counts()

CATEGORY
Unsuccessful Contact          1037
Customer Inquiry               245
Billing                        219
Production                     119
Transfer                        42
Damage                          31
Customer Requested Removal      18
Monitoring Portal               15
Performance                     15
Pre Board                        7
Unavoidable Casualty             7
Monitoring                       6
Admin                            4
Performance Guaranteee           1
Energy Efficiency                1
Name: count, dtype: int64