In [None]:
## Please refer to the README.md for any difficulties running this notebook.
import gc
import time

import librosa
import numpy as np
import requests
import torch
from PIL import Image
from transformers import (
    AutoModelForCausalLM,
    AutoProcessor,
    BitsAndBytesConfig,
    GenerationConfig,
    Qwen2AudioForConditionalGeneration,
)

In [None]:
import io
import re
from io import BytesIO

from PIL import Image, ImageDraw

In [None]:
import panel as pn

pn.extension("filedropper")

In [None]:
model_store = {
    "Loaded": False,
    "History": [],
    "Model": None,
    "Processor": None,
    "Loaded": False,
}

In [None]:
## Update with references to your downloaded model paths
model_path = "/home/example/common/molmo_aria/Molmo-7B-D-0924"
aria_model_path = "/shared/example/models/Aria"
qwen_audio_model_path = "/home/example/common/LLMs/Qwen2-Audio-7B-Instruct"

In [None]:
def load_model(event):
    global model_store
    global model_info_pane
    if model_store["Model"]:
        model_cleanup()

    match toggle_group.value:
        case "Molmo-7B-D-0924":
            model_info_pane.object = f"<p>Loading {toggle_group.value}...</p>"
            model_store["Processor"] = AutoProcessor.from_pretrained(
                model_path,
                trust_remote_code=True,
                torch_dtype=torch.bfloat16,
                device_map="auto",
            )
            model_store["Model"] = AutoModelForCausalLM.from_pretrained(
                model_path,
                trust_remote_code=True,
                torch_dtype=torch.bfloat16,
                device_map="auto",
            )
            model_info_pane.object = f"<p>{toggle_group.value} loaded.</p>"
            model_store["Loaded"] = True
        case "Molmo-7B-D-0924-4bit":
            model_info_pane.object = f"<p>Loading {toggle_group.value}...</p>"
            model_store["Processor"] = AutoProcessor.from_pretrained(
                model_path,
                trust_remote_code=True,
                torch_dtype=torch.bfloat16,
                device_map="auto",
            )
            arguments = {
                "device_map": "auto",
                "torch_dtype": "auto",
                "trust_remote_code": True,
            }
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="fp4",  # or nf4
                bnb_4bit_use_double_quant=False,
            )
            arguments["quantization_config"] = quantization_config
            model_store["Model"] = AutoModelForCausalLM.from_pretrained(
                model_path, **arguments
            )
            model_info_pane.object = f"<p>{toggle_group.value} loaded.</p>"
            model_store["Loaded"] = True
        case "Aria":
            model_info_pane.object = f"<p>Loading {toggle_group.value}...</p>"

            model_id_or_path = aria_model_path
            model_store["Processor"] = AutoProcessor.from_pretrained(
                model_id_or_path, trust_remote_code=True
            )
            model_store["Model"] = AutoModelForCausalLM.from_pretrained(
                model_id_or_path,
                device_map="auto",
                torch_dtype=torch.bfloat16,
                trust_remote_code=True,
            )

            model_info_pane.object = f"<p>{toggle_group.value} loaded.</p>"
            model_store["Loaded"] = True
        case "Qwen2-Audio":
            model_info_pane.object = f"<p>Loading {toggle_group.value}...</p>"

            model_id_or_path = qwen_audio_model_path
            model_store["Processor"] = AutoProcessor.from_pretrained(model_id_or_path)
            model_store["Model"] = Qwen2AudioForConditionalGeneration.from_pretrained(
                model_id_or_path, device_map="auto"
            )

            model_info_pane.object = f"<p>{toggle_group.value} loaded.</p>"
            model_store["Loaded"] = True
        case _:
            pass

In [None]:
def model_cleanup():
    global model_store
    global model_info_pane
    if model_store["Model"]:
        model_info_pane.object = "<p><b>No Model Loaded</b></p>"
        del model_store["Model"]
        del model_store["Processor"]
        gc.collect()
        torch.cuda.empty_cache()
        model_store["Model"] = None
        model_store["Processor"] = None
        model_store["Loaded"] = False

In [None]:
def parse_points(points_str):
    # Regex to extract each <points> tag with multiple x and y pairs
    point_tags = re.findall(r"<points (.*?)>(.*?)</points>", points_str)
    if len(point_tags) == 0:
        point_tags = re.findall(r"<point (.*?)>(.*?)</point>", points_str)
    parsed_points = []
    if len(point_tags) == 0:
        return None

    for attributes, label in point_tags:
        coordinates = re.findall(r'x\d+="(.*?)" y\d+="(.*?)"', attributes)
        if not coordinates:
            single_coordinate = re.findall(r'x="(.*?)" y="(.*?)"', attributes)
            if single_coordinate:
                coordinates = [single_coordinate[0]]
        parsed_points.append(
            {
                "label": label,
                "coordinates": [(float(x), float(y)) for x, y in coordinates],
            }
        )
    return parsed_points

In [None]:
def overlay_points(points_data):
    global file_dropper
    if file_dropper.value:
        file_name, file_content = next(iter(file_dropper.value.items()))
        image = Image.open(io.BytesIO(file_content))
    else:
        return

    draw = ImageDraw.Draw(image)
    width, height = image.size

    for point_data in points_data:
        label = point_data["label"]
        for x_percent, y_percent in point_data["coordinates"]:
            x = (x_percent / 100) * width
            y = (y_percent / 100) * height
            radius = int(height / 55)
            draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill="blue")

        # Optionally, add label text next to the first coordinate
        # if point_data["coordinates"]:
        #     x, y = point_data["coordinates"][0]
        #     draw.text((x, y - 10), label, fill="yellow")

    image_pane.object = image

In [None]:
def display_image(event):
    if file_dropper.value:
        if list(file_dropper.mime_type.values())[0].split("/")[0] == "image":
            audio_pane.object = None
            audio_pane.visible = False
            file_name, file_content = next(iter(file_dropper.value.items()))
            image = Image.open(io.BytesIO(file_content))
            image_preview_html.object = "<p>Scaled Image Preview:</p>"
            image_pane.object = image
        elif list(file_dropper.mime_type.values())[0].split("/")[0] == "audio":
            image_pane.object = None
            file_name, file_content = next(iter(file_dropper.value.items()))
            image_preview_html.object = "<p>Audio Track:</p>"
            audio = librosa.load(io.BytesIO(file_content))
            audio_pane.sample_rate = sample_rate = audio[1]
            audio_pane.object = np.int16(np.array(audio[0], dtype=np.float32) * 32767)
            audio_pane.visible = True
    else:
        image_preview_html.object = "<p></p>"
        image_pane.object = None
        audio_pane.object = None
        audio_pane.visible = False

In [None]:
def build_chat_history(chat_interface):
    return [{"role": i.user, "content": i.object} for i in chat_interface.objects]


def compile_prompt_gguf(
    history,
    user_name,
    assistant_name,
    system_prompt="You are an unbiased, helpful assistant.",
):
    messages = []
    for i in history:
        if i["role"] == user_name:
            messages.append(
                {"role": "user", "content": [{"text": i["content"], "type": "text"}]}
            )
        elif i["role"] == assistant_name:
            messages.append(
                {
                    "role": "assistant",
                    "content": [{"text": i["content"], "type": "text"}],
                }
            )
        else:
            pass

    if messages[-1]["role"] == "user":
        messages[-1]["content"].append({"text": None, "type": "image"})
    return messages


def compile_prompt(
    history,
    user_name,
    assistant_name,
    system_prompt="You are an unbiased, helpful assistant.",
):
    texts = [f""]
    for i in history:
        if i["role"] == user_name:
            texts.append(f'<|startoftext|>USER: {i["content"]}\nASSISTANT:')
        elif i["role"] == assistant_name:
            if i["content"][-13:] == "<|endoftext|>":
                texts.append(f'{i["content"]}\n')
            elif i["content"][-15:] == "<|endoftext|>\n":
                texts.append(f'{i["content"]}')
            else:
                texts.append(f'{i["content"]}<|endoftext|>\n')
        else:
            pass
    return "".join(texts)

In [None]:
toggle_group = pn.widgets.ToggleGroup(
    name="Model Select",
    options=["Molmo-7B-D-0924", "Molmo-7B-D-0924-4bit", "Aria", "Qwen2-Audio"],
    behavior="radio",
)
load_button = pn.widgets.Button(name="Load Model", button_type="primary")
load_button.on_click(load_model)

model_info_pane = pn.pane.HTML("<p><b>No Model Loaded</b></p>")

In [None]:
file_dropper = pn.widgets.FileDropper(
    accepted_filetypes=["image/*", "audio/*"],
    multiple=False,
    max_file_size="10MB",
    width=300,
    height=95,
)

image_pane = pn.pane.Image(sizing_mode="scale_width", max_width=550)
audio_pane = pn.pane.Audio(sizing_mode="scale_width", max_width=550, visible=False)
image_preview_html = pn.pane.HTML("<p></p>")
file_dropper.param.watch(display_image, "value")

image_load = pn.Column(
    file_dropper,
    pn.Column(
        image_preview_html,
        audio_pane,
        image_pane,
    ),
)

In [None]:
left_bar = pn.Column(
    toggle_group,
    pn.Row(load_button, model_info_pane),
    image_load,
    width=600,
    height=800,
)

In [None]:
def callback_vlm(contents: str, user: str, instance: pn.chat.ChatInterface):
    global model_store
    global file_dropper

    if not model_store["Loaded"]:
        instance.send(
            "Loading model; one moment please...",
            user="System",
            respond=False,
        )
        load_model(None)
        null_and_void = instance.objects.pop()

    if toggle_group.value in ["Molmo-7B-D-0924", "Molmo-7B-D-0924-4bit"]:
        if file_dropper.value:
            if list(file_dropper.mime_type.values())[0].split("/")[0] == "image":
                file_name, file_content = next(iter(file_dropper.value.items()))
                image = Image.open(io.BytesIO(file_content))
        else:
            return "Please upload an image using the file dropper in order to talk over that image."

        prompt_full = compile_prompt(build_chat_history(instance), "User", "Assistant")

        inputs = model_store["Processor"].process(images=[image], text=prompt_full)

        inputs = {
            k: v.to(model_store["Model"].device).unsqueeze(0) for k, v in inputs.items()
        }

        with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
            output = model_store["Model"].generate_from_batch(
                inputs,
                GenerationConfig(max_new_tokens=1250, stop_strings="<|endoftext|>"),
                tokenizer=model_store["Processor"].tokenizer,
            )

        generated_tokens = output[0, inputs["input_ids"].size(1) :]
        model_store["History"].append(generated_tokens)
        generated_text = model_store["Processor"].tokenizer.decode(
            generated_tokens, skip_special_tokens=True
        )

        points_data = parse_points(generated_text)
        if points_data:
            overlay_points(points_data)
        time.sleep(0.1)
        return generated_text
    elif toggle_group.value == "Aria":
        if file_dropper.value:
            if list(file_dropper.mime_type.values())[0].split("/")[0] == "image":
                file_name, file_content = next(iter(file_dropper.value.items()))
                image = Image.open(io.BytesIO(file_content))
        else:
            return "Please upload an image using the file dropper in order to talk over that image."

        messages = compile_prompt_gguf(
            build_chat_history(instance), "User", "Assistant"
        )
        text = model_store["Processor"].apply_chat_template(
            messages, add_generation_prompt=True
        )
        inputs = model_store["Processor"](text=text, images=image, return_tensors="pt")
        inputs["pixel_values"] = inputs["pixel_values"].to(model_store["Model"].dtype)
        inputs = {k: v.to(model_store["Model"].device) for k, v in inputs.items()}

        with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
            output = model_store["Model"].generate(
                **inputs,
                max_new_tokens=500,
                stop_strings=["<|im_end|>"],
                tokenizer=model_store["Processor"].tokenizer,
                do_sample=True,
                temperature=0.7,
            )
            output_ids = output[0][inputs["input_ids"].shape[1] :]
            result = model_store["Processor"].decode(
                output_ids, skip_special_tokens=True
            )
            result = result.replace("<|im_end|>", "")
        time.sleep(0.1)
        return result
    elif toggle_group.value == "Qwen2-Audio":
        if file_dropper.value:
            if list(file_dropper.mime_type.values())[0].split("/")[0] == "audio":
                _, audio_file_content = next(iter(file_dropper.value.items()))
        else:
            return "Please attach an audio sample of the appropriate file format"

        messages = build_chat_history(instance)[-1]
        if messages["role"] == "User":
            text_input = messages["content"]
        else:
            return "Error handling input content - please restart application and try again."

        conversation = [
            {"role": "system", "content": "You are a helpful assistant."},
            {
                "role": "user",
                "content": [
                    {"type": "audio", "audio_url": "Filler.wav"},
                    {"type": "text", "text": text_input},
                ],
            },
        ]
        text = model_store["Processor"].apply_chat_template(
            conversation, add_generation_prompt=True, tokenize=False
        )
        audios = []
        for message in conversation:
            if isinstance(message["content"], list):
                for ele in message["content"]:
                    if ele["type"] == "audio":
                        try:
                            audios.append(
                                librosa.load(
                                    io.BytesIO(audio_file_content),
                                    sr=model_store[
                                        "Processor"
                                    ].feature_extractor.sampling_rate,
                                )[0]
                            )
                        except:
                            return "Error loading audio file, please change file dropper content to appropriate file format"

        inputs = model_store["Processor"](
            text=text, audios=audios, return_tensors="pt", padding=True
        )
        inputs.input_ids = inputs.input_ids.to("cuda")
        inputs["input_ids"] = inputs["input_ids"].to("cuda")

        generate_ids = model_store["Model"].generate(**inputs, max_length=256)
        generate_ids = generate_ids[:, inputs.input_ids.size(1) :]

        response = model_store["Processor"].batch_decode(
            generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        time.sleep(0.1)
        return response

In [None]:
with open("main.html", "r") as f:
    header_html = f.read().replace("\n", "")

In [None]:
header_pane = pn.pane.HTML(header_html, width_policy="max", sizing_mode="stretch_width")

In [None]:
chat_interface = pn.chat.ChatInterface(
    callback=callback_vlm, callback_exception="verbose"
)

In [None]:
pn.Column(header_pane, pn.Row(left_bar, chat_interface)).servable()