In [38]:
%pip install google-generativeai langchain langchain_google_genai

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip available: 22.2.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [39]:
import os
import json
from dotenv import load_dotenv
import google.generativeai as genai
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputKeyToolsParser

In [40]:
# Load API key from .env file
load_dotenv()
GEMINI_API = os.getenv("GEMINI_API_KEY")

In [41]:
# Initialize the Gemini model in LangChain
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp", temperature=0.7, api_key=GEMINI_API)

In [98]:
# Example problem statement
problem_statement = """
    Metro station wants to establish a TicketDistributor machine that issues tickets for
passengers travelling on metro rails. Travellers have options of selecting a ticket for a single
trip, round trips or multiple trips. They can also issue a metro pass for regular passengers or
a time card for a day, a week or a month according to their requirements. The discounts on
tickets will be provided to frequent travelling passengers. The machine is also supposed to
read the metro pass and time cards issued by the metro counters or machine. The ticket rates
differ based on whether the traveller is a child or an adult. The machine is also required to
recognize original as well as fake currency notes. The typical transaction consists of a user
using the display interface to select the type and quantity of tickets and then choosing a
payment method of either cash, credit/debit card or smartcard. The tickets are printed and
dispensed to the user. Also, the messaging facilities after every transaction are required on
the registered number. The system can also be operated comfortably by a touch-screen. A
large number of heavy components are to be used. We do not want our system to slow down,
and also the usability of the machine.
The TicketDistributor must be able to handle several exceptions, such as aborting the
transaction for incomplete transactions, the insufficient amount given by the travellers to the
machine, money return in case of an aborted transaction, change return after a successful
transaction, showing insufficient balance in the card, updated information printed on the
tickets e.g. departure time, date, time, price, valid from, valid till, validity duration, ticket
issued from and destination station. In case of exceptions, an error message is to be displayed.
We do not want user feedback after every development stage but after every two stages to
save time. The machine is required to work in a heavy load environment such that in the
morning and evening time on weekdays, and weekends performance and efficiency would
not be affected.
"""

### Class, Attributes and operations Identification

In [104]:
from langchain_core.prompts import PromptTemplate

# Define the prompt template
class_identification_prompt = PromptTemplate.from_template(
    """You are an expert software architect. Your task is to analyze the given problem statement and extract relevant 
    classes, attributes, and operations while adhering to the SOLID principles.

    Problem Statement:
    {problem_statement}

    Instructions:
    1. Identify potential classes based on key concepts in the problem statement.
    2. List attributes for each class that represent its properties.
    3. List operations (methods) for each class that define its behavior.
    4. Ensure that the identified elements follow the SOLID principles:
       - Single Responsibility Principle (SRP)
       - Open/Closed Principle (OCP)
       - Liskov Substitution Principle (LSP)
       - Interface Segregation Principle (ISP)
       - Dependency Inversion Principle (DIP)

    Provide the output in the following JSON format:
    {{
        "classes": [
            {{
                "name": "ClassName",
                "attributes": ["attribute1", "attribute2"],
                "operations": ["operation1()", "operation2()"]
            }}
        ]
    }}
    
    """
)

# Format the prompt with the problem statement
formatted_prompt = class_identification_prompt.format(problem_statement=problem_statement)

In [105]:
# Invoke the Gemini LLM via LangChain
response = llm.invoke(formatted_prompt)

# Extract JSON response from the model output
try:
    extracted_json = json.loads(response.content)
    print(json.dumps(extracted_json, indent=4))  # Pretty print the output
except json.JSONDecodeError:
    print("Failed to parse JSON response. Here is the raw output:\n", response.content)

Failed to parse JSON response. Here is the raw output:
 ```json
{
    "classes": [
        {
            "name": "TicketDistributor",
            "attributes": [
                "display: DisplayInterface",
                "paymentProcessor: PaymentProcessor",
                "ticketPrinter: TicketPrinter",
                "messageService: MessageService",
                "currencyValidator: CurrencyValidator",
                "ticketPricingService: TicketPricingService",
                "transactionLog: TransactionLog"
            ],
            "operations": [
                "selectTicketType(ticketType: TicketType)",
                "selectQuantity(quantity: int)",
                "selectPaymentMethod(paymentMethod: PaymentMethod)",
                "processTransaction()",
                "abortTransaction()",
                "handleException(exception: Exception)",
                "dispenseTicket()",
                "returnChange(amount: double)",
                "updateDisplay(mes

In [None]:

import json

def safe_load_json(response_content):
    """ Safely loads JSON from API response content. """
    try:
        # Ensure response is a valid JSON string
        if isinstance(response_content, bytes):
            response_content = response_content.decode("utf-8")

        return json.loads(response_content)
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")
        print(f"Invalid JSON content: {response_content}")
        return None

import json
import re

def clean_and_parse_json(response_content):
    """ Cleans and parses JSON from Gemini API response content. """
    try:
        # Ensure response is a string
        if isinstance(response_content, bytes):
            response_content = response_content.decode("utf-8")

        # Remove triple backticks and "json" label if present
        response_content = re.sub(r"```json\s*", "", response_content)  # Remove opening triple backticks
        response_content = re.sub(r"```$", "", response_content)  # Remove closing triple backticks

        # Replace incorrect double curly braces `{{` with `{`
        response_content = response_content.replace("{{", "{").replace("}}", "}")

        # Parse JSON
        return json.loads(response_content)

    except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")
        print(f"Invalid JSON content: {response_content}")
        return None

# Example: Simulating API Response (Replace with actual API call)
response_content = '''```json
    {{
        "classes": [
            {{
                "name": "ClassName",
                "attributes": ["attribute1", "attribute2"],
                "operations": ["operation1()", "operation2()"]
            }}
        ]
    }}
    ```'''

# Clean and parse JSON
parsed_json = clean_and_parse_json(response_content)

if parsed_json:
    print("✅ Successfully parsed JSON:", parsed_json)
else:
    print("❌ Failed to parse JSON")


In [107]:
# Step 1: Generate classes, attributes, and operations
generated_classes_json = safe_load_json(response.content)
# type(response.content)
# response.content

Error decoding JSON: Expecting value: line 1 column 1 (char 0)
Invalid JSON content: ```json
{
    "classes": [
        {
            "name": "TicketDistributor",
            "attributes": [
                "display: DisplayInterface",
                "paymentProcessor: PaymentProcessor",
                "ticketPrinter: TicketPrinter",
                "messageService: MessageService",
                "currencyValidator: CurrencyValidator",
                "ticketPricingService: TicketPricingService",
                "transactionLog: TransactionLog"
            ],
            "operations": [
                "selectTicketType(ticketType: TicketType)",
                "selectQuantity(quantity: int)",
                "selectPaymentMethod(paymentMethod: PaymentMethod)",
                "processTransaction()",
                "abortTransaction()",
                "handleException(exception: Exception)",
                "dispenseTicket()",
                "returnChange(amount: double)",
     

### Relationship Identification

In [55]:
from langchain_core.prompts import PromptTemplate

# Define the relationship identification prompt template
relationship_identification_prompt = PromptTemplate.from_template(
    """You are an expert software architect. Your task is to analyze the given class structure and identify 
    relationships between the classes while ensuring adherence to the SOLID principles.

    **Goal:** Generate a complete class diagram, including all attributes, operations, and relationships.

    **Given Classes:**
    {generated_classes_json}

    **Instructions:**
    1. **Identify Relationships:** Determine how the given classes interact. Use appropriate relationships such as:
       - Association
       - Aggregation
       - Composition
       - Inheritance (Generalization)
       - Dependency
    2. **Follow SOLID Principles:** Ensure relationships follow:
       - **SRP**: Each class should have a single responsibility.
       - **OCP**: Use inheritance where necessary but ensure extensibility.
       - **LSP**: Subclasses should be substitutable for their base class.
       - **ISP**: Split large interfaces into smaller, specific ones.
       - **DIP**: Depend on abstractions, not concrete implementations.
    3. **Output Format:** Provide a structured JSON format with identified relationships:

    ```json
    {{
        "relationships": [
            {{
                "source": "ClassName1",
                "target": "ClassName2",
                "type": "RelationshipType",
                "description": "Brief explanation of why this relationship exists"
            }}
        ]
    }}
    ```

    """
)

# Format the prompt with previously generated classes
formatted_relationship_prompt = relationship_identification_prompt.format(generated_classes_json=generated_classes_json)

In [58]:
# Invoke the Gemini LLM via LangChain
response = llm.invoke(formatted_relationship_prompt)

# Extract JSON response from the model output
try:
    extracted_json = json.loads(response.content)
    print(json.dumps(extracted_json, indent=4))  # Pretty print the output
except json.JSONDecodeError:
    print("Failed to parse JSON response. Here is the raw output:\n", response.content)

Failed to parse JSON response. Here is the raw output:
 ```json
{
  "relationships": [
    {
      "source": "TicketDistributor",
      "target": "DisplayInterface",
      "type": "Composition",
      "description": "TicketDistributor owns and uses DisplayInterface to interact with the user."
    },
    {
      "source": "TicketDistributor",
      "target": "PaymentProcessor",
      "type": "Composition",
      "description": "TicketDistributor owns and uses PaymentProcessor to handle payment transactions."
    },
    {
      "source": "TicketDistributor",
      "target": "TicketPrinter",
      "type": "Composition",
      "description": "TicketDistributor owns and uses TicketPrinter to print tickets."
    },
    {
      "source": "TicketDistributor",
      "target": "MessagingService",
      "type": "Composition",
      "description": "TicketDistributor owns and uses MessagingService to send messages (e.g., for transaction confirmation)."
    },
    {
      "source": "TicketDistributo

In [69]:
identified_relationships_json = safe_load_json(response.content)

Error decoding JSON: Expecting value: line 1 column 1 (char 0)
Invalid JSON content: ```json
{
    "classes": [
        {
            "name": "TicketDistributor",
            "attributes": [
                "display: DisplayInterface",
                "paymentProcessor: PaymentProcessor",
                "ticketPrinter: TicketPrinter",
                "messageService: MessageService",
                "fareCalculator: FareCalculator",
                "cardReader: CardReader",
                "currencyValidator: CurrencyValidator",
                "transactionManager: TransactionManager"
            ],
            "operations": [
                "selectTicketType(ticketType: TicketType)",
                "selectQuantity(quantity: int)",
                "selectPaymentMethod(paymentMethod: PaymentMethod)",
                "processPayment(amount: double): boolean",
                "dispenseTicket()",
                "sendMessage(message: string)",
                "handleException(exception:

### Combine all information and generate plantUML script for generating Class Diagrams

In [61]:
def generate_plantuml_script(generated_classes, identified_relationships):
    """
    Generates a PlantUML script from the extracted class details and relationships.
    Ensures adherence to SOLID principles.
    
    :param generated_classes: JSON containing classes, attributes, and operations.
    :param identified_relationships: JSON containing relationships between classes.
    :return: A PlantUML script as a string.
    """
    plantuml_script = "@startuml\n"
    plantuml_script += "skinparam classAttributeIconSize 0\n\n"

    # Generate Class Definitions
    for cls in generated_classes["classes"]:
        plantuml_script += f"class {cls['name']} {{\n"
        
        # Add attributes
        for attr in cls["attributes"]:
            plantuml_script += f"  - {attr}\n"
        
        # Add operations
        for op in cls["operations"]:
            plantuml_script += f"  + {op}()\n"

        plantuml_script += "}\n\n"

    # Generate Relationships
    relationship_mappings = {
        "Association": "--",
        "Aggregation": "o--",
        "Composition": "*--",
        "Inheritance": "<|--",
        "Dependency": "..>"
    }

    for relation in identified_relationships["relationships"]:
        source = relation["source"]
        target = relation["target"]
        relation_type = relation["type"]
        relation_desc = relation.get("description", "")

        # Get appropriate PlantUML relation symbol
        plantuml_relation = relationship_mappings.get(relation_type, "--")
        
        # Add relationship line
        plantuml_script += f"{source} {plantuml_relation} {target} : {relation_desc}\n"

    plantuml_script += "\n@enduml"

    return plantuml_script

# Generate PlantUML script
plantuml_code = generate_plantuml_script(generated_classes_json, identified_relationships_json)

# Print the generated PlantUML code
print(plantuml_code)


TypeError: string indices must be integers