This is a GUI demo built with [gradio](https://gradio.app/), identical to our [HuggingFace Demo 🤗](https://huggingface.co/spaces/xinyu1205/Recognize_Anything-Tag2Text).

By running through this notebook, you can deploy the demo on your own machine.

In [None]:
%pip install -r requirements.txt
%pip install gradio

In [None]:
import torch
import torchvision.transforms as transforms
from models.tag2text import ram, tag2text_caption

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# init image transforms
image_size = 384
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# load RAM Model
model_ram = ram(
    pretrained='pretrained/ram_swin_large_14m.pth',
    image_size=image_size,
    vit='swin_l'
).eval().to(device)

# load Tag2Text Model
model_tag2text = tag2text_caption(
    pretrained='pretrained/tag2text_swin_14m.pth',
    image_size=image_size,
    vit='swin_b',
    threshold=0.68
).eval().to(device)

In [None]:
def inference_with_ram(img):
    with torch.no_grad():
        img = transform(img).unsqueeze(0).to(device)
        tags, tags_chinese = model_ram.generate_tag(img)
        return tags[0], tags_chinese[0]


def inference_with_t2t(img, input_tags):
    img = transform(img).unsqueeze(0).to(device)
    if not input_tags or input_tags.lower() == 'none':
        input_tags = None
    else:
        input_tags = [input_tags.replace(',', ' | ')]

    with torch.no_grad():
        caption, tag_predict = model_tag2text.generate(
            img,
            tag_input=input_tags,
            max_length=50,
            return_tag_predict=True
        )
        if input_tags is None:
            output_tags = tag_predict
        else:
            # re-inference with tag_input=None to get model output tags
            _, output_tags = model_tag2text.generate(
                img,
                tag_input=None,
                max_length=50,
                return_tag_predict=True
            )

    return output_tags[0], caption[0]

In [None]:
import gradio as gr


def build_gui():

    description = """
        <center><strong><font size='10'>Recognize Anything Model</font></strong></center>
        <br>
        Welcome to the Recognize Anything Model (RAM) and Tag2Text Model demo! <br><br>
        <li>
            <b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese outputs of the image tags</b>!
        </li>
        <li>
            <b>Tag2Text Model:</b> Upload your image to get the <b>tags</b> and <b>caption</b> of the image.
            Optional: You can also input specified tags to get the corresponding caption.
        </li>
    """  # noqa

    article = """
        <p style='text-align: center'>
            RAM and Tag2Text is training on open-source datasets, and we are persisting in refining and iterating upon it.<br/>
            <a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a>
            |
            <a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a>
            |
            <a href='https://github.com/xinyu1205/Tag2Text' target='_blank'>Github Repo</a>
        </p>
    """  # noqa

    with gr.Blocks(title="Recognize Anything Model") as demo:
        # components
        gr.HTML(description)

        with gr.Tab(label="Recognize Anything Model"):
            with gr.Row():
                with gr.Column():
                    ram_in_img = gr.Image(type="pil")
                    with gr.Row():
                        ram_btn_run = gr.Button(value="Run")
                        ram_btn_clear = gr.Button(value="Clear")
                with gr.Column():
                    ram_out_tag = gr.Textbox(label="Tags")
                    ram_out_biaoqian = gr.Textbox(label="标签")
            # change examples as you like
            gr.Examples(
                examples=[
                    ["images/1641173_2291260800.jpg"],
                    ["images/openset_example.jpg"]
                ],
                fn=inference_with_ram,
                inputs=[ram_in_img],
                outputs=[ram_out_tag, ram_out_biaoqian],
                cache_examples=False
            )

        with gr.Tab(label="Tag2Text Model"):
            with gr.Row():
                with gr.Column():
                    t2t_in_img = gr.Image(type="pil")
                    t2t_in_tag = gr.Textbox(
                        label="User Specified Tags (Optional, separated by comma)")
                    with gr.Row():
                        t2t_btn_run = gr.Button(value="Run")
                        t2t_btn_clear = gr.Button(value="Clear")
                with gr.Column():
                    t2t_out_tag = gr.Textbox(label="Tags")
                    t2t_out_cap = gr.Textbox(label="Caption")
            # change examples as you like
            gr.Examples(
                examples=[
                    ["images/1641173_2291260800.jpg", ""],
                    ["images/openset_example.jpg", ""]
                ],
                fn=inference_with_t2t,
                inputs=[t2t_in_img, t2t_in_tag],
                outputs=[t2t_out_tag, t2t_out_cap],
                cache_examples=False
            )

        gr.HTML(article)

        # events
        # run inference
        ram_btn_run.click(
            fn=inference_with_ram,
            inputs=[ram_in_img],
            outputs=[ram_out_tag, ram_out_biaoqian]
        )
        t2t_btn_run.click(
            fn=inference_with_t2t,
            inputs=[t2t_in_img, t2t_in_tag],
            outputs=[t2t_out_tag, t2t_out_cap]
        )

        # comment out when deployging on huggingface due to internet latency
        # images of two image panels should keep the same
        # and clear old outputs when image changes
        # def sync_img(v):
        #     return [gr.update(value=v)] + [gr.update(value="")] * 4

        # ram_in_img.upload(fn=sync_img, inputs=[ram_in_img], outputs=[
        #     t2t_in_img, ram_out_tag, ram_out_biaoqian, t2t_out_tag, t2t_out_cap
        # ])
        # ram_in_img.clear(fn=sync_img, inputs=[ram_in_img], outputs=[
        #     t2t_in_img, ram_out_tag, ram_out_biaoqian, t2t_out_tag, t2t_out_cap
        # ])
        # t2t_in_img.clear(fn=sync_img, inputs=[t2t_in_img], outputs=[
        #     ram_in_img, ram_out_tag, ram_out_biaoqian, t2t_out_tag, t2t_out_cap
        # ])
        # t2t_in_img.upload(fn=sync_img, inputs=[t2t_in_img], outputs=[
        #     ram_in_img, ram_out_tag, ram_out_biaoqian, t2t_out_tag, t2t_out_cap
        # ])

        # clear all
        def clear_all():
            return [gr.update(value=None)] * 2 + [gr.update(value="")] * 5

        ram_btn_clear.click(fn=clear_all, inputs=[], outputs=[
            ram_in_img, t2t_in_img,
            ram_out_tag, ram_out_biaoqian, t2t_in_tag, t2t_out_tag, t2t_out_cap
        ])
        t2t_btn_clear.click(fn=clear_all, inputs=[], outputs=[
            ram_in_img, t2t_in_img,
            ram_out_tag, ram_out_biaoqian, t2t_in_tag, t2t_out_tag, t2t_out_cap
        ])

    return demo

demo = build_gui()

In [None]:
demo.launch(
    server_name="127.0.0.1",  # localhost. use "0.0.0.0" to open to LAN
    share=False  # use True to acquire a temporary public domain for sharing
)