<a href="https://colab.research.google.com/github/softmurata/colab_notebooks/blob/main/apps/gradiochatapp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers accelerate

In [None]:
!pip install safetensors

In [None]:
#@title chat
# google colabでは大規模モデルが動かせずクラッシュしてしまう。

In [None]:
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration

processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xxl", device_map="auto")

In [None]:
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' 
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')

question = "how many dogs are in the picture?"
inputs = processor(raw_image, question, return_tensors="pt").to("cuda")

out = model.generate(**inputs)
print(processor.decode(out[0], skip_special_tokens=True))

In [4]:
#@title image captioning
from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
)
model.to(device)

In [None]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)

generated_ids = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

In [None]:
!pip install gradio

In [3]:
from io import BytesIO

import string
import gradio as gr
import requests

In [4]:
title = """<h1 align="center">chatZumen</h1>"""
description = """Gradio demo for chatZumen, image-to-text generation from soccer R&D. To use it, simply upload your image, or click one of the examples to load them.
<br> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected."""
article = """<strong>Paper</strong>: <a href='https://arxiv.org/abs/2301.12597' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>
<br> <strong>Code</strong>: BLIP2 is now integrated into GitHub repo: <a href='https://github.com/salesforce/LAVIS' target='_blank'>LAVIS: a One-stop Library for Language and Vision</a>
<br> <strong>🤗 `transformers` integration</strong>: You can now use `transformers` to use our BLIP-2 models! Check out the <a href='https://huggingface.co/docs/transformers/main/en/model_doc/blip-2' target='_blank'> official docs </a>
<p> <strong>Project Page</strong>: <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'> BLIP2 on LAVIS</a>
<br> <strong>Description</strong>: Captioning results from <strong>BLIP2_OPT_6.7B</strong>. Chat results from <strong>BLIP2_FlanT5xxl</strong>.
"""

In [8]:
def postprocess_output(output):
    # if last character is not a punctuation, add a full stop
    if not output[0][-1] in string.punctuation:
        output[0] += "."

    return output

def inference_chat(
    image,
    text_input,
    decoding_method,
    temperature,
    length_penalty,
    repetition_penalty,
    history=[],
):
    text_input = text_input
    history.append(text_input)

    prompt = " ".join(history)
    
    # ToDo: GPT_index
    # output = query_chat_api(
    #    image, prompt, decoding_method, temperature, length_penalty, repetition_penalty
    #)
    # output = postprocess_output(output)
    output = ["hello world"]
    history += output

    chat = [
        (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
    ]  # convert to tuples of list

    return {chatbot: chat, state: history}

def inference_caption(
    image,
    decoding_method,
    temperature,
    length_penalty,
    repetition_penalty,
):
    # ToDo: yolov8 detection
    # output = query_caption_api(
    #    image, decoding_method, temperature, length_penalty, repetition_penalty
    # )
    # return output[0]
    return "great cat"

In [9]:
with gr.Blocks(
    css="""
    .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
    #component-21 > div.wrap.svelte-w6rprc {height: 600px;}
    """
) as iface:
    state = gr.State([])

    gr.Markdown(title)
    gr.Markdown(description)
    gr.Markdown(article)
    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil")

            # with gr.Row():
            sampling = gr.Radio(
                choices=["Beam search", "Nucleus sampling"],
                value="Beam search",
                label="Text Decoding Method",
                interactive=True,
            )

            temperature = gr.Slider(
                minimum=0.5,
                maximum=1.0,
                value=1.0,
                step=0.1,
                interactive=True,
                label="Temperature (used with nucleus sampling)",
            )

            len_penalty = gr.Slider(
                minimum=-1.0,
                maximum=2.0,
                value=1.0,
                step=0.2,
                interactive=True,
                label="Length Penalty (set to larger for longer sequence, used with beam search)",
            )

            rep_penalty = gr.Slider(
                minimum=1.0,
                maximum=5.0,
                value=1.5,
                step=0.5,
                interactive=True,
                label="Repeat Penalty (larger value prevents repetition)",
            )

        with gr.Column(scale=1.8):

            with gr.Column():
                caption_output = gr.Textbox(lines=1, label="Caption Output")
                caption_button = gr.Button(
                    value="Caption it!", interactive=True, variant="primary"
                )
                caption_button.click(
                    inference_caption,
                    [
                        image_input,
                        sampling,
                        temperature,
                        len_penalty,
                        rep_penalty,
                    ],
                    [caption_output],
                )

            gr.Markdown("""Trying prompting your input for chat; e.g. example prompt for QA, \"Question: {} Answer:\" Use proper punctuation (e.g., question mark).""")
            with gr.Row():
                with gr.Column(
                    scale=1.5, 
                ):
                    chatbot = gr.Chatbot(
                        label="Chat Output (from FlanT5)",
                    )

                # with gr.Row():
                with gr.Column(scale=1):
                    chat_input = gr.Textbox(lines=1, label="Chat Input")
                    chat_input.submit(
                        inference_chat,
                        [
                            image_input,
                            chat_input,
                            sampling,
                            temperature,
                            len_penalty,
                            rep_penalty,
                            state,
                        ],
                        [chatbot, state],
                    )

                    with gr.Row():
                        clear_button = gr.Button(value="Clear", interactive=True)
                        clear_button.click(
                            lambda: ("", [], []),
                            [],
                            [chat_input, chatbot, state],
                            queue=False,
                        )

                        submit_button = gr.Button(
                            value="Submit", interactive=True, variant="primary"
                        )
                        submit_button.click(
                            inference_chat,
                            [
                                image_input,
                                chat_input,
                                sampling,
                                temperature,
                                len_penalty,
                                rep_penalty,
                                state,
                            ],
                            [chatbot, state],
                        )

            image_input.change(
                lambda: ("", "", []),
                [],
                [chatbot, caption_output, state],
                queue=False,
            )

iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(enable_queue=True)

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://1121abc5a93b34a137.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces


