In [None]:
import torch
import gradio as gr 
import re 
from PIL import Image
from transformers import AutoTokenizer, ViTImageProcessor, VisionEncoderDecoderModel

In [None]:
# Set the device to CPU
device = 'cpu'

# Load the pre-trained checkpoints for the image captioning model
encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"

# Instantiate a pre-trained image feature extractor, tokenizer, and image captioning model
feature_extractor = ViTImageProcessor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)

In [None]:
def predict(image, max_length=64, num_beams=4):
    """
    Generates a textual caption for the given image using a pre-trained image captioning model.

    Parameters:
    image (PIL.Image): The input image.
    max_length (int): The maximum length of the generated caption.
    num_beams (int): The number of beams to use in beam search decoding.

    Returns:
    str: The generated textual caption.
    """
    # Convert the input image to RGB and extract its features using a pre-trained image feature extractor.
    image = image.convert('RGB')
    image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)

    # Define a lambda function to clean up the generated caption text.
    clean_text = lambda x: x.replace('','').split('\n')[0]

    # Generate the caption using the pre-trained image captioning model.
    caption_ids = model.generate(image, max_length=max_length, num_beams=num_beams)[0]
    caption_text = clean_text(tokenizer.decode(caption_ids))

    # Print the type of the generated caption text and return it.
    print(type(caption_text))
    return caption_text


# def set_example_image(example: list) -> dict:
#     return gr.Image.update(value=example[0])

In [None]:
# Define the user interface using the `gr.Interface` function
interface = gr.Interface(
    fn=predict,                                      # The function to run when the user inputs an image
    inputs=gr.Image(type="pil"),                     # The input widget, which allows the user to upload an image
    outputs=gr.Textbox("Caption Text")               # The output widget, which displays the generated caption
)

# Launch the user interface
interface.launch()
