In [1]:
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
import requests
from io import BytesIO

device = "cuda"

tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4v-9b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    "THUDM/glm-4v-9b",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).to(device).eval()
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]



In [2]:
import requests
from PIL import Image
from io import BytesIO
import torch
import textwrap
from IPython.display import display
import ipywidgets as widgets
import os

# Function to pretty print text with specified width
def pretty_print(text, width=70):
    wrapped_text = textwrap.fill(text, width=width)
    text_output_area.append_stdout(wrapped_text + '\n')

# Function to load image from URL or path
def load_image(image_path_or_url):
    if os.path.isfile(image_path_or_url):
        image = Image.open(image_path_or_url).convert('RGB')
    else:
        response = requests.get(image_path_or_url)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    return image

# Function to process the query and generate a response
def process_query(query, image_path_or_url=None):
    image_output_area.clear_output()
    text_output_area.clear_output()

    image = None
    if image_path_or_url:
        try:
            image = load_image(image_path_or_url)
            with image_output_area:
                display(image)
        except Exception as e:
            with image_output_area:
                display(widgets.HTML("<b>Image not found or invalid URL. Displaying query response only.</b>"))

    # Prepare inputs
    inputs = tokenizer.apply_chat_template([{"role": "user", "image": image, "content": query}],
                                           add_generation_prompt=True, tokenize=True, return_tensors="pt",
                                           return_dict=True)  # chat mode

    inputs = inputs.to(device)
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)
        outputs = outputs[:, inputs['input_ids'].shape[1]:]
        output = tokenizer.decode(outputs[0])
        # if output ends with , remove it.
        if output.endswith(tokenizer.eos_token):
            output = output[:-len(tokenizer.eos_token)]

    # Pretty print the output within the text output area
    with text_output_area:
        wrapped_text = textwrap.fill(output, width=70)
        display(widgets.HTML(f"<pre>{wrapped_text}</pre>"))

# Function to display image when URL is updated
def display_image(image_path_or_url):
    image_output_area.clear_output()
    if image_path_or_url:
        try:
            image = load_image(image_path_or_url)
            with image_output_area:
                display(image)
        except Exception as e:
            with image_output_area:
                display(widgets.HTML("<b>Image not found or invalid URL. Displaying query response only.</b>"))

# Create an input Textarea widget for the query
query_input = widgets.Textarea(
    value='What can you tell me about the human pose?',
    placeholder='Type your query here',
    description='Query:',
    disabled=False,
    layout=widgets.Layout(width='75%')
)

# Create an input Textarea widget for the image URL or path
image_input = widgets.Textarea(
    value='https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg',
    placeholder='Enter image URL or path',
    description='Image URL:',
    disabled=False,
    layout=widgets.Layout(width='50%')
)

# Create a FileUpload widget for image upload
upload_button = widgets.FileUpload(
    accept='image/*',  # Accept image files only
    multiple=False,  # Single file upload
    layout=widgets.Layout(width='25%')
)

# Create a button widget to submit the query
submit_button = widgets.Button(description="Submit", layout=widgets.Layout(width='75%'))

# Create separate output areas to display image and text results
image_output_area = widgets.Output(layout=widgets.Layout(display='flex', justify_content='center'))
text_output_area = widgets.Output(layout=widgets.Layout(display='flex', justify_content='center'))

# Define the button click event handler for query submission
def on_button_click(b):
    query = query_input.value
    image_path_or_url = image_input.value.strip()
    process_query(query, image_path_or_url)

# Define the handler for image URL input change
def on_image_url_change(change):
    image_path_or_url = change['new'].strip()
    display_image(image_path_or_url)

# Define the handler for file upload
def on_file_upload(change):
    uploaded_file = list(upload_button.value.values())[0]
    image_path_or_url = uploaded_file['name']
    with open(image_path_or_url, 'wb') as f:
        f.write(uploaded_file['content'])
    display_image(image_path_or_url)

# Bind the button click event to the handler
submit_button.on_click(on_button_click)

# Bind the image URL input change event to the handler
image_input.observe(on_image_url_change, names='value')

# Bind the file upload event to the handler
upload_button.observe(on_file_upload, names='value')

# Arrange the input widgets in a horizontal box layout for image input
image_input_widgets = widgets.HBox([image_input, upload_button], layout=widgets.Layout(align_items='center', width='75%'))

# Arrange all input widgets in a vertical box layout
input_widgets = widgets.VBox([query_input, image_input_widgets, submit_button], layout=widgets.Layout(align_items='center'))

# Display the input widgets and the output areas
display(input_widgets)
display(image_output_area)
display(text_output_area)

# Pre-render the image from the default URL
display_image(image_input.value)

VBox(children=(Textarea(value='What can you tell me about the human pose?', description='Query:', layout=Layou…

Output(layout=Layout(display='flex', justify_content='center'))

Output(layout=Layout(display='flex', justify_content='center'))