# Imports

In [34]:
# from kfp import dsl
# from kfp.v2 import compiler
from google.cloud import aiplatform
from google.cloud import dlp_v2

from typing import List, Dict
import boto3
from botocore.exceptions import ClientError
import json, os, ast, re
import logging
import time
from datetime import datetime, timedelta
import pandas as pd, numpy as np
from scipy.special import softmax
from pydantic import BaseModel, Field, ValidationError
import pytz

import threading
from concurrent.futures import ThreadPoolExecutor, as_completed


import snowflake.connector as sc
from snowflake.connector.pandas_tools import write_pandas

import vertexai
import vertexai.preview.generative_models as generative_models
from vertexai.generative_models import GenerativeModel, GenerationConfig, Part

# Sentiments
from transformers import pipeline
from transformers import AutoTokenizer, AutoConfig
from transformers import AutoModelForSequenceClassification

# Variables

In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Temporary secrets manager
with open("../secrets/configs.json", 'r') as secrets_file:
    configs = json.load(secrets_file)

loc_logs = configs.get("loc_logs")
excel_path = configs.get("excel_path")

aws_access_key = configs.get("aws_access_key")
aws_secret_key = configs.get("aws_secret_key")

# AWS
s3_source_bucket = configs.get('s3_source_bucket')
s3_transcripts_location = configs.get('s3_transcripts_location')

# GCP
gcp_project_id=configs.get('gcp_project_id')
gcp_prjct_location=configs.get('gcp_prjct_location')

# Snowflake
private_key_file = configs.get('snowflakegcp_rsa_key')
private_key_file_pwd = configs.get('snf_ssh_key_pass')


catsubcat_conn_params = {
    'account': configs.get('snf_account'),
    'user': configs.get('snf_user'),
    'private_key_file': configs.get('snf_private_key_file'),
    'private_key_file_pwd':configs.get('snf_private_key_pwd'),
    'warehouse': configs.get('snf_warehouse'),
    'database': 'POSIGEN_DEV',
    'schema': configs.get('snf_catsubcat_schema')
}

# Define the Snowflake View containing category mappings
snf_catsubcat_view = configs.get('snf_catsubcat_view')

conn_params = {
    'account': configs.get('snf_account'),
    'user': configs.get('snf_user'),
    'private_key_file': configs.get('snf_private_key_file'),
    'private_key_file_pwd':configs.get('snf_private_key_pwd'),
    'warehouse': configs.get('snf_warehouse'),
    'database': 'POSIGEN',
    'schema': configs.get('snf_schema')
}

# # Sentiment Scores
MODEL = f"cardiffnlp/twitter-roberta-base-sentiment-latest"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
config = AutoConfig.from_pretrained(MODEL)
model_sentiment = AutoModelForSequenceClassification.from_pretrained(MODEL)

Some weights of the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment-latest were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Util Functions

## Misc Utils

### Initiate Master Inter and Intra Dataframes

In [4]:
def initiate_master_dataframes():
    if os.path.isfile("df_intra_calls_data.csv"):
        logger.info("df_intra_calls_data.csv exists.") 
        df_intra_calls_data = pd.read_csv("df_intra_calls_data.csv")
        df_intra_calls_data.CONTACT_ID = df_intra_calls_data.CONTACT_ID.astype('string')
    else:
        logger.info("df_intra_calls_data.csv does not exists.")
        df_intra_calls_data = pd.DataFrame()

    if os.path.isfile("df_inter_calls_data.csv"):
        logger.info("df_inter_calls_data.csv exists.")
        df_inter_calls_data = pd.read_csv("df_inter_calls_data.csv")
        df_inter_calls_data.CONTACT_ID = df_inter_calls_data.CONTACT_ID.astype('string')
    else:
        logger.info("df_inter_calls_data.csv does not exists.")
        df_inter_calls_data = pd.DataFrame()

    return df_intra_calls_data, df_inter_calls_data

## Function: Listing Transcripts

In [5]:
def list_new_transcripts_from_folderlist(
    aws_access_key,
    aws_secret_key,
    source_bucket,
    custom_location,
    folderlist
):
    """
    Fetch audio file metadata from S3 folders using pagination.
    """

    s3_client = boto3.client(
        's3',
        aws_access_key_id=aws_access_key,
        aws_secret_access_key=aws_secret_key
    )

    # logger.info("Fetching New Transcripts to process...")
    all_files = []

    # Fetch files from each folder in the list
    for folder in folderlist:
        try:
            paginator = s3_client.get_paginator('list_objects_v2')
            pages = paginator.paginate(Bucket=source_bucket, Prefix=folder)

            for page in pages:
                for obj in page.get('Contents', []):
                    file_path = obj['Key']
                    s3_ts = obj['LastModified']

                    # Skip non-JSON files
                    if not file_path.endswith('.json'):
                        continue

                    call_id = file_path.split('/')[-1].split("_analysis_")[0]
                    call_timestamp = pd.to_datetime(file_path.split('analysis_')[-1].split('.')[0].replace('Z', ""))

                    all_files.append({
                        'File': file_path,
                        'ID': call_id,
                        'File_Timestamp': call_timestamp,
                        'File_Date': call_timestamp.date().strftime('%Y-%m-%d'),
                        'File_Time': call_timestamp.time().strftime('%H:%M:%S'),
                        'S3_Timestamp': s3_ts,
                        'S3_Date': s3_ts.strftime('%Y-%m-%d'),
                        'S3_Time': s3_ts.strftime('%H:%M:%S')
                    })

        except ClientError as e:
            logger.error(f"Error accessing S3: {e}")
            continue

    if all_files:
        df_calls_list = pd.DataFrame(all_files).sort_values(['File_Timestamp'], ascending=False)
        df_calls_list['Time_Bin'] = df_calls_list['File_Timestamp'].dt.floor('2h')
    else:
        df_calls_list = pd.DataFrame()

    return df_calls_list

## Function: Read Transcripts

In [6]:
def fetch_transcript_from_s3(aws_access_key: str, aws_secret_key: str, s3_source_bucket: str, file_key):
    """
    Read Transcript JSON content from a specific file in S3.
    
    :param bucket_name: Name of the S3 bucket
    :param file_key: Full path/key of the JSON file
    :return: Parsed JSON content
    """
    s3_client = boto3.client(
            's3',
            aws_access_key_id=aws_access_key,
            aws_secret_access_key=aws_secret_key
        )
    
    try:
        # Download the file
        response = s3_client.get_object(Bucket=s3_source_bucket, Key=file_key)
        
        # Read the content
        json_content = response['Body'].read().decode('utf-8')
        
        # Parse JSON
        return json.loads(json_content)
    
    except Exception as e:
        logger.info
        logger.error(f"Error reading Transcript JSON file {file_key}: {e}")
        logger.info("")
        pass

## Bulk Download Transcripts

In [40]:
def download_transcripts_to_gcs(
        file,
        s3_client,
        s3_analysis_bucket
    ):
        """Download transcript from S3 and upload to GCS."""

        local_file_path = f"transcripts"  # Temporary local storage
        
        try:
            # Download file from S3
            s3_client.download_file(s3_analysis_bucket, file, local_file_path)

        except Exception as e:
            logger.error(f"Error: Failed to process {file} -> {str(e)}")
            # handle_exception(file, vai_gcs_bucket, pipeline_run_name, gcs_errored_folder, str(e))
            # return None, file

## Fetching Category, Sub-Category Mapping

In [7]:
def fetch_category_mapping_from_snowflake(catsubcat_conn_params):
    """Fetch Category-Subcategory mapping from a Snowflake View and return as DataFrame."""
    try:
        conn = sc.connect(**catsubcat_conn_params)
        query = f"SELECT CATEGORY, SUBCATEGORY FROM {snf_catsubcat_view}"
        df = pd.read_sql(query, conn)
        conn.close()
        return df
    except Exception as e:
        raise RuntimeError(f"Error fetching category mapping from Snowflake: {e}")


# Create Intra-call Dataframe

In [8]:
def millis_to_hhmmss(millis):
    """Convert milliseconds to mm:ss format"""
    total_seconds = int(millis / 1000)
    hours = total_seconds // 3600
    minutes = total_seconds // 60
    seconds = total_seconds % 60
    return f"{minutes:02d}:{seconds:02d}"

def convert_to_seconds(time_str):
    try:
        # Parse time string using datetime
        time_obj = datetime.strptime(time_str, '%H:%M:%S')
        # Convert to timedelta and extract total seconds
        total_seconds = time_obj.minute * 60 + time_obj.second
        return total_seconds
    except ValueError:
        pass

def process_transcript(
    transcript_data: dict,
    contact_id: str
):
    """
    Pre-process the transcript loaded from S3 Buckets:
    1. Load the transcript as Pandas Dataframe.
    2. Select only the necessary columns ['BeginOffsetMillis', 'EndOffsetMillis', 'ParticipantId', 'Content', 'Sentiment', 'LoudnessScore'].
    3. Format the time in minutes and seconds.
    4. Rename the columns for better understanding.
    """
    # Load the Transcript as Pandas Dataframe
    transcript_df = pd.json_normalize(transcript_data['Transcript'])

    # Select the relevant Columns
    columns_to_select = [
        'BeginOffsetMillis',
        'EndOffsetMillis',
        'ParticipantId',
        'Content'
    ]
    formatted_df = transcript_df[columns_to_select].copy()
    
    # Optionally rename columns to reflect their new format
    formatted_df = formatted_df.rename(columns={
        'BeginOffsetMillis': 'Begin_Offset',
        'EndOffsetMillis': 'End_Offset',
        'Content': 'caption',
        'Sentiment': 'sentiment_label',
        'ParticipantId': 'speaker_tag'
    })

    # Inserting the Call ID:
    formatted_df.insert(loc=0, column='contact_id', value=contact_id)
    formatted_df['call_language'] = transcript_data['LanguageCode']

    return formatted_df

def get_sentiment_label(row):
    # Check conditions in order of priority (Positive > Negative > Neutral)
    if row['positive'] > row['negative'] and row['positive'] > row['neutral']:
        return 'Positive'
    elif row['negative'] > row['positive'] and row['negative'] > row['neutral']:
        return 'Negative'
    else:
        return 'Neutral'

def get_sentiment_scores(text_list):
    dict_sentiments = []
    for text in text_list:
        encoded_input = tokenizer(text, return_tensors='pt')
        output = model_sentiment(**encoded_input)
        scores = output[0][0].detach().numpy()
        scores = np.round(np.multiply(softmax(scores), 100), 2)
        merged_dict = dict(zip(list(config.id2label.values()), list(scores)))
        dict_sentiments.append(merged_dict)

    df_dict_sentiments = pd.DataFrame(dict_sentiments)
    df_dict_sentiments['sentiment_lable'] = df_dict_sentiments[['positive','negative','neutral']].apply(get_sentiment_label, axis=1)
    
    return df_dict_sentiments

def get_different_times(intra_call):
    # Apply formatting to both time columns
    intra_call['start_time_second'] = (intra_call['Begin_Offset'] / 1000).astype(int)
    # intra_call['Begin_Offset'] = intra_call['Begin_Offset'].apply(millis_to_hhmmss)
    intra_call['end_time_second'] = (intra_call['End_Offset'] / 1000).astype(int)
    # intra_call['End_Offset'] = intra_call['End_Offset'].apply(millis_to_hhmmss)
    intra_call['time_spoken_second'] = intra_call['end_time_second'] - intra_call['start_time_second']
    intra_call['time_spoken_second'] = intra_call['time_spoken_second'].where(intra_call['time_spoken_second'] >= 0, 0)
    intra_call['time_spoken_second'] = intra_call['time_spoken_second'].fillna(0).astype(int)
    intra_call['time_silence_second'] = intra_call['start_time_second'].shift(-1) - intra_call['end_time_second']
    intra_call['time_silence_second'] = intra_call['time_silence_second'].where(intra_call['time_silence_second'] >= 0, 0)
    intra_call['time_silence_second'] = intra_call['time_silence_second'].fillna(0).astype(int)
    intra_call['load_date'] = datetime.now()

    # Dropping time formatted columns
    intra_call = intra_call.drop(['Begin_Offset', 'End_Offset'], axis=1)

    return intra_call

def mask_pii_in_captions(contact_id, df, project_id):
    """
    Masks PII data in the 'caption' column of a pandas DataFrame using Google Cloud DLP API.
    
    Args:
        contact_id: Identifier for logging purposes
        df (pandas.DataFrame): DataFrame with a 'caption' column to process
        project_id (str): Your Google Cloud project ID
        
    Returns:
        pandas.DataFrame: DataFrame with masked PII in the 'caption' column
    """
    logger.info(f"{contact_id}: Masking PII Data")

    # Create a copy of the DataFrame to avoid modifying the original
    masked_df = df.copy()
    
    # Add unique markers to each caption to identify them after processing
    masked_df['original_index'] = masked_df.index
    masked_df['marked_caption'] = masked_df.index.astype(str) + "|||SEPARATOR|||" + masked_df['caption'].astype(str)
    
    # Concatenate all captions for bulk processing
    all_captions = "\n===RECORD_BOUNDARY===\n".join(masked_df['marked_caption'])
    
    # Initialize DLP client
    dlp_client = dlp_v2.DlpServiceClient()
    
    # Specify the parent resource name
    parent = f"projects/{project_id}/locations/global"
    
    # Custom dictionary detector for PosiGen
    posigen_dictionary = {
        "info_type": {"name": "CUSTOM_DICTIONARY_POSIGEN"},
        "dictionary": {
            "word_list": {
                "words": ["posigen", "Posigen", "PosiGen", "POSIGEN"]
            }
        }
    }
    
    # Configure inspection config with rule set for exclusions
    inspect_config = {
        "info_types": [
            {"name": "CREDIT_CARD_NUMBER"},
            {"name": "CREDIT_CARD_EXPIRATION_DATE"},
            {"name": "STREET_ADDRESS"},
            {"name": "IP_ADDRESS"},
            {"name": "DATE_OF_BIRTH"}
        ],
        "min_likelihood": dlp_v2.Likelihood.POSSIBLE,
        "custom_info_types": [posigen_dictionary],  # ✅ Custom info types should be a list
        "rule_set": [
            {
                "info_types": [{"name": "CUSTOM_DICTIONARY_POSIGEN"}],  # ✅ Specify info_type for rule
                "rules": [
                    {
                        "exclusion_rule": {
                            "matching_type": dlp_v2.MatchingType.MATCHING_TYPE_FULL_MATCH,
                            "dictionary": {
                                "word_list": {
                                    "words": ["posigen", "Posigen", "PosiGen", "POSIGEN"]
                                }
                            }
                        }
                    }
                ]
            }
        ]
    }
    
    # Configure deidentification to use "[REDACTED]" instead of asterisks
    deidentify_config = {
        "info_type_transformations": {
            "transformations": [
                {
                    "info_types": [
                        {"name": "CREDIT_CARD_NUMBER"},
                        {"name": "CREDIT_CARD_EXPIRATION_DATE"},
                        {"name": "STREET_ADDRESS"},
                        {"name": "IP_ADDRESS"},
                        {"name": "DATE_OF_BIRTH"}
                    ],
                    "primitive_transformation": {
                        "replace_config": {
                            "new_value": {"string_value": "[REDACTED]"}
                        }
                    }
                }
            ]
        }
    }
    
    # Create deidentify request
    item = {"value": all_captions}
    
    # Call the DLP API
    try:
        response = dlp_client.deidentify_content(
            request={
                "parent": parent,
                "deidentify_config": deidentify_config,
                "inspect_config": inspect_config,
                "item": item,
            }
        )
    except Exception as e:
        logger.error(f"{contact_id}: Error in DLP API call: {e}")
        return df  # Return original DataFrame if masking fails

    # Get processed content and split by record boundaries
    processed_content = response.item.value
    processed_records = processed_content.split("\n===RECORD_BOUNDARY===\n")
    
    # Create mapping from original indices to processed captions
    processed_dict = {}
    for record in processed_records:
        parts = record.split("|||SEPARATOR|||", 1)
        if len(parts) == 2:
            idx, content = parts
            processed_dict[int(idx)] = content
    
    # Update the DataFrame with redacted content
    masked_df['caption'] = masked_df.apply(
        lambda row: processed_dict.get(row['original_index'], row['caption']), 
        axis=1
    )
    
    # Additional processing to mask all digits with asterisks
    def mask_digits(text):
        """Replaces digits with asterisks while preserving '[REDACTED]' markers."""
        if not isinstance(text, str):
            return text
        parts = text.split("[REDACTED]")
        for i in range(len(parts)):
            parts[i] = re.sub(r'\d', '*', parts[i])
        return "[REDACTED]".join(parts)
    
    # Apply the digit masking function to each processed caption
    masked_df['caption'] = masked_df['caption'].apply(mask_digits)
    
    # Drop temporary columns
    masked_df.drop(['original_index', 'marked_caption'], axis=1, inplace=True)
    
    logger.info(f"{contact_id}: Completed Masking PII Data")
    return masked_df
    
    
def create_intra_call_df(aws_access_key: str, aws_secret_key: str, transcript_data: dict, contact_id: str, gcp_project_id: str):
    intra_call = process_transcript(transcript_data, contact_id)        
    df_sentiment_scores = get_sentiment_scores(intra_call.caption.to_list())
    intra_call = pd.concat([intra_call, df_sentiment_scores], axis=1)    
    intra_call = get_different_times(intra_call)
    intra_call = mask_pii_in_captions(contact_id, intra_call, gcp_project_id)
    
    return intra_call

# Create Inter-call Dataframe

In [9]:
class CategoryValidator:
    def __init__(self, cat_subcat_mapping):
        """
        Initialize with category mapping from a Snowflake View.
        :param snowflake_conn_params: Dictionary containing Snowflake connection details.
        :param view_name: Name of the Snowflake View containing category mappings.
        """
        self.category_mapping = cat_subcat_mapping
        self.valid_categories = set(self.category_mapping['CATEGORY'].unique())
        self.category_subcategory_map = self._create_category_mapping()

    def _create_category_mapping(self):
        """Create category to subcategory mapping."""
        mapping = {}
        for _, row in self.category_mapping.iterrows():
            if row['CATEGORY'] not in mapping:
                mapping[row['CATEGORY']] = set()
            mapping[row['CATEGORY']].add(row['SUBCATEGORY'])
        return mapping

    def validate_category(self, category: str) -> bool:
        """Check if category is valid."""
        return category in self.valid_categories

    def validate_subcategory(self, category: str, subcategory: str) -> bool:
        """Check if subcategory is valid for given category."""
        return category in self.category_subcategory_map and subcategory in self.category_subcategory_map[category]

    def get_valid_subcategories(self, category: str):
        """Get valid subcategories for a category."""
        return self.category_subcategory_map.get(category, set())

class CallSummary(BaseModel):
    summary: str = Field(..., max_length=500)
    # key_points: List[str] = Field(..., max_items=5)
    # outcome: str = Field(..., max_length=200)
    # follow_up_recommendations: List[str] = Field(..., max_items=3)

class CallTopic(BaseModel):
    primary_topic: str = Field(..., max_length=100)
    category: str = Field(..., max_length=100)
    sub_category: str = Field(..., max_length=100)

    def validate_category_mapping(self, category_validator: CategoryValidator) -> bool:
        """Validate category and subcategory against mapping"""
        if not category_validator.validate_category(self.category):
            logger.error(f"Invalid category: {self.category}")
        if not category_validator.validate_subcategory(self.category, self.sub_category):
            logger.error(f"Invalid subcategory '{self.sub_category}' for category '{self.category}'")

class AgentCoaching(BaseModel):
    strengths: List[str] = Field(..., max_items=3)
    improvement_areas: List[str] = Field(..., max_items=3)
    specific_recommendations: List[str] = Field(..., max_items=4)
    skill_development_focus: List[str] = Field(..., max_items=3)

class TranscriptAnalysis(BaseModel):
    call_summary: CallSummary
    call_topic: CallTopic
    agent_coaching: AgentCoaching

class KPIExtractor:
    def __init__(self, project_id: str, location: str, excel_path: str):
        vertexai.init(project=project_id, location=location)
        self.model = GenerativeModel("gemini-1.5-flash-002")
        self.category_validator = CategoryValidator(cat_subcat_mapping)
        
        self.generation_config = {
            "temperature": 0.3,
            "max_output_tokens": 1024,
            "top_p": 0.8,
            "top_k": 40,
            "response_format": "json"
        }
        
        self.safety_settings = {
            generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
            generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
            generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
            generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
        }

    
    def get_categories_prompt(self) -> str:
        """Create prompt section for valid categories and subcategories, handling null values"""
        categories_prompt = []
        
        for category, subcategories in self.category_validator.category_subcategory_map.items():
            if category is None:  # Skip if category is None
                continue
            
            # Ensure subcategories are valid (remove None values)
            valid_subcategories = [subcat for subcat in subcategories if subcat is not None]
            
            if valid_subcategories:
                subcats = ', '.join(sorted(valid_subcategories))
            else:
                subcats = "No defined subcategories"
            
            categories_prompt.append(f"Category '{category}' can have subcategories: {subcats}")
        
        return '\n'.join(categories_prompt)

    def create_prompt(self, transcript: str) -> str:
        """Create structured prompt with category guidance"""
        categories_guidance = self.get_categories_prompt()
        
        return f"""
        Analyze this call transcript and provide a structured analysis in the exact JSON format specified below.
        Keep responses concise, specific, and actionable.
    
        Guidelines:
        - Call summary should be factual and highlight key interactions
        - Topics and categories MUST match the following valid mappings:
        {categories_guidance}
        - Coaching points should be specific and actionable
        - All responses must follow the exact structure specified
        - Ensure all lists have the specified maximum number of items
        - All text fields must be clear, professional, and free of fluff
    
        Transcript:
        {transcript}
    
        Required Output Structure:
        {{
            "call_summary": {{
                "summary": "3-4 line overview of the call"
            }},
            "call_topic": {{
                "primary_topic": "Main topic of discussion",
                "category": "MUST BE ONE OF THE VALID CATEGORIES LISTED ABOVE",
                "sub_category": "MUST BE A VALID SUB-CATEGORY FOR THE CHOSEN CATEGORY"
            }},
            "agent_coaching": {{
                "strengths": ["Strength 1", "Strength 2", "Strength 3"],
                "improvement_areas": ["Area 1", "Area 2", "Area 3"],
                "specific_recommendations": ["Rec 1", "Rec 2", "Rec 3", "Rec 4"],
                "skill_development_focus": ["Skill 1", "Skill 2", "Skill 3"]
            }}
        }}
    
        Rules:
        1. Maintain exact JSON structure
        2. No additional fields or comments
        3. No markdown formatting
        4. Ensure all arrays have the exact number of items specified
        5. Keep all text concise and professional
        6. Do not mention any PII information such as Customer Name etc.
        7. STRICTLY use only the categories and subcategories from the provided mapping
        """
    
    def extract_json(self, response: str) -> Dict:
        """Extract valid JSON from response"""
        match = re.search(r'```json\s*([\s\S]*?)\s*```', response)
        if match:
            json_str = match.group(1)
        else:
            json_str = response.strip()
        
        try:
           return json.loads(json_str)
        except json.JSONDecodeError:
           logger.error("Invalid JSON response")
           pass

    def validate_response(self, response_json: Dict, contact_id: str = None) -> TranscriptAnalysis:
        """Validate response using Pydantic models and category mapping"""
        try:
            # First validate basic structure with Pydantic
            analysis = TranscriptAnalysis(**response_json)
            
            # Then validate category mapping
            analysis.call_topic.validate_category_mapping(self.category_validator)
            
            return analysis
        except ValidationError as e:
            logger.error(f"{contact_id if contact_id else ''}: Pydantic validation error - {e}")
            pass
        except ValueError as e:
            logger.error(f"{contact_id if contact_id else ''}: Category validation error - {e}")
            pass


    def extract_genai_kpis(
       self,
       transcript: str,
       contact_id: str = None
    ):
        """Extract KPIs from transcript with validation and retries"""
        max_retries = 3
        attempt = 0
    
        while attempt < max_retries:
            try:
                # Generate prompt
                prompt = self.create_prompt(transcript)
    
                # Get response from Gemini
                response = self.model.generate_content(prompt)
    
                # Parse JSON response
                response_json = self.extract_json(response.text)
    
                # If response is empty, retry
                if not response_json or "NA" in response_json.values():
                    logger.warning(f"Attempt {attempt + 1}: Gemini returned NA or empty response. Retrying...")
                    attempt += 1
                    time.sleep(2)  # Wait before retrying
                    continue
    
                # Validate response
                validated_response = self.validate_response(response_json, contact_id)
    
                if validated_response:
                    return validated_response.model_dump()
    
                logger.warning(f"Attempt {attempt + 1}: Invalid response structure. Retrying...")
                attempt += 1
                time.sleep(2)  # Wait before retrying
    
            except Exception as e:
                logger.error(f"Attempt {attempt + 1}: Error extracting KPIs: {str(e)}")
                attempt += 1
                time.sleep(2)  # Wait before retrying
    
        logger.error(f"Failed to extract valid KPIs after {max_retries} attempts.")
        return {"error": "Failed to extract KPIs after multiple attempts"}

def dict_to_newline_string(data: dict) -> str:
    """Converts a dictionary into a new-line formatted string."""
    formatted_str = ""
    for key, value in data.items():
        formatted_str += f"{key}:\n"
        for item in value:
            formatted_str += f"  - {item}\n"
    return formatted_str.strip()

def create_inter_call_df(
    gcp_project_id: str,
    gcp_prjct_location: str,
    df_intra_call: pd.DataFrame,
    transcript_data: dict,
    ac_last_modified_date: datetime,
    cat_subcat_mapping: pd.DataFrame
):
    try:
        contact_id = df_intra_call.contact_id.unique        
    
        # logger.info(f"{contact_id}: Extracting KPIs from Gemini")      
        extractor = KPIExtractor(gcp_project_id, gcp_prjct_location, excel_path)
        transcript = " ".join(df_intra_call.caption)
        call_gen_kpis = extractor.extract_genai_kpis(transcript)
        # logger.info(f"{contact_id}: Completed Extracting KPIs from Gemini") 
    
        # logger.info(f"{contact_id}: Creating Inter Call df")
        inter_call_dict = {}
        inter_call_dict['contact_id'] = str(df_intra_call['contact_id'][0])
        inter_call_dict['call_text'] = " ".join(df_intra_call.caption)
        inter_call_dict['call_summary'] = call_gen_kpis['call_summary']['summary']
        inter_call_dict['topic'] = call_gen_kpis['call_topic']['primary_topic']
        inter_call_dict['category'] = call_gen_kpis['call_topic']['category']
        # inter_call_dict['category_generated'] = call_gen_kpis['call_topic']['category']
        inter_call_dict['sub_category'] = call_gen_kpis['call_topic']['sub_category']
        # inter_call_dict['sub_category_generated'] = call_gen_kpis['call_topic']['sub_category']
        inter_call_dict['agent_coaching'] = dict_to_newline_string(call_gen_kpis['agent_coaching'])
    
        df_inter_call = pd.DataFrame(pd.Series(inter_call_dict)).T
    
        # Add metadata from AWS
        # df_inter_call['account_id'] = transcript_data['AccountId']
        df_inter_call['agent_speech_speed'] = transcript_data['ConversationCharacteristics']['TalkSpeed']['DetailsByParticipant']['AGENT']['AverageWordsPerMinute']
        df_inter_call['customer_speech_speed'] = transcript_data['ConversationCharacteristics']['TalkSpeed']['DetailsByParticipant']['CUSTOMER']['AverageWordsPerMinute']
        df_inter_call['total_talktime_agent_second'] = int(transcript_data['ConversationCharacteristics']['TalkTime']['DetailsByParticipant']['AGENT']['TotalTimeMillis']/1000)
        df_inter_call['total_talktime_customer_second'] = int(transcript_data['ConversationCharacteristics']['TalkTime']['DetailsByParticipant']['CUSTOMER']['TotalTimeMillis']/1000)
        df_inter_call['total_talktime_call_second'] = int(transcript_data['ConversationCharacteristics']['TalkTime']['TotalTimeMillis']/1000)
        df_inter_call['total_duration_call_second'] = int(transcript_data['ConversationCharacteristics']['TotalConversationDurationMillis']/1000)
        df_inter_call['total_dead_air_call_second'] = df_inter_call['total_duration_call_second'] - df_inter_call['total_talktime_call_second']
        # df_inter_call['customer_instance_id'] = transcript_data['CustomerMetadata']['InstanceId']
        # df_inter_call['call_job_status'] = transcript_data['JobStatus']
        df_inter_call['call_language'] = transcript_data['LanguageCode']
        df_inter_call['call_s3_uri'] = transcript_data['CustomerMetadata']['InputS3Uri']
        df_inter_call['ac_last_modified_date'] = ac_last_modified_date
        # df_inter_call['load_date'] = datetime.now()
        
        return df_inter_call

    except Exception as e:
        logger.info
        logger.error(f"{contact_id}: Error Creating Intra Call df: {e}")
        logger.info("")
        pass

# Writing Dataframe to Snowflake

In [10]:
def insert_new_records(conn, table_name, df):
    """
    Inserts only new records (based on ID) into Snowflake table with UTC load timestamp.

    Steps:
    1. Fetches existing IDs from table.
    2. Filters out rows with existing IDs from DataFrame.
    3. Adds 'LOAD_DATE_UTC' column with current UTC timestamp.
    4. Inserts only new records.

    Args:
        conn: Snowflake connection object.
        table_name (str): Name of the target table.
        df (pd.DataFrame): DataFrame containing the data (must have 'CONTACT_ID' column).

    Returns:
        int: Number of inserted records.
    """    
    cursor = conn.cursor()

    # Step 1: Get existing IDs from Snowflake table
    cursor.execute(f"SELECT DISTINCT(CONTACT_ID) FROM {table_name}")
    existing_ids = {row[0] for row in cursor.fetchall()}
    
    # Step 2: Filter DataFrame to keep only new records
    new_records_df = df[~df['CONTACT_ID'].isin(existing_ids)]
    
    if new_records_df.empty:
        logger.info("No new records to insert")
        return 0
    
    # Step 3: Add UTC timestamp column
    utc_now = datetime.now(pytz.utc).strftime('%Y-%m-%d %H:%M:%S')
    new_records_df = new_records_df.copy()  # Avoid modifying original df
    new_records_df["LOAD_DATE"] = utc_now  # Add new column

    # Step 4: Insert new records into Snowflake
    success, nchunks, nrows, _ = write_pandas(conn, new_records_df, table_name)
    
    logger.info(f"Inserted {nrows} new records with UTC load date")
    logger.info(f"Skipped {len(df) - len(new_records_df)} existing records")
    
    cursor.close()
    return nrows

### Handel Duplicates

In [11]:
def handle_duplicates(data_frame, columns_to_check):
    """
    Dedupes the final Dataframes to be written to the Snowflake Tables
    """
    # Remove duplicate records based on the specified columns, keeping the first occurrence
    df_cleaned = data_frame.drop_duplicates(subset=columns_to_check, keep="first")

    # (Optional) Reset index after removing duplicates
    df_cleaned = df_cleaned.reset_index(drop=True)

    return df_cleaned

# Exception Handling

# Logging Handling

In [12]:
def setup_logger(log_file):
    """
    Sets up a logger that writes to both file and console with timestamp.

    Args:
        log_file (str): Name of the log file to write to.

    Returns:
        logger: Configured logger instance.
    """
    logger = logging.getLogger('voice_ai_logger')

    # Reset handlers if already exist
    if logger.hasHandlers():
        logger.handlers.clear()

    # Set log level
    logger.setLevel(logging.DEBUG)

    # Create formatter
    formatter = logging.Formatter(
        '%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S'
    )

    # Create file handler
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # Create a stream handler (console)
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    return logger


def reset_logging():
    """Removes all logging handlers and resets the logger."""
    # Get all loggers
    loggers = list(logging.root.manager.loggerDict.values())
    
    # Include root logger explicitly
    loggers.append(logging.getLogger())

    for logger in loggers:
        if isinstance(logger, logging.Logger):  # Ensure it's a valid logger instance
            for handler in logger.handlers[:]:  # Copy list to avoid modification issues
                logger.removeHandler(handler)
                handler.close()

            logger.setLevel(logging.NOTSET)
            logger.propagate = False  # Avoid duplicate logs from propagation

# Main Function

In [14]:
# Get current date
current_date = datetime.now()

# Get dates for last 2 days
num_days = 4
dates_to_check = []
for i in range(num_days):
    check_date = current_date - timedelta(days=i)
    dates_to_check.append(check_date)

# Process each date
folderlist = []
for date in dates_to_check:
    year = str(date.year)
    month = f"{date.month:02d}"
    day = f"{date.day:02d}"
    
    # Construct the prefix for S3 listing
    prefix = f"{s3_transcripts_location}/{year}/{month}/{day}/"
    folderlist.append(prefix)
folderlist

['Analysis/Voice/2025/03/10/',
 'Analysis/Voice/2025/03/09/',
 'Analysis/Voice/2025/03/08/',
 'Analysis/Voice/2025/03/07/']

In [17]:
reset_logging()

# Setup logger
log_file='voice_ai_runtime.logs'
logger = setup_logger(log_file)

logger.info("Fetching Category, Sub-category Mapping.")
cat_subcat_mapping = fetch_category_mapping_from_snowflake(catsubcat_conn_params)

# Get the transcripts in to_process_folder
df_list_transcripts = list_new_transcripts_from_folderlist(aws_access_key, aws_secret_key, s3_source_bucket, s3_transcripts_location, folderlist)
df_list_transcripts.to_csv("df_list_transcripts.csv")

logger.info("")
logger.info(f"Transcripts to process: {len(df_list_transcripts)}")
logger.info(df_list_transcripts.groupby(['File_Date']).size().reset_index(name='frequency'))
# logger.info(df_list_transcripts[:3000].groupby(['File_Date', 'Time_Bin']).size().reset_index(name='Count'))
logger.info("")

2025-03-10 05:21:18 [INFO]: Fetching Category, Sub-category Mapping.
  df = pd.read_sql(query, conn)
2025-03-10 05:21:20 [INFO]: 
2025-03-10 05:21:20 [INFO]: Transcripts to process: 2073
2025-03-10 05:21:20 [INFO]:     File_Date  frequency
0  2025-03-07       1916
1  2025-03-08        157
2025-03-10 05:21:20 [INFO]: 


In [32]:
f"{s3_source_bucket}/{folderlist[0]}"

'amazon-connect-39f6aa5d9242/Analysis/Voice/2025/03/10/'

In [36]:
files_list[0]

'Analysis/Voice/2025/03/08/e630c206-987c-47c2-8b23-2329d0b13ec6_analysis_2025-03-08T23:58:16Z.json'

In [41]:
files_list = df_list_transcripts.File.to_list()

s3_client = boto3.client(
    "s3", 
    aws_access_key_id=aws_access_key, 
    aws_secret_access_key=aws_secret_key
)

if files_list:
    success_downloads = []
    failed_downloads = []
    
    for folder in folderlist:
        with ThreadPoolExecutor(max_workers=5) as executor:
            logger.info(f"Started: bulk download to GCS transcripts#: {len(files_list)}")

            future_to_file = {
                executor.submit(
                    download_transcripts_to_gcs,
                    file,
                    s3_client,
                    s3_source_bucket
                ): file for file in files_list[:2]
            }

            for future in as_completed(future_to_file):
                try:
                    success, failed = future.result()  # Get results

                    if success:
                        success_downloads.append(success)
                    if failed:
                        failed_downloads.append(failed)

                except Exception as e:
                    logger.error(f"Unexpected Error: {str(e)}")
                    handle_exception("N/A", vai_gcs_bucket, pipeline_run_name, gcs_errored_folder, str(e))

2025-03-10 05:41:54 [INFO]: Started: bulk download to GCS transcripts#: 2073
2025-03-10 05:41:55 [ERROR]: Error: Failed to process Analysis/Voice/2025/03/08/ce79464e-86fd-4384-802b-ae13e09ba949_analysis_2025-03-08T23:22:44Z.json -> [Errno 21] Is a directory: 'transcripts.788a1211' -> 'transcripts'
2025-03-10 05:41:55 [ERROR]: Unexpected Error: cannot unpack non-iterable NoneType object
2025-03-10 05:41:55 [ERROR]: Error: Failed to process Analysis/Voice/2025/03/08/e630c206-987c-47c2-8b23-2329d0b13ec6_analysis_2025-03-08T23:58:16Z.json -> [Errno 21] Is a directory: 'transcripts.79278e6b' -> 'transcripts'


NameError: name 'handle_exception' is not defined

In [18]:
df_list_transcripts[:2]

Unnamed: 0,File,ID,File_Timestamp,File_Date,File_Time,S3_Timestamp,S3_Date,S3_Time,Time_Bin
141,Analysis/Voice/2025/03/08/e630c206-987c-47c2-8...,e630c206-987c-47c2-8b23-2329d0b13ec6,2025-03-08 23:58:16,2025-03-08,23:58:16,2025-03-09 00:04:04+00:00,2025-03-09,00:04:04,2025-03-08 22:00:00
129,Analysis/Voice/2025/03/08/ce79464e-86fd-4384-8...,ce79464e-86fd-4384-802b-ae13e09ba949,2025-03-08 23:22:44,2025-03-08,23:22:44,2025-03-08 23:30:13+00:00,2025-03-08,23:30:13,2025-03-08 22:00:00


In [19]:
# Initiating Master DataFrames
max_threads = 5
print("Called: To make sure previous calls are loaded into memory again.")
logger.info("Called: Initiate Master Dataframes")
df_intra_calls_data, df_inter_calls_data = initiate_master_dataframes()
print(len(df_inter_calls_data))
print(len(df_intra_calls_data))
print(f"Max Threads: {max_threads}")

2025-03-10 05:22:27 [INFO]: Called: Initiate Master Dataframes
2025-03-10 05:22:27 [INFO]: df_intra_calls_data.csv does not exists.
2025-03-10 05:22:27 [INFO]: df_inter_calls_data.csv does not exists.


Called: To make sure previous calls are loaded into memory again.
0
0
Max Threads: 5


In [22]:
def process_single_call(transcript):
    # for transcript in df_list_transcripts.File.to_list():
    logger.info("--------------------------")
    start_time = time.time()
    try:
        # get the call ID
        contact_id = transcript.split('/')[-1].split('.')[0].split('analysis')[0].strip('_')
        ac_last_modified_date = datetime.strptime(
                                        transcript.split('analysis_')[-1].split('.')[0].replace('Z', ""), 
                                        "%Y-%m-%dT%H:%M:%S"
                                    ).strftime('%Y-%m-%d %H:%M:%S')

        # Check if Call Already Processed
        if (len(df_intra_calls_data) > 0 and contact_id in df_intra_calls_data.CONTACT_ID.unique()) and (len(df_inter_calls_data) > 0 and contact_id in df_inter_calls_data.CONTACT_ID.unique()):
            logger.info(f"{contact_id}: Call already Processed.")
            # break

        else:
            # get the audio transcript file name
            logger.info(f"{contact_id}: Processing") 

            # Get the Transcript file from S3 Bucket
            logger.info(f"{contact_id}: Fetching Transcript from S3")
            transcript_data = fetch_transcript_from_s3(aws_access_key, aws_secret_key, s3_source_bucket, transcript)
            logger.info(f"{contact_id}: Successfully fetched the Transcript from S3")

            # Create the Inter Call KPIs
            logger.info(f"{contact_id}: Creating df_intra_call ")
            df_intra_call = create_intra_call_df(aws_access_key, aws_secret_key, transcript_data, contact_id, gcp_project_id)
            logger.info(f"{contact_id}: Successfully created df_intra_call ")

            # Create the Intra Call KPIs
            logger.info(f"{contact_id}: Creating df_inter_call ")
            df_inter_call = create_inter_call_df(gcp_project_id, gcp_prjct_location, df_intra_call, transcript_data, ac_last_modified_date, cat_subcat_mapping)
            logger.info(f"{contact_id}: Successfully created df_inter_call ")


            # ###============================================================###
            # Save DataFrames only when both are having data
            if not df_intra_call.empty and not df_inter_call.empty:
                # Appending to Intra-calls Master DataFrame
                df_intra_call.columns = df_intra_call.columns.str.upper()  # Capitalising Column names for Snowflake
                df_intra_calls_data = pd.concat([df_intra_calls_data, df_intra_call], ignore_index=True)
                df_intra_calls_data.to_csv("df_intra_calls_data.csv", index=False)
                logger.info(f"{contact_id}: Persisted df_intra_calls_data to CSV.")

                # Appending to Inter-calls Master DataFrame
                df_inter_call.columns = df_inter_call.columns.str.upper()  # Capitalising Column names for Snowflake
                df_inter_calls_data = pd.concat([df_inter_calls_data, df_inter_call], ignore_index=True)
                df_inter_calls_data.to_csv("df_inter_calls_data.csv", index=False)
                logger.info(f"{contact_id}: Persisted df_intra_calls_data to CSV.")
                # logger.info(f"{contact_id}: Processing Complete")

            else:
                if df_intra_call.empty:
                    logger.error("Intra Call DataFrame was not created successfully.")
                if df_inter_call.empty:
                    logger.error("Inter Call DataFrame was not created successfully.")

        if len(df_inter_calls_data)%20 == 0:
            logger.info("--------------------------")
            logger.info(f"Processed {len(df_inter_calls_data)} files")
            logger.info("--------------------------")

    except Exception as e:
        logger.error("--------------------------")
        logger.error(f"{contact_id}: Exception - {e}")
        logger.error("--------------------------")


    end_time = time.time()
    elapsed_time = end_time - start_time  # Time taken for this iteration
    minutes, seconds = divmod(elapsed_time, 60)
    logger.info(f"{contact_id}: Processed Call #{len(df_inter_calls_data)} in {int(minutes)} min {seconds:.2f} sec elapsed")
    logger.info("--------------------------")
    logger.info("")
    logger.info("")

In [25]:
calls = df_list_transcripts.File.to_list()

In [None]:
# If there are transcripts to be processed
if len(df_list_transcripts) == 0:
    logger.info("No Transcripts to Process")
    logger.info("")

else:
    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
        future_to_file = {executor.submit(process_single_call, call): call for call in calls}
        
        for future in concurrent.futures.as_completed(future_to_file):
            file = future_to_file[future]
            try:
                result = future.result()  # This will re-raise any exceptions
                if result:
                    results.append(result)
            except Exception as e:
                logging.error(f"Unhandled exception processing {file}: {e}")
    
    
        
    logger.info(f"Removing duplicates in df_inter_calls_data.")
    columns_to_check = ["CONTACT_ID"]
    df_inter_calls_data = handle_duplicates(df_inter_calls_data, columns_to_check)
    logger.info(f"Removing duplicates in df_intra_calls_data.")
    columns_to_check = ["CONTACT_ID", "SPEAKER_TAG", "CAPTION", "START_TIME_SECOND", "END_TIME_SECOND"]
    df_intra_calls_data = handle_duplicates(df_intra_calls_data, columns_to_check)

    logger.info(f"Completed processing {len(df_list_transcripts)} Calls")

# Load Data to Snowflake

In [None]:
# conn_params

In [None]:
# start_time = time.time()
# logger.info(f"Writing Dataframe to Snowflake.")
# conn = sc.connect(**conn_params)

# table_name ='SRC_GCP_INTER_CALLS'    
# logger.info(f"Writing data to table: {conn_params['database']}.{table_name}")
# insert_new_records(conn, table_name, df_inter_calls_data)
# end_time = time.time()
# elapsed_time = end_time - start_time  # Time taken for this iteration
# minutes, seconds = divmod(elapsed_time, 60)
# logger.info(f"SRC_GCP_INTER_CALLS: Inserted records #{len(df_inter_calls_data)} in {int(minutes)} min {seconds:.2f} sec elapsed")
        
# start_time = time.time()
# logger.info(f"Writing data to table: {conn_params['database']}.{table_name}")
# table_name ='SRC_GCP_INTRA_CALLS'
# insert_new_records(conn, table_name, df_intra_calls_data)
# end_time = time.time()
# elapsed_time = end_time - start_time  # Time taken for this iteration
# minutes, seconds = divmod(elapsed_time, 60)
# logger.info(f"SRC_GCP_INTRA_CALLS: Inserted records #{len(df_intra_calls_data)} in {int(minutes)} min {seconds:.2f} sec elapsed")

# conn.close()

# Validate

In [None]:
len(df_inter_calls_data)