In [32]:
%%capture
!pip install langchain transformers torch torchvision pillow
!pip install google-generativeai langchain-google-genai
!pip install tiktoken verovio

In [41]:
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoModel, AutoTokenizer

from PIL import Image
from typing import Type
from pydantic import BaseModel

# Define the input schema for the tool
class ImageCaptioningInput(BaseModel):
    image_path: str  # Path to the image file

# Define the custom tool
class ImageCaptioningTool():
    name: str = "image_captioning"
    description: str = "Generate a caption for an image using a pre-trained BLIP model."
    args_schema: Type[BaseModel] = ImageCaptioningInput

    def __init__(self):
        super().__init__()
        # Set up device (CUDA if available, else CPU)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        # Load BLIP model and processor
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(self.device)

    def run(self, image_path: str) -> str:
        """
        Generate a caption for an image given its file path.

        Args:
            image_path (str): Path to the image file.

        Returns:
            str: Generated caption for the image.
        """
        try:
            # Load and process the image
            image = Image.open(image_path).convert("RGB")
            inputs = self.processor(image, return_tensors="pt").to(self.device)
            outputs = self.model.generate(**inputs)
            caption = self.processor.decode(outputs[0], skip_special_tokens=True)
            return caption
        except Exception as e:
            return f"Error processing the image: {str(e)}"

    async def _arun(self, *args, **kwargs):
        raise NotImplementedError("Async version not implemented for this tool.")


# Define the input schema for the tool
class OCRInput(BaseModel):
    image_path: str  # Path to the image file

# Define the custom OCR tool
class OCRTool():
    name: str = "ocr"
    description: str = "Extract text from an image using the GOT-OCR2_0 model."
    args_schema: Type[BaseModel] = OCRInput

    def __init__(self):
        super().__init__()
        # Set up device (CUDA if available, else CPU)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        # Load GOT-OCR2_0 model and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
        self.model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=self.tokenizer.eos_token_id)
        self.model = self.model.eval().to(self.device)

    def run(self, image_path: str) -> str:
        """
        Extract text from an image using the GOT-OCR2_0 model.

        Args:
            image_path (str): Path to the image file.

        Returns:
            str: Extracted text from the image.
        """
        try:
            # Input the image file to the model
            res = self.model.chat(self.tokenizer, image_path, ocr_type='ocr')  # We use 'ocr' for plain text extraction
            return res

        except Exception as e:
            return f"Error processing the image: {str(e)}"

    async def _arun(self, *args, **kwargs):
        raise NotImplementedError("Async version not implemented for this tool.")

In [45]:
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import google.generativeai as genai
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import Tool
from langchain.agents import initialize_agent

# Define the input schema for the chain (both tools need to be used together)
class KeywordGenerationInput(BaseModel):
    image_path: str  # Path to the image file

# Define the prompt for keyword generation
prompt_template = """
You are a powerful AI assistant skilled in generating keywords for search engines.
Given an image caption and OCR-extracted text, your task is to generate a list of highly relevant keywords.
Make sure the keywords are related to the content in the image and cover different aspects that might be searched for.
Here's the image caption and the extracted text:

Image Caption: {caption}
OCR Extracted Text: {ocr_text}

Please generate a list of keywords (comma-separated).
"""

# Create the prompt using the template
prompt = PromptTemplate(input_variables=["caption", "ocr_text"], template=prompt_template)

# Create the LLMChain for keyword generation 
GOOGLE_API_KEY = ""
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=GOOGLE_API_KEY, max_output_tokens=100)
chain = LLMChain(prompt=prompt, llm=llm)

# Define a combined tool that runs the captioning and OCR tools and passes their results to Google Gemini
class KeywordGenerationTool():
    name: str = "keyword_generation"
    description: str = "Generate relevant keywords from an image using both image captioning and OCR text extraction."
    args_schema: Type[BaseModel] = KeywordGenerationInput

    def __init__(self, captioning_tool, ocr_tool, chain):
        super().__init__()
        self.captioning_tool = captioning_tool
        self.ocr_tool = ocr_tool
        self.chain = chain

    def run(self, image_path: str) -> str:
        """
        Run the image captioning tool, the OCR tool, and generate keywords via the LLM.
        Args:
            image_path (str): Path to the image file.
        Returns:
            str: Generated list of keywords.
        """
        try:
            # Step 1: Generate the caption using the image captioning tool
            caption = self.captioning_tool.run(image_path)
            # print(f"Caption: {caption}")

            # Step 2: Extract text using the OCR tool
            ocr_text = self.ocr_tool.run(image_path)
            # print(f"OCR Text: {ocr_text}")

            # Step 3: Use Google Gemini (via LLMChain) to generate keywords based on both caption and OCR text
            keywords = self.chain.run(caption=caption, ocr_text=ocr_text)
            return [keyword.strip() for keyword in keywords.split(',')]
        except Exception as e:
            return f"Error generating keywords: {str(e)}"

    async def _arun(self, *args, **kwargs):
        raise NotImplementedError("Async version not implemented for this tool.")

In [43]:
# Create instances of the captioning tool and OCR tool
captioning_tool = ImageCaptioningTool()
ocr_tool = OCRTool()

# Create an instance of the keyword generation tool with the chain
keyword_tool = KeywordGenerationTool(captioning_tool, ocr_tool, chain)

Using device: cuda
Using device: cuda


In [44]:
# Path to your image file
image_path = "/content/img.png"

# Generate keywords using the tool
keywords = keyword_tool.run(image_path)
print(f"Generated Keywords: {keywords}")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Caption: a diagram of a system with a number of codes
OCR Text: uses 1..* ServiceOperator 1..* uses <<System>> uses 0.* <<CO>> ManagementApplication UserAgent 1..* 1..* 1..* controlslifecycleof 1..1 1..1 1..1 1..1 uses <<CO>> 1..1 <<CO>> controlslifecycle ServiceTemplateHandler SubscriberManager 1..1 1..1 1..1 controls controls controls 0..* 0..* 0..* <<IO>> <<IO>> <<IO>> ServiceTemplate Subscriber UserGroup <<IO>> SubscriptionContract uses uses 0..* 0..* 0..* 1..1 controls <<CO>> notifies <<CO>> SubscriptionRegistrar SubscriptionAgent 1..1 0..* 0..*
Generated Keywords: system diagram, UML diagram, class diagram, software architecture, service operator, management application, user agent, service template handler, subscriber manager, subscription registrar, subscription agent, service template, subscriber, user group, subscription contract,  CO (control object), IO (input/output), controls lifecycle, uses relationship, cardinality, 1..*, 0..*, 1..1, software design, system design,  mod