In [None]:
!pip install -Uqqq pip --progress-bar off
!pip install -qqq torch==2.1 --progress-bar off
!pip install -qqq transformers==4.34.1 --progress-bar off
!pip install -qqq accelerate==0.23.0 --progress-bar off
!pip install -qqq bitsandbytes==0.41.1 --progress-bar off
!pip install -qqq llava-torch==1.1.1 --progress-bar off

In [None]:
import requests
from PIL import Image
from io import BytesIO
import torch
from llava.model.builder import load_pretrained_model
from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava.conversation import SeparatorStyle, conv_templates
from llava.mm_utils import (
    KeywordsStoppingCriteria,
    get_model_name_from_path,
    process_images,
    tokenizer_image_token,
)

# Load the model
MODEL = "4bit/llava-v1.5-13b-3GB"
model_name = get_model_name_from_path(MODEL)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path=MODEL, model_base=None, model_name=model_name, load_4bit=True)

# Set the conversation mode
CONV_MODE = "llava_v0"

def load_image(image_file):
    if image_file.startswith("http://") or image_file.startswith("https://"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image

def process_image(image):
    args = {"image_aspect_ratio": "pad"}
    image_tensor = process_images([image], image_processor, args)
    return image_tensor.to(model.device, dtype=torch.float16)

def create_prompt(prompt: str):
    conv = conv_templates[CONV_MODE].copy()
    roles = conv.roles
    prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
    conv.append_message(roles[0], prompt)
    conv.append_message(roles[1], None)
    return conv.get_prompt(), conv

def ask_image(image: Image, prompt: str):
    image_tensor = process_image(image)
    prompt, conv = create_prompt(prompt)
    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .to(model.device)
    )

    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    stopping_criteria = KeywordsStoppingCriteria(
        keywords=[stop_str], tokenizer=tokenizer, input_ids=input_ids
    )

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=True,
            temperature=0.01,
            max_new_tokens=512,
            use_cache=True,
            stopping_criteria=[stopping_criteria],
        )
    return tokenizer.decode(
        output_ids[0, input_ids.shape[1]:], skip_special_tokens=True
    ).strip()


def main():
    image_url_or_path = input("Enter the image URL or path: ")
    question = input("Enter your question about the image: ")

    # Load and process the image
    image = load_image(image_url_or_path)

    # Generate explanation
    explanation = ask_image(image, question)
    print(explanation)

if __name__ == "__main__":
    main()

[2023-11-11 11:09:03,662] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


Downloading (…)okenizer_config.json:   0%|          | 0.00/749 [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/33.7k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/9 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00009.bin:   0%|          | 0.00/2.97G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00009.bin:   0%|          | 0.00/2.93G [00:00<?, ?B/s]

Downloading (…)l-00003-of-00009.bin:   0%|          | 0.00/2.89G [00:00<?, ?B/s]

Downloading (…)l-00004-of-00009.bin:   0%|          | 0.00/2.96G [00:00<?, ?B/s]

Downloading (…)l-00005-of-00009.bin:   0%|          | 0.00/2.89G [00:00<?, ?B/s]

Downloading (…)l-00006-of-00009.bin:   0%|          | 0.00/2.98G [00:00<?, ?B/s]

Downloading (…)l-00007-of-00009.bin:   0%|          | 0.00/2.87G [00:00<?, ?B/s]

Downloading (…)l-00008-of-00009.bin:   0%|          | 0.00/2.89G [00:00<?, ?B/s]

Downloading (…)l-00009-of-00009.bin:   0%|          | 0.00/2.72G [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/4.76k [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/154 [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

Enter the image URL or path: /content/unnamed.png
Enter your question about the image: Explain the image in detail and create a conecting story
The image is a table displaying a list of significant floods that occurred throughout the 20th century. The table is organized by date and location, providing a comprehensive overview of the floods and their impacts. The table shows that these floods have caused substantial damage, with a total of $1.5 billion in damages reported.

The table also includes a column for the amount of rainfall in each flood, which can be an important factor in determining the severity of the flood. The table is a valuable resource for understanding the history of flooding events and their consequences, as well as for informing future flood management strategies.


In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

In [None]:
!pip install -q gradio

[0m

In [None]:
import gradio as gr

def interface_function(image, question):
    # Process the image and generate the explanation
    explanation = ask_image(image, question)
    return explanation

iface = gr.Interface(
    fn=interface_function,
    inputs=[gr.inputs.Image(type="pil"), gr.inputs.Textbox(lines=2, placeholder="Enter your question about the image")],
    outputs="text",
    title="Image Explanation Generator",
    description="Upload an image and ask a question about it to generate an explanation."
)

iface.launch(debug=True ,share=True)


  super().__init__(
  super().__init__(
  super().__init__(


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://253f52d07178ce4651.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://253f52d07178ce4651.gradio.live


