<a href="https://colab.research.google.com/github/tewarikhush/agent-image-generation/blob/main/Agentic_Image_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Data Loading

In [None]:
import json
import pandas as pd
from typing import Dict, Any, List
from google.colab import auth, userdata
import gspread
import google.generativeai as genai
from google.auth import default

# ============================================================================
# GOOGLE SHEETS DATA LOADING
# ============================================================================

def load_data_from_sheets(spreadsheet_name: str = 'SampleDataCorpusKey',
                          column: str = 'FinalText') -> List[str]:
    """
    Load data from Google Sheets

    Args:
        spreadsheet_name: Name of the Google Sheet
        column: Column to extract ('FinalText', 'Prompt', or 'both')

    Returns:
        List of text strings, or list of tuples if column='both'
    """
    # Authenticate to Google Drive
    auth.authenticate_user()
    creds, _ = default()
    gc = gspread.authorize(creds)

    try:
        spreadsheet = gc.open(spreadsheet_name)
        worksheet = spreadsheet.sheet1
        rows = worksheet.get_all_values()
        df = pd.DataFrame(rows[1:], columns=rows[0])

        if column == 'both':
            if 'Prompt' in df.columns and 'FinalText' in df.columns:
                extracted_data = list(zip(df['Prompt'].tolist(), df['FinalText'].tolist()))
                print(f"Successfully extracted {len(extracted_data)} rows (Prompt + FinalText)")
                return extracted_data
            else:
                print("Error: Both 'Prompt' and 'FinalText' columns required for 'both' mode")
                return []

        elif column == 'Prompt':
            if 'Prompt' in df.columns:
                extracted_data = df['Prompt'].tolist()
                print(f"Successfully extracted {len(extracted_data)} rows from 'Prompt' column")
                return extracted_data
            else:
                print("Error: 'Prompt' column not found")
                return []

        elif column == 'FinalText':
            if 'FinalText' in df.columns:
                extracted_data = df['FinalText'].tolist()
                print(f"Successfully extracted {len(extracted_data)} rows from 'FinalText' column")
                return extracted_data
            else:
                print("Error: 'FinalText' column not found")
                return []

        else:
            print(f"Error: Invalid column parameter '{column}'. Use 'FinalText', 'Prompt', or 'both'")
            return []

    except gspread.SpreadsheetNotFound:
        print(f"Error: Google Sheet named '{spreadsheet_name}' not found")
        return []
    except Exception as e:
        print(f"An error occurred: {e}")
        return []

Tool Definitions

In [None]:
# ============================================================================
# TOOL DEFINITIONS
# ============================================================================

def generate_diagram(components: List[Dict], relationships: List[Dict], callouts: List[Dict]) -> Dict:
    """Generate a diagram visualization"""
    return {
        "type": "diagram",
        "components": components,
        "relationships": relationships,
        "callouts": callouts
    }

def generate_chart(chart_type: str, data: List[Dict], title: str = "") -> Dict:
    """Generate a chart visualization"""
    return {
        "type": "chart",
        "chart_type": chart_type,
        "data": data,
        "title": title
    }

def generate_table(headers: List[str], rows: List[List[str]], caption: str = "") -> Dict:
    """Generate a table visualization"""
    return {
        "type": "table",
        "headers": headers,
        "rows": rows,
        "caption": caption
    }

def generate_flowchart(steps: List[Dict]) -> Dict:
    """Generate a flowchart visualization"""
    return {
        "type": "flowchart",
        "steps": steps
    }

def no_visual_needed(reason: str) -> Dict:
    """Return when no visual is needed"""
    return {
        "type": "none",
        "reason": reason
    }


Function Schemas for LLM

In [None]:
# ============================================================================
# FUNCTION SCHEMAS FOR LLM
# ============================================================================

from google.generativeai.types import content_types

# Function declarations as dictionaries (intermediate representation)
generate_diagram_dict = {
    "name": "generate_diagram",
    "description": "Generate a diagram for explaining concepts, mechanisms, or structures. Use when the text describes components and their relationships.",
    "parameters": {
        "type": "object",
        "properties": {
            "components": {
                "type": "array",
                "items": {
                    "type": "object",
                    "properties": {
                        "name": {"type": "string", "description": "Component name"},
                        "description": {"type": "string", "description": "Component description"}
                    },
                    "required": ["name", "description"]
                },
                "description": "List of components/elements in the diagram"
            },
            "relationships": {
                "type": "array",
                "items": {
                    "type": "object",
                    "properties": {
                        "from": {"type": "string", "description": "Source component"},
                        "to": {"type": "string", "description": "Target component"},
                        "relationship_type": {"type": "string", "description": "Type of relationship"}
                    },
                    "required": ["from", "to", "relationship_type"]
                },
                "description": "How components connect or relate"
            },
            "callouts": {
                "type": "array",
                "items": {
                    "type": "object",
                    "properties": {
                        "label": {"type": "string", "description": "Callout label"},
                        "explanation": {"type": "string", "description": "Brief explanation"}
                    },
                    "required": ["label", "explanation"]
                },
                "description": "Key annotations (max 6)"
            }
        },
        "required": ["components", "relationships", "callouts"]
    }
}

generate_chart_dict = {
    "name": "generate_chart",
    "description": "Generate a chart (bar/line/pie) for quantitative data. Use when the text contains specific numbers and comparisons.",
    "parameters": {
        "type": "object",
        "properties": {
            "chart_type": {
                "type": "string",
                "description": "Type of chart to generate. Must be one of: bar, line, pie"
            },
            "data": {
                "type": "array",
                "description": "Data points extracted from the text",
                "items": {
                    "type": "object",
                    "properties": {
                        "label": {"type": "string", "description": "Data label"},
                        "value": {"type": "number", "description": "Numeric value"}
                    },
                    "required": ["label", "value"]
                }
            },
            "title": {"type": "string", "description": "Chart title (optional)"}
        },
        "required": ["chart_type", "data"]
    }
}

generate_table_dict = {
    "name": "generate_table",
    "description": "Generate a table for structured qualitative comparisons or attributes. Use when text contains categorical comparisons.",
    "parameters": {
        "type": "object",
        "properties": {
            "headers": {
                "type": "array",
                "description": "Column headers",
                "items": {"type": "string"}
            },
            "rows": {
                "type": "array",
                "description": "Table rows (3-10 rows recommended)",
                "items": {
                    "type": "array",
                    "items": {"type": "string"}
                }
            },
            "caption": {"type": "string", "description": "Table caption (optional)"}
        },
        "required": ["headers", "rows"]
    }
}

generate_flowchart_dict = {
    "name": "generate_flowchart",
    "description": "Generate a flowchart for processes, steps, or sequences. Use when text describes a procedure or workflow.",
    "parameters": {
        "type": "object",
        "properties": {
            "steps": {
                "type": "array",
                "description": "Ordered steps in the flowchart",
                "items": {
                    "type": "object",
                    "properties": {
                        "id": {"type": "string", "description": "Unique step identifier"},
                        "text": {"type": "string", "description": "Step description (max 12 words)"},
                        "step_type": { # Renamed 'type' to 'step_type' to avoid conflict with schema 'type'
                            "type": "string",
                            "description": "Type of step. Must be one of: step, decision"
                        },
                        "next": {
                            "type": "array",
                            "description": "IDs of next steps",
                            "items": {"type": "string"}
                        }
                    },
                    "required": ["id", "text", "step_type"]
                }
            }
        },
        "required": ["steps"]
    }
}

no_visual_dict = {
    "name": "no_visual_needed",
    "description": "Call this when a visual would add no value to the text. Use when text is purely narrative or abstract.",
    "parameters": {
        "type": "object",
        "properties": {
            "reason": {
                "type": "string",
                "description": "Brief explanation (≤30 words) why no visual is needed"
            }
        },
        "required": ["reason"]
    }
}

# Convert dictionaries to FunctionDeclaration objects
generate_diagram_func = content_types.FunctionDeclaration(**generate_diagram_dict)
generate_chart_func = content_types.FunctionDeclaration(**generate_chart_dict)
generate_table_func = content_types.FunctionDeclaration(**generate_table_dict)
generate_flowchart_func = content_types.FunctionDeclaration(**generate_flowchart_dict)
no_visual_func = content_types.FunctionDeclaration(**no_visual_dict)


# Create the tool with all function declarations
visual_tool = content_types.Tool(
    function_declarations=[
        generate_diagram_func,
        generate_chart_func,
        generate_table_func,
        generate_flowchart_func,
        no_visual_func
    ]
)

# Define TOOLS as a list containing the single tool object
TOOLS = [visual_tool]

System Prompt

In [None]:
# ============================================================================
# SYSTEM PROMPT
# ============================================================================

SYSTEM_PROMPT = """You are an expert visual content analyzer. Your job is to read text and decide which single visual would be most helpful for readers.

**Your task:**
1. Analyze the provided text excerpt carefully
2. Choose exactly ONE visual type that would maximize clarity
3. Extract all necessary information from the text to create that visual
4. Call the appropriate function with the extracted data

**Rules:**
- Base decisions strictly on the text provided - do NOT invent data
- If data is incomplete or ambiguous, choose 'no_visual_needed' instead
- Prefer the simplest visual that provides maximum clarity
- Extract data exactly as stated in the text
- Keep all text descriptions concise (≤12 words for steps, ≤30 words for explanations)

**Visual type guidelines:**
- **Diagram**: For concepts, mechanisms, structures with components and relationships
- **Chart**: For quantitative data with specific numbers (bar/line/pie)
- **Table**: For structured qualitative comparisons or categorical attributes
- **Flowchart**: For processes, procedures, sequences, or decision trees
- **None**: When a visual would add no meaningful value"""

Orchestration

In [None]:
# ============================================================================
# AGENT ORCHESTRATOR
# ============================================================================

class VisualGenerationAgent:
    def __init__(self, api_key: str = None):
        """
        Initialize the agent with Gemini

        Args:
            api_key: Optional API key. If not provided, will try to get from Colab secrets
        """
        if api_key is None:
            try:
                api_key = userdata.get('GEMINI_API_KEY')
                print("✓ API key loaded from Colab secrets")
            except Exception as e:
                raise ValueError(
                    "API key not found. Either pass it as argument or store it in Colab secrets as 'GEMINI_API_KEY'"
                )

        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel(
            'gemini-2.5-pro',
            tools=TOOLS
        )

        self.tools_map = {
            "generate_diagram": generate_diagram,
            "generate_chart": generate_chart,
            "generate_table": generate_table,
            "generate_flowchart": generate_flowchart,
            "no_visual_needed": no_visual_needed
        }

    def process(self, text_excerpt: str) -> Dict[str, Any]:
        """
        Main processing function - analyzes text and generates appropriate visual

        Args:
            text_excerpt: The text to analyze

        Returns:
            Dictionary containing the visual output and metadata
        """
        try:
            # Step 1: Create the prompt with system instruction
            prompt = f"{SYSTEM_PROMPT}\n\nAnalyze this text and generate the most helpful visual:\n\n{text_excerpt}"

            # Step 2: Generate content with function calling
            response = self.model.generate_content(prompt)

            # Step 3: Check if function was called
            if not response.candidates[0].content.parts:
                return {
                    "success": False,
                    "error": "No response from model"
                }

            # Get the function call from the response
            function_call = None
            for part in response.candidates[0].content.parts:
                if hasattr(part, 'function_call') and part.function_call:
                    function_call = part.function_call
                    break

            if not function_call:
                return {
                    "success": False,
                    "error": "No tool was called by the LLM"
                }

            function_name = function_call.name

            # Step 4: Extract function arguments properly
            function_args = {}
            for key, value in function_call.args.items():
                # Convert proto Value to Python types
                function_args[key] = self._proto_to_python(value)

            # Step 5: Execute the tool
            if function_name not in self.tools_map:
                return {
                    "success": False,
                    "error": f"Unknown function: {function_name}"
                }

            tool_function = self.tools_map[function_name]
            result = tool_function(**function_args)

            # Step 6: Return structured output
            return {
                "success": True,
                "visual": result,
                "function_called": function_name,
                "raw_args": function_args
            }

        except Exception as e:
            return {
                "success": False,
                "error": str(e)
            }

    def _proto_to_python(self, value):
      """Convert protobuf Value to Python types recursively"""
      import json

      # If it's already a Python type, return it
      if isinstance(value, (str, int, float, bool, type(None))):
          return value
      if isinstance(value, (list, dict)):
          # Already converted, but might have nested proto values
          if isinstance(value, list):
              return [self._proto_to_python(item) for item in value]
          elif isinstance(value, dict):
              return {k: self._proto_to_python(v) for k, v in value.items()}
          return value

      # Handle protobuf types
      if hasattr(value, 'list_value'):
          # It's a list
          return [self._proto_to_python(item) for item in value.list_value.values]
      elif hasattr(value, 'struct_value'):
          # It's a dict/object
          return {k: self._proto_to_python(v) for k, v in value.struct_value.fields.items()}
      elif hasattr(value, 'string_value'):
          return value.string_value
      elif hasattr(value, 'number_value'):
          return value.number_value
      elif hasattr(value, 'bool_value'):
          return value.bool_value
      elif hasattr(value, 'null_value'):
          return None
      else:
          # Fallback: try to convert to dict or return as string
          try:
              # Try to access as dict
              result = {}
              for key in dir(value):
                  if not key.startswith('_'):
                      try:
                          result[key] = self._proto_to_python(getattr(value, key))
                      except:
                          pass
              return result if result else str(value)
          except:
              return str(value)

    def process_batch(self, data_list: List[Any]) -> List[Dict[str, Any]]:
        """
        Process multiple texts from Google Sheets

        Args:
            data_list: List of text strings OR list of tuples (Prompt, FinalText)

        Returns:
            List of results for each text
        """
        results = []

        for idx, item in enumerate(data_list):
            # Handle both single strings and tuples
            if isinstance(item, tuple):
                text_to_process = item[1] if item[1] else item[0]  # Prefer second element (FinalText)
            else:
                text_to_process = item

            if not text_to_process or text_to_process.strip() == "":
                results.append({
                    "index": idx,
                    "success": False,
                    "error": "Empty text"
                })
                continue

            print(f"\n{'='*60}")
            print(f"Processing row {idx + 1}/{len(data_list)}")
            print(f"{'='*60}")
            print(f"Text preview: {text_to_process[:100]}...")

            result = self.process(text_to_process)
            result["index"] = idx
            result["original_text"] = text_to_process[:200]  # Store preview

            if result["success"]:
                visual_type = result['visual']['type']
                print(f"✓ Generated: {visual_type}")

                # Display the generated visual content
                print(f"\n--- Visual Details ---")
                if visual_type == "chart":
                    chart_data = result['visual'].get('data', [])
                    print(f"Chart Type: {result['visual'].get('chart_type', 'N/A')}")
                    print(f"Data Points: {len(chart_data)}")
                    print("Data:")
                    for i, item in enumerate(chart_data):
                        if i >= 5:
                            break
                        if isinstance(item, dict):
                            print(f"  - {item.get('label')}: {item.get('value')}")
                        else:
                            print(f"  - {item}")
                    if len(chart_data) > 5:
                        print(f"  ... and {len(chart_data) - 5} more")

                elif visual_type == "table":
                    table_rows = result['visual'].get('rows', [])
                    print(f"Columns: {result['visual'].get('headers', [])}")
                    print(f"Rows: {len(table_rows)}")
                    print("Preview (first 3 rows):")
                    for i, row in enumerate(table_rows):
                        if i >= 3:
                            break
                        print(f"  {row}")

                elif visual_type == "diagram":
                    components = result['visual'].get('components', [])
                    relationships = result['visual'].get('relationships', [])
                    callouts = result['visual'].get('callouts', [])

                    print(f"Components: {len(components)}")
                    for i, comp in enumerate(components):
                        if i >= 3:
                            break
                        if isinstance(comp, dict):
                            print(f"  - {comp.get('name')}: {comp.get('description')}")
                        else:
                            print(f"  - {comp}")

                    print(f"Relationships: {len(relationships)}")
                    for i, rel in enumerate(relationships):
                        if i >= 3:
                            break
                        if isinstance(rel, dict):
                            print(f"  - {rel.get('from')} → {rel.get('to')} ({rel.get('type')})")
                        else:
                            print(f"  - {rel}")

                    if callouts:
                        print(f"Callouts: {len(callouts)}")
                        for i, callout in enumerate(callouts):
                            if i >= 3:
                                break
                            if isinstance(callout, dict):
                                print(f"  - {callout.get('label')}: {callout.get('explanation')}")
                            else:
                                print(f"  - {callout}")

                elif visual_type == "flowchart":
                    steps = result['visual'].get('steps', [])
                    print(f"Steps: {len(steps)}")
                    for i, step in enumerate(steps):
                        if i >= 5:
                            break
                        # Handle both dict and string formats
                        if isinstance(step, dict):
                            step_id = step.get('id', 'N/A')
                            step_text = step.get('text', 'N/A')
                            step_type = step.get('type', 'N/A')
                            print(f"  {step_id}: {step_text} [{step_type}]")
                        else:
                            print(f"  {step}")

                elif visual_type == "none":
                    print(f"Reason: {result['visual'].get('reason', 'N/A')}")

                print(f"--- End Visual Details ---\n")
            else:
                print(f"✗ Error: {result.get('error', 'Unknown error')}")

            results.append(result)

        return results

Main Function

In [None]:
# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main():
    # Step 1: Load data from Google Sheets
    print("Loading data from Google Sheets...")

    # DEFAULT: Load only FinalText column
    data = load_data_from_sheets('SampleDataCorpusKey', column='FinalText')

    # ALTERNATIVES:
    # data = load_data_from_sheets('SampleDataCorpusKey', column='Prompt')  # Use Prompt column
    # data = load_data_from_sheets('SampleDataCorpusKey', column='both')    # Use both as tuples

    if not data:
        print("No data loaded. Exiting.")
        return

    print(f"\nPreview of first 3 rows:")
    for i, item in enumerate(data[:3]):
        if isinstance(item, tuple):
            print(f"{i+1}. Prompt: {item[0][:50]}...")
            print(f"   FinalText: {item[1][:50]}...\n")
        else:
            print(f"{i+1}. Text: {item[:100]}...\n")

    # Step 2: Initialize agent (API key from Colab secrets)
    # Make sure you've stored your Gemini API key in Colab secrets as 'GEMINI_API_KEY'
    agent = VisualGenerationAgent()  # Auto-loads from Colab secrets

    # OR explicitly pass the API key:
    # agent = VisualGenerationAgent(api_key="your-api-key-here")

    # Step 3: Process all rows
    print("\n" + "="*60)
    print("Starting batch processing...")
    print("="*60)

    results = agent.process_batch(data)

    # Step 4: Summary
    print("\n" + "="*60)
    print("PROCESSING SUMMARY")
    print("="*60)

    successful = sum(1 for r in results if r["success"])
    print(f"Total processed: {len(results)}")
    print(f"Successful: {successful}")
    print(f"Failed: {len(results) - successful}")

    # Count by visual type
    visual_types = {}
    for r in results:
        if r["success"]:
            vtype = r["visual"]["type"]
            visual_types[vtype] = visual_types.get(vtype, 0) + 1

    print("\nVisual type breakdown:")
    for vtype, count in visual_types.items():
        print(f"  {vtype}: {count}")

    # Step 5: Show sample results
    print("\n" + "="*60)
    print("SAMPLE RESULTS (first 2 successful)")
    print("="*60)

    shown = 0
    for r in results:
        if r["success"] and shown < 2:
            print(f"\nRow {r['index'] + 1}:")
            print(json.dumps(r["visual"], indent=2))
            shown += 1

    return results

# Run the main function
if __name__ == "__main__":
    results = main()

Loading data from Google Sheets...
Successfully extracted 17 rows from 'FinalText' column

Preview of first 3 rows:
1. Text: Water on Earth moves continuously through the processes of evaporation, condensation, precipitation,...

2. Text: Plants capture energy from sunlight through chlorophyll in their leaves, absorbing carbon dioxide (C...

3. Text: A simple electric circuit consists of a battery, wires, and a bulb. Electricity flows from the batte...

✓ API key loaded from Colab secrets

Starting batch processing...

Processing row 1/17
Text preview: Water on Earth moves continuously through the processes of evaporation, condensation, precipitation,...
✓ Generated: diagram

--- Visual Details ---
Components: 11
  - append
  - clear
  - count
Relationships: 11
  - append
  - clear
  - count
Callouts: 11
  - append
  - clear
  - count
--- End Visual Details ---


Processing row 2/17
Text preview: Plants capture energy from sunlight through chlorophyll in their leaves, absorbing carbon 