In [21]:
from google.auth import default
from dfcx_scrapi.core.flows import Flows
from dfcx_scrapi.core.pages import Pages
from dfcx_scrapi.core.intents import Intents
import vertexai
from vertexai.generative_models import GenerativeModel, HarmCategory, HarmBlockThreshold
from google.cloud.dialogflowcx_v3beta1 import types
import json
import re
from google.protobuf.json_format import MessageToDict
from typing import Optional
import logging
from pathlib import Path
import yaml
import atexit

In [253]:
intents = Intents(agent_id="projects/ai-ml-team-sandbox/locations/us-east1/agents/fac86b77-f640-41f6-937b-22f037d6cc22")

In [None]:
all_intents = {}
for intent in intents.list_intents():
    all_intents[intent.display_name] = [
        "".join([part.text for part in phr.parts])
        for phr in intent.training_phrases
    ]
all_intents
# TODO: Sample 10 each

In [27]:
def setup_logging():
    config_file = Path("config.yaml") 
    with open(config_file) as f:
        config = yaml.safe_load(f)
    logging.config.dictConfig(config=config)
    queue_handler = logging.getHandlerByName("queue_handler")
    if queue_handler is not None:
        queue_handler.listener.start()
        atexit.register(queue_handler.listener.stop)

setup_logging()

In [28]:
logger = logging.getLogger("dialogflow_annotator")

class DialogflowAnnotator:

    DEFAULT_GENERATION_CONFIG = {
        "max_output_tokens": 8192,
        "temperature": 0,
        "top_p": 0.95,
    }

    DEFAULT_SAFETY_SETTINGS = {
        HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
        HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
        HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
        HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
    }

    SYSTEM_INSTRUCTION = "You are an expert on Google Dialogflow CX structure and describing how Dialogflow agents work."

    def __init__(
        self,
        agent_id: str,
        model: str | GenerativeModel = "gemini-1.5-flash-001",
        generation_config: Optional[dict[str, str | int | float]] = None,
        safety_settings: Optional[dict[HarmCategory, HarmBlockThreshold]] = None,
    ) -> None:
        self.agent_id = agent_id
        self.project_id = self.agent_id.split("/")[1]
        self.location = self.agent_id.split("/")[3]
        # TODO: Make creds more flexible
        self.creds = default()[0]
    
        self.flows_client = Flows(creds=self.creds, agent_id=self.agent_id)
        self.pages_client = Pages(creds=self.creds)
        self.intents_intents = Intents(creds=self.creds, agent_id=self.agent_id)
        self.flows_map = self.flows_client.get_flows_map()
        self.flow_page_map = self.flows_client.get_flow_page_map(agent_id=self.agent_id, rate_limit=0.5)
        
        if self.location == "global":
            vertexai.init(project=self.project_id, location="us-central1")
        else:
            vertexai.init(project=self.project_id, location=self.location)

        if isinstance(model, str):
            self.model = GenerativeModel(model, system_instruction=[self.SYSTEM_INSTRUCTION])
        else:
            self.model = model

        self.generation_config = generation_config or self.DEFAULT_GENERATION_CONFIG
        self.safety_settings = safety_settings or self.DEFAULT_SAFETY_SETTINGS
    
    def _generate_annotations(self, prompt):

        responses = self.model.generate_content(
            [prompt],
            generation_config=self.generation_config,
            safety_settings=self.safety_settings,
        )

        return responses.candidates[0].content.text

    def _parse_json_from_gemini_string(self, json_string: str):
        """Parses a dictionary from a JSON-like object string.

        Args:
        json_str: A string representing a JSON-like object, e.g.:
            ```json
            {
            "key1": "value1",
            "key2": "value2"
            }
            ```

        Returns:
        A dictionary representing the parsed object, or None if parsing fails.
        """

        try:
            # Remove potential leading/trailing whitespace
            json_string = json_string.strip()

            # Extract JSON content from triple backticks and "json" language specifier
            json_match = re.search(r"```json\s*(.*?)\s*```", json_string, re.DOTALL)

            if json_match:
                json_string = json_match.group(1)

            return json.loads(json_string)
        except (json.JSONDecodeError, AttributeError) as e:
            logger.error(f"Failed to parse JSON from string: {e}")
            return None

    def annotate_flows(
        self,
        flows: str | list[str],
        annotate_routes: bool = True,
        overwrite: bool = False,
    ) -> list[types.Page]:
        annotated_flows = []

        if isinstance(flows, str):
            flows = [flows]
        
        for i, flow in enumerate(flows):
            logger.info(f"Annotating flow {i+1}/{len(flows)}")
            if flow.startswith("projects") and len(flow.split("/")) == 8:
                annotated_flows.append(
                    self._annotate_flow(flow_id=flow, annotate_routes=annotate_routes, overwrite=overwrite)
                )
                continue
            annotated_flows.append(
                self._annotate_flow(flow_name=flow, annotate_routes=annotate_routes, overwrite=overwrite)
            )
        
        return annotated_flows

    def annotate_intents(
        self
    ) -> list[types.Intent]:
        pass

    def _annotate_flow(
        self,
        flow_name: Optional[str] = None,
        flow_id: Optional[str] = None,
        annotate_routes: bool = True,
        overwrite: bool = False,
    ) -> types.Page:
        
        output_schema = {
            "$schema": "http://json-schema.org/draft-07/schema#",
            "title": "Flow Description Format",
            "description": "Schema describing the response format to describe a page",
            "descriptions": {
                "type": "array",
                "description": "List of page descriptions",
                "items": {
                    "type": "object",
                    "properties": {
                        "name": {
                            "type": "string",
                            "description": "The ID of the page",
                            "example": "39575310-df14-4271-bbf5-7c9e60219391"
                        },
                        "display_name": {
                            "type": "string",
                            "description": "The display name of the page",
                            "example": "say_Welcome"
                        },
                        "description": {
                            "type": "string",
                            "description": "Your description of the page's functionality"
                        },
                        "transition_routes": {
                            "type": "array",
                            "description": "List of transition routes from this page",
                            "items": {
                                "type": "object",
                                "properties": {
                                    "name": {
                                        "type": "string",
                                        "description": "The ID of the transition route",
                                        "example": "39575310-df14-4271-bbf5-7c9e60219391"
                                    },
                                    "description": {
                                        "type": "string",
                                        "description": "Your description of the route"
                                    }
                                },
                                "required": [
                                    "name",
                                    "description"
                                ]
                            }
                        }
                    },
                    "required": [
                        "name",
                        "description",
                        "transitionRoutes"
                    ]
                }
            }
        }


        if flow_name is None and flow_id is None:
            raise ValueError("Must provide either flow_name or flow_id")
        
        flow_id = flow_id or self.flow_page_map[flow_name]["id"]
        flow_name = flow_name or self.flows_map[flow_id]

        if flow_name != self.flows_map[flow_id]:
            raise ValueError(f"Flow name {flow_name} does not match flow ID {flow_id}")
        
        logger.info(f"Annotating flow {flow_name}: {flow_id}")

        pages_to_annotate = [
            MessageToDict(p._pb, preserving_proto_field_name=True)
            for p in self.pages_client.list_pages(flow_id=flow_id)
        ]

        prompt = f"""Below is an array of Dialogflow CX pages represented as JSON data. For each page in the array:
- First, write a detailed description of each route.
- Then, use the route descriptions along with the JSON data to write a detailed description of what the page does overall. Be concise and use no more than 50 words to describe each route and 150 words to describe the page's functionality.
- When writing route descriptions, use the flow map or the page map to identify the destination page or flow of each route.
- Respond with a JSON object with no other text. DO NOT USE MARKDOWN.

## FLOW MAP
{self.flows_map}

## PAGE MAP
{ {data["id"]: name for name, data in self.flow_page_map.items()} }

## PAGES TO DESCRIBE
{pages_to_annotate}

## RESPONSE FORMAT
Use the following JSON schema to structure your response:
{output_schema}

YOUR RESPONSE:"""
        logger.info(f"Generating annotations for flow {flow_name}")
        generated_results = self._generate_annotations(prompt)
        logger.info(f"Parsing generated output for flow {flow_name}")
        parsed_generated_results = self._parse_json_from_gemini_string(generated_results)

        if parsed_generated_results is None:
            logger.warning(f"Failed to parse generated results for flow {flow_name}")
            return None

        annotated_pages = []

        for annotation in parsed_generated_results["descriptions"]:
            existing_page = [
                p for p in pages_to_annotate
                if p.get("name") == annotation["name"]
            ][0]

            updated_kwargs = self._build_updated_page(
                existing_page=existing_page,
                annotation=annotation,
                annotate_routes=annotate_routes,
                overwrite=overwrite
            )

            # TODO: Figure out why ujet payloads break the update_page call.
            try:
                annotated_page = self.pages_client.update_page(
                    page_id=annotation["name"],
                    **updated_kwargs
                )

                logging.info(f"Successfully described page {existing_page.get('display_name')}")
                annotated_pages.append(annotated_page)
            except ValueError as e:
                logging.error(f"Failed to describe page {existing_page.get('display_name')}: {e}")
        
        return annotated_pages

    def _build_updated_page(
        self,
        existing_page: dict[str, str],
        annotation: dict[str, str],
        annotate_routes: bool,
        overwrite: bool,
    ) -> dict[str, str]:
        
        updated_kwargs = {}

        logger.info(f"Updating page: {existing_page.get('display_name')}")

        if "description" in existing_page and not overwrite:
            logger.info(f". Page {existing_page.get('display_name')} already has a description")
        else:
            logger.info(f". Adding description to page {existing_page.get('display_name')}")
            updated_kwargs["description"] = annotation["description"]
        
        if not annotate_routes:
            return updated_kwargs
        
        current_routes = existing_page.get("transition_routes", [])
        described_routes = annotation.get("transition_routes", [])

        # TODO: Update by ID if lengths don't match
        if len(current_routes) != len(described_routes):
            logger.info(
                f". Skipping routes for page {existing_page.get('display_name')}. "
                "Number of original and updated routes does not match."
            )
        else:
            route_ids_to_update = [
                r["name"] for r in current_routes
                if overwrite or "description" not in r
            ]
                
            logger.info(f". Adding descriptions to {len(route_ids_to_update)} routes for page {existing_page.get('display_name')}")
            updated_kwargs["transition_routes"] = [
                r | c if r["name"] in route_ids_to_update else r
                for r, c in zip(current_routes, described_routes)
            ]
            
        return updated_kwargs

In [29]:
# AGENT_ID = "projects/ai-ml-team-sandbox/locations/us-east1/agents/fac86b77-f640-41f6-937b-22f037d6cc22"
AGENT_ID = "projects/missi-six-dev/locations/global/agents/8ea7a506-622d-449f-92e5-b8f3eaac4d71"

annotator = DialogflowAnnotator(agent_id=AGENT_ID)

2024-06-28T16:51:50-0400 DEBUG    Checking None for explicit credentials as part of auth process...
2024-06-28T16:51:50-0400 DEBUG    Checking Cloud SDK credentials as part of auth process...
2024-06-28T16:51:50-0400 DEBUG    Making request: POST https://oauth2.googleapis.com/token
2024-06-28T16:51:50-0400 DEBUG    Starting new HTTPS connection (1): oauth2.googleapis.com:443
2024-06-28T16:51:50-0400 DEBUG    https://oauth2.googleapis.com:443 "POST /token HTTP/11" 200 None
2024-06-28T16:51:50-0400 DEBUG    Making request: POST https://oauth2.googleapis.com/token
2024-06-28T16:51:50-0400 DEBUG    Starting new HTTPS connection (1): oauth2.googleapis.com:443
2024-06-28T16:51:51-0400 DEBUG    https://oauth2.googleapis.com:443 "POST /token HTTP/11" 200 None
2024-06-28T16:51:51-0400 DEBUG    Making request: POST https://oauth2.googleapis.com/token
2024-06-28T16:51:51-0400 DEBUG    Starting new HTTPS connection (1): oauth2.googleapis.com:443
2024-06-28T16:51:51-0400 DEBUG    https://oauth2.goo

In [30]:
annotator.annotate_flows(flows=list(annotator.flows_map.keys())[10:])

KeyboardInterrupt: 