In [None]:
!pip install gradio
!pip install Pillow
!pip install datasets
!pip install llm-lens
!pip install torch
!pip install torchvision
!pip install SpeechRecognition
!pip install moviepy
!pip install opencv-python
!pip install salesforce-lavis
!pip install numpy
!pip install ninja
!pip install sentencepiece
!pip install icecream
!pip install transformers==4.28.1
!pip install tqdm
!pip install decord==0.6.0
!pip install timm==0.6.7
!pip install oss2
!pip install markdown2
!pip install hjson
!pip install einops
!pip install wget
!pip install accelerate
!pip install flash-attn --no-build-isolation

####Gradio LENS

In [None]:
"""
Python file that stores the refactored classes for LENS and BLIP_VQA to make the app.py much cleaner
Author: Aditya Ramanath Poonja
huggingFace : https://huggingface.co/pooadi
GitHub      : https://github.com/pooadi
"""

import re
import torch
from lavis.models import load_model_and_preprocess
from lens import Lens, LensProcessor
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


########################################################################################################################

class LENSInference:
    """
    class for running LENs Inference
    """

    def __init__(self):
        """
        Initialisation function of the class
        """
        self._lens = Lens()
        self._processor = LensProcessor()

        # regex to remove the <pad> and </s> from the output of the LLMs
        self._remWord1 = re.compile('(\s*)<pad>(\s*)')
        self._remWord2 = re.compile('(\s*)</s>(\s*)')

    ####################################################################################################################

    def __call__(
            self,
            imageFrame,
            question
    ):
        """
        function to be called to run the inference
        :param imageFrame: List of image to be inferred
        :param question: List of questions to be passed into the network
        :return: [lLMOutput, ImageCaption] : Output of LLM (LENS) + image caption
        """

        # inferring the initial vision models like BLIP, CLIP

        samples = self._processor(imageFrame, question)
        outputInit = self._lens(samples)

        # feeding the output of the vision models to a frozen LLM
        tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small", truncation_side='left', padding=True)
        lLMModel = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
        inputIds = tokenizer(samples["prompts"], return_tensors="pt").input_ids
        outputLLM = lLMModel.generate(inputIds)
        lLMOutput = str(tokenizer.decode(outputLLM[0]))

        # use regex to remove unnecessary prefix and suffix
        lLMOutput = self._remWord1.sub('', lLMOutput)
        lLMOutput = self._remWord2.sub('', lLMOutput)

        return [lLMOutput, outputInit["caption"][0]]


########################################################################################################################

class BLIPVQAInference:
    """
    class for running BLIP_VQA inference
    """

    def __init__(self):
        """
        Initialisation function of the class
        """
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Loading the BLIP VQA model
        self._modelBLIPVQA, \
            self._visProcessors, \
            self._txtProcessors = load_model_and_preprocess(name="blip_vqa", model_type="vqav2",
                                                            is_eval=True, device=str(self._device))

    ################################################################################################################

    def __call__(
            self,
            imageFrame,
            question):
        """
        function to be called to run the inference
        :param imageFrame: Image to be inferred
        :param question: Question to be passed into the network
        :return: outputBLIPVQA : Output of the BLIP_VQA
        """

        imageBLIPVQA = self._visProcessors["eval"](imageFrame).unsqueeze(0).to(self._device)

        questionBLIPVQA = self._txtProcessors["eval"](question)

        # blip VQA output
        bLIPVQAOutput = self._modelBLIPVQA.predict_answers(
            samples={"image": imageBLIPVQA, "text_input": questionBLIPVQA},
            inference_method="generate")

        return bLIPVQAOutput[0]

########################################################################################################################

In [None]:
"""
This code is written and maintained by Aditya Ramanath Poonja
huggingFace : https://huggingface.co/pooadi
GitHub      : https://github.com/pooadi
"""

import torch
import gradio as gr
import decord
import numpy as np
from PIL import Image
#from . import dirPath


########################################################################################################################
class DemoGUIGradio:

    ########################################## - INITIALIZE - ##########################################################

    def __init__(self):
        """
        Initialisation function
        """
        # Load the first model
        self._model1 = LENSInference()

        # Loading the second model
        self._model2 = BLIPVQAInference()

    ######################################- GRADIO BACKEND FUNCTIONS - #################################################

    @staticmethod
    def _getFrameIds(startFrame,
                     endFrame,
                     numSegments=32,
                     jitter=True):

        segSize = float(endFrame - startFrame - 1) / numSegments
        seq = []

        for i in range(numSegments):

            start = int(np.round(segSize * i) + startFrame)
            end = int(np.round(segSize * (i + 1)) + startFrame)
            end = min(end, endFrame)
            if jitter:
                frameId = np.random.randint(low=start, high=(end + 1))
            else:
                frameId = (start + end) // 2
            seq.append(frameId)

        return seq

    ####################################################################################################################

    def _imageInference(self,
                        imageFrame,
                        question):
        """
        Function to run when you click on the submit button

        :param imageFrame: Image uploaded by the user
        :param question: question asked by the user
        :return: [LENS Answer, BLIP VQA Answer, Image Description]

        """

        if question == "":
            question = "What is the sentiment expressed in the image?"

        model1Output = ["", ""]

        with torch.no_grad():

            # inference of Model 2
            model2Output = self._model2(imageFrame, question)

            # inference of Model 1
            model1Output = self._model1([imageFrame], [question])

        return [str(model1Output[0]), str(model2Output), str(model1Output[1])]

    ######################################- GRADIO BACKEND FUNCTIONS - #################################################

    def _videoInference(self,
                        videoFile,
                        question,
                        numSegments=4,
                        strideSize=16):
        """
        Function to run when you click on the submit button
        :param videoFile: Video uploaded by the user
        :param question: question asked by the user
        :param numSegments:
        :param strideSize:
        :return: [LENS Answer, BLIP VQA Answer, Image Description]
        """

        if question == "":
            question = "What is the sentiment expressed in the image?"

        if numSegments == "":
          numSegments = 4

        if strideSize =="":
          strideSize = 16

        model1Output = ""
        model2Output = ""
        imageDescription = ""

        inputVideo = decord.VideoReader(videoFile)
        frameSampleSize = int(numSegments) * int(strideSize)
        maxStartFrame = len(inputVideo) - frameSampleSize
        fps = inputVideo.get_avg_fps()
        currFrame = 0

        while currFrame == 0 or currFrame < maxStartFrame:

            stopFrame = min(currFrame + frameSampleSize, len(inputVideo))
            currSec, stopSec = currFrame / fps, stopFrame / fps
            frameIds = self._getFrameIds(currFrame, stopFrame, numSegments=numSegments, jitter=False)
            frames = inputVideo.get_batch(frameIds).asnumpy()
            textToAdd = f"{'-' * 30} Predictions From: {currSec:2.3f}-{stopSec:2.3f} seconds {'-' * 30}\n"
            model1Output += textToAdd
            model2Output += textToAdd
            imageDescription += textToAdd
            i = 0

            for frame in frames:

                imageFrame = Image.fromarray(frame, "RGB")

                with torch.no_grad():

                    # inference of model 2
                    _model2Output = self._model2(imageFrame, question)
                    textToAdd = f"Frame {frameIds[i]}: {_model2Output}\n"
                    model2Output += textToAdd

                    # inference of Model 1
                    _model1Output = self._model1([imageFrame], [question])
                    textToAdd = f"Frame {frameIds[i]}: {_model1Output[0]}\n"
                    model1Output += textToAdd
                    textToAdd = f"Frame {frameIds[i]}: {_model1Output[1]}\n"
                    imageDescription += textToAdd

                i += 1
            currFrame += frameSampleSize

        return [str(model1Output), str(model2Output), str(imageDescription)]
        #return model2Output

    ######################################- GRADIO FRONTEND FUNCTIONS - ################################################

    def __call__(self):
        """
        Function that creates the frontend Gradio interface
        """

        # Creating a block for the app
        with gr.Blocks(title="Demo for Sentiment Detection using Multimodal LLMs") as interface:

            # title of the demo
            # title = "Demo for Sentiment Detection using Multimodal LLMs"

            # description of the demo
            description = "Gradio initial demo for a proposed Sentiment/Emotion Detection using Multimodal LLM"

            # example Files

            """exampleInputs = [[f"{dirPath}/Examples/Crying1.jpeg", "What is the facial expression of the person in the "
                                                                  "image?"],
                             [f"{dirPath}/Examples/Crying2.jpeg", "Describe the emotion expressed in the image"],
                             [f"{dirPath}/Examples/Happy1.jpeg", "What is the person expressing in the image?"],
                             [f"{dirPath}/Examples/Happy2.jpeg", "What is the sentiment expressed in the image?"],
                             [f"{dirPath}/Examples/SadGirl.jpeg", "What is the person doing in the image?"]]

            exampleVideoInputs = [[f"{dirPath}/Examples/Example1.mp4",
                                   "What is the facial expression of the person in the "
                                                                  "image?"]]"""

            # Setting up Markdown for Title and Description
            # gr.Markdown(value=f"# <p style=\"text-align: center;\"> {title} </p>")
            gr.Markdown(value=f"#### {description}")

            with gr.Tab("Image File"):
                # Image Tab

                with gr.Row():

                    with gr.Column():

                        # The input components list
                        inputs = [gr.Image(type='pil', interactive=True),
                                  gr.Textbox(lines=2, label="Question",
                                             placeholder="Type your question here (Default question:"
                                                         " What is the sentiment expressed in the image?)...")]
                        with gr.Row():

                            # The clear and the submit button objects
                            clearButton = gr.ClearButton()
                            submitButton = gr.Button(value="Submit", variant="primary")

                    with gr.Column():

                        # The output components list
                        outputs = [gr.Textbox(label="Model 1 Answer"),
                                   gr.Textbox(label="Model 2 Answer"),
                                   gr.Textbox(label="Image Description")]

                # Adding components for the clear Button to clear when it is clicked
                clearButton.add(components=inputs + outputs)

                # Adding the details for the submit button click action
                submitButton.click(fn=self._imageInference, inputs=inputs, outputs=outputs)

                # setting up examples
                #examples = gr.Examples(examples=exampleInputs, inputs=inputs, outputs=outputs, fn=self._imageInference,
                #                       cache_examples=True)

            with gr.Tab("Video File"):

                with gr.Row():

                    with gr.Column():

                        # The input components list
                        inputs = [gr.Video(label="Video File"),
                                  gr.Textbox(lines=2, label="Question",
                                             placeholder="Type your question here (Default question:"
                                                         " What is the sentiment expressed in the image?)...")]
                        with gr.Row():

                            # the secondary inputs
                            inputs2 = [gr.Textbox(label="Number of Segments",
                                                  placeholder="Enter an integer value (Default: 4)"),
                                       gr.Textbox(label="Stride Size",
                                                  placeholder="Enter an integer value (Default: 16)")]
                        with gr.Row():

                            # The clear and the submit button objects
                            clearButton = gr.ClearButton()
                            submitButton = gr.Button(value="Submit", variant="primary")

                    with gr.Column():

                        # The output components list
                        outputs = [gr.Textbox(label="Model 1 Answer", max_lines=5),
                                   gr.Textbox(label="Model 2 Answer", max_lines=5),
                                   gr.Textbox(label="Description", max_lines=5)]
                        #outputs = gr.Textbox(label="Image Description")

                    # Adding components for the clear Button to clear when it is clicked
                clearButton.add(components=inputs + outputs)

                # Adding the details for the submit button click action
                submitButton.click(fn=self._videoInference, inputs=inputs + inputs2, outputs=outputs)

                # setting up examples
                #examples = gr.Examples(examples=exampleVideoInputs, inputs=inputs, outputs=outputs,
                #                       fn=self._videoInference,
                #                       cache_examples=True)

        # Launch interface
        interface.launch(share=True, debug=True)

########################################################################################################################

# Load gradio class
gradioUI = DemoGUIGradio()


######################################- GRADIO UI DEPLOYMENT - #####################################################

# deploy UI
gradioUI()

########################################################################################################################


#### Gradio Chatbot UI

In [None]:
# @title MPLUG-OWL (a lighter LLaVA for testing purposes)
"""
This code is written and maintained by Aditya Ramanath Poonja
huggingFace : https://huggingface.co/pooadi
GitHub      : https://github.com/pooadi
"""


# code to mount my drive
!nvidia-smi
from google.colab import drive

drive.mount('/content/drive')
%cd /content/drive/MyDrive/mPLUG-Owl

import os
import argparse
import datetime
import json
import os
import time
import torch
import gradio as gr
import requests

from serve.conversation import default_conversation
from serve.gradio_css import code_highlight_css
from serve.gradio_patch import Chatbot as grChatbot
from serve.serve_utils import (
    add_text, after_process_image, disable_btn, no_change_btn,
    downvote_last_response, enable_btn, flag_last_response,
    get_window_url_params, init, regenerate, upvote_last_response,
)
from serve.model_worker import mPLUG_Owl_Server
from serve.model_utils import post_process_code

########################################################################################################################
class DemoGUIGradio:

    @staticmethod
    def loadInterface(state, request: gr.Request):

      state = default_conversation.copy()

      return (state)

    ####################################################################################################################

    @staticmethod
    def clearHistory(state, request: gr.Request):

      state = default_conversation.copy()

      return (state, state.to_gradio_chatbot(), "", None, None)

    ####################################################################################################################

    @staticmethod
    def addTextHttpBot( state, text, image, video, max_output_tokens, temperature, top_k, top_p,
                       num_beams, no_repeat_ngram_size, length_penalty,
                        do_sample, request: gr.Request):

      if len(text) <= 0 and (image is None or video is None):

        state.skip_next = True
        return (state, state.to_gradio_chatbot(), "", None, None)

      if image is not None:

        if '<image>' not in text:

            text = text + '\n<image>'

        text = (text, image)

      if video is not None:

        num_frames = 4

        if '<image>' not in text:

          text = text + '\n<image>' * num_frames

        text = (text, video)

      state.append_message(state.roles[0], text)
      state.append_message(state.roles[1], None)
      state.skip_next = False

      yield (state, state.to_gradio_chatbot(), "", None, None)

      if state.skip_next:
        # This generate call is skipped due to invalid inputs
        yield (state, state.to_gradio_chatbot(), "", None, None)
        return

      prompt = after_process_image(state.get_prompt())
      images = state.get_images()

      data = {
          "text_input": prompt,
          "images": images if len(images) > 0 else [],
          "generation_config": {
              "top_k": int(top_k),
              "top_p": float(top_p),
              "num_beams": int(num_beams),
              "no_repeat_ngram_size": int(no_repeat_ngram_size),
              "length_penalty": float(length_penalty),
              "do_sample": bool(do_sample),
              "temperature": float(temperature),
              "max_new_tokens": min(int(max_output_tokens), 1536),
              }
          }
      state.messages[-1][-1] = "▌"
      yield (state, state.to_gradio_chatbot(), "",None, None)

      try:
          for chunk in model.predict(data):
              if chunk:
                  if chunk[1]:
                      output = chunk[0].strip()
                      output = post_process_code(output)
                      state.messages[-1][-1] = output + "▌"
                      yield (state, state.to_gradio_chatbot(), "", None, None)
                  else:
                      output = chunk[0].strip()
                      state.messages[-1][-1] = output
                      yield (state, state.to_gradio_chatbot(), "", None, None)
                      return
                  time.sleep(0.03)

      except requests.exceptions.RequestException as e:
          state.messages[-1][-1] = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
          yield (state, state.to_gradio_chatbot(), "", None, None)
          return

      state.messages[-1][-1] = state.messages[-1][-1][:-1]
      yield (state, state.to_gradio_chatbot(), "", None, None)

    ######################################- GRADIO FRONTEND FUNCTIONS - ################################################

    def __call__(self):
        """
        Function that creates the frontend Gradio interface
        """

        # Creating a block for the app
        with gr.Blocks(title="Demo for Sentiment Detection using multimodal LLMs") as interface:

            state = gr.State()

            # title of the demo
            title = "Demo for Sentiment Detection using Multimodal LLMs"

            # description of the demo
            description = "Gradio initial demo for a proposed Sentiment/Emotion Detection using Multimodal LLM"

            # Setting up Markdown for Title and Description
            gr.Markdown(value=f"# <p style=\"text-align: center;\"> {title} </p>")
            gr.Markdown(value=f"#### {description}")

            with gr.Row():
              with gr.Column():

                imagebox = gr.Image(type="pil", visible=True)
                videobox = gr.Video()

                with gr.Accordion("Parameters", open=True, visible=True) as parameterRow:

                    max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
                    temperature = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, interactive=True, label="Temperature",)
                    top_k = gr.Slider(minimum=1, maximum=5, value=3, step=1, interactive=True, label="Top K",)
                    top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, interactive=True, label="Top p",)
                    length_penalty = gr.Slider(minimum=1, maximum=5, value=1, step=0.1, interactive=True, label="length_penalty",)
                    num_beams = gr.Slider(minimum=1, maximum=5, value=1, step=1, interactive=True, label="Beam Size",)
                    no_repeat_ngram_size = gr.Slider(minimum=1, maximum=5, value=2, step=1, interactive=True, label="no_repeat_ngram_size",)
                    do_sample = gr.Checkbox(interactive=True, value=True, label="do_sample")

              with gr.Column():
                chatbot = grChatbot(elem_id="chatbot", visible=True).style(height=800)
                with gr.Row():
                  with gr.Column(scale=8):
                    textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", visible=True).style(container=False)
                  with gr.Column(scale=1, min_width=60):
                    submitBtn = gr.Button(value="Submit", visible=True)
                with gr.Row(visible=True) as buttonRow:
                  clearBtn = gr.Button(value="🗑️  Clear history", interactive=True)

            btnList = [clearBtn]

            parameter_list = [max_output_tokens, temperature, top_k,
                              top_p, num_beams, no_repeat_ngram_size, length_penalty, do_sample]

            clearBtn.click(self.clearHistory, None, [state, chatbot, textbox, videobox])

            textbox.submit(self.addTextHttpBot,
             [state, textbox, imagebox, videobox] + parameter_list,
             [state, chatbot, textbox, imagebox, videobox]
                           )

            submitBtn.click(self.addTextHttpBot,
               [state, textbox, imagebox, videobox] + parameter_list,
               [state, chatbot, textbox, imagebox, videobox]
             )

            interface.load(self.loadInterface, [state], [state])

        # Launch interface
        #interface.launch(share=True, debug=True)
        return interface

########################################################################################################################

if torch.cuda.is_available():
  device = "cuda"
else:
  device = "cpu"

model = mPLUG_Owl_Server(
    base_model='MAGAer13/mplug-owl-llama-7b',
    load_in_8bit=True,
    bf16=True,
    device=device,
    )

# deploy UI
gradioUI = DemoGUIGradio()
interface = gradioUI()
interface.queue(concurrency_count=3, status_update_rate=10, api_open=False).launch(debug=True, share=True)

########################################################################################################################



In [None]:
# @title LLaVA
"""
This code is written and maintained by Aditya Ramanath Poonja
huggingFace : https://huggingface.co/pooadi
GitHub      : https://github.com/pooadi
"""


# code to mount my drive
!nvidia-smi
from google.colab import drive

drive.mount('/content/drive')
%cd /content/drive/MyDrive/ChatAI_Project/LLaVA

import argparse
import datetime
import json
import os
import time

import gradio as gr
import requests

from llava.conversation import (default_conversation, conv_templates,
                                   SeparatorStyle)
from llava.constants import LOGDIR
from llava.utils import (build_logger, server_error_msg,
    violates_moderation, moderation_msg)
import hashlib

################################################################################################


logger = build_logger("gradio_web_server", "gradio_web_server.log")

headers = {"User-Agent": "LLaVA Client"}

no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=False)

priority = {
    "vicuna-13b": "aaaaaaa",
    "koala-13b": "aaaaaab",
}


##################################################################################################

def get_conv_log_filename():

    t = datetime.datetime.now()
    name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
    return name

##################################################################################################

def get_model_list():
    ret = requests.post(args.controller_url + "/refresh_all_workers")
    assert ret.status_code == 200
    ret = requests.post(args.controller_url + "/list_models")
    models = ret.json()["models"]
    models.sort(key=lambda x: priority.get(x, x))
    logger.info(f"Models: {models}")
    return models

####################################################################################################

get_window_url_params = """
function() {
    const params = new URLSearchParams(window.location.search);
    url_params = Object.fromEntries(params);
    console.log(url_params);
    return url_params;
    }
"""

####################################################################################################

def load_demo(url_params, request: gr.Request):

    logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")

    dropdown_update = gr.Dropdown.update(visible=True)

    if "model" in url_params:

        model = url_params["model"]

        if model in models:

            dropdown_update = gr.Dropdown.update(
                value=model, visible=True)

    state = default_conversation.copy()

    return (state,
            dropdown_update,
            gr.Chatbot.update(visible=True),
            gr.Textbox.update(visible=True),
            gr.Button.update(visible=True),
            gr.Row.update(visible=True),
            gr.Accordion.update(visible=True))

####################################################################################################

def load_demo_refresh_model_list(request: gr.Request):

    logger.info(f"load_demo. ip: {request.client.host}")

    models = get_model_list()
    state = default_conversation.copy()

    return (state, gr.Dropdown.update(
               choices=models,
               value=models[0] if len(models) > 0 else ""),
            gr.Chatbot.update(visible=True),
            gr.Textbox.update(visible=True),
            gr.Button.update(visible=True),
            gr.Row.update(visible=True),
            gr.Accordion.update(visible=True))

####################################################################################################

def clear_history(request: gr.Request):
    logger.info(f"clear_history. ip: {request.client.host}")
    state = default_conversation.copy()
    return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5

####################################################################################################

def add_text(state, text, image, image_process_mode, request: gr.Request):

    logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")

    if len(text) <= 0 and image is None:

        state.skip_next = True
        return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5

    if args.moderate:

        flagged = violates_moderation(text)

        if flagged:

            state.skip_next = True
            return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
                no_change_btn,) * 5

    text = text[:1536]  # Hard cut-off

    if image is not None:

        text = text[:1200]  # Hard cut-off for images

        if '<image>' not in text:

            # text = '<Image><image></Image>' + text

            text = text + '\n<image>'

        text = (text, image, image_process_mode)

        if len(state.get_images(return_pil=True)) > 0:

            state = default_conversation.copy()

    state.append_message(state.roles[0], text)
    state.append_message(state.roles[1], None)
    state.skip_next = False

    return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5

####################################################################################################

def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):

    logger.info(f"http_bot. ip: {request.client.host}")
    start_tstamp = time.time()
    model_name = model_selector

    if state.skip_next:
        # This generate call is skipped due to invalid inputs
        yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
        return

    if len(state.messages) == state.offset + 2:

        # First round of conversation
        if "llava" in model_name.lower():
            if 'llama-2' in model_name.lower():
                template_name = "llava_llama_2"
            elif "v1" in model_name.lower():
                if 'mmtag' in model_name.lower():
                    template_name = "v1_mmtag"
                elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
                    template_name = "v1_mmtag"
                else:
                    template_name = "llava_v1"
            elif "mpt" in model_name.lower():
                template_name = "mpt"
            else:
                if 'mmtag' in model_name.lower():
                    template_name = "v0_mmtag"
                elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
                    template_name = "v0_mmtag"
                else:
                    template_name = "llava_v0"
        elif "mpt" in model_name:
            template_name = "mpt_text"
        elif "llama-2" in model_name:
            template_name = "llama_2"
        else:
            template_name = "vicuna_v1"
        new_state = conv_templates[template_name].copy()
        new_state.append_message(new_state.roles[0], state.messages[-2][1])
        new_state.append_message(new_state.roles[1], None)
        state = new_state

    # Query worker address
    controller_url = args.controller_url
    ret = requests.post(controller_url + "/get_worker_address",
            json={"model": model_name})
    worker_addr = ret.json()["address"]
    logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")

    # No available worker
    if worker_addr == "":
        state.messages[-1][-1] = server_error_msg
        yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
        return

    # Construct prompt
    prompt = state.get_prompt()

    all_images = state.get_images(return_pil=True)
    all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
    for image, hash in zip(all_images, all_image_hash):
        t = datetime.datetime.now()
        filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
        if not os.path.isfile(filename):
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            image.save(filename)

    # Make requests
    pload = {
        "model": model_name,
        "prompt": prompt,
        "temperature": float(temperature),
        "top_p": float(top_p),
        "max_new_tokens": min(int(max_new_tokens), 1536),
        "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
        "images": f'List of {len(state.get_images())} images: {all_image_hash}',
    }
    logger.info(f"==== request ====\n{pload}")

    pload['images'] = state.get_images()

    state.messages[-1][-1] = "▌"
    yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5

    try:
        # Stream output
        response = requests.post(worker_addr + "/worker_generate_stream",
            headers=headers, json=pload, stream=True, timeout=10)
        for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
            if chunk:
                data = json.loads(chunk.decode())
                if data["error_code"] == 0:
                    output = data["text"][len(prompt):].strip()
                    state.messages[-1][-1] = output + "▌"
                    yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
                else:
                    output = data["text"] + f" (error_code: {data['error_code']})"
                    state.messages[-1][-1] = output
                    yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
                    return
                time.sleep(0.03)
    except requests.exceptions.RequestException as e:
        state.messages[-1][-1] = server_error_msg
        yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
        return

    state.messages[-1][-1] = state.messages[-1][-1][:-1]
    yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5

    finish_tstamp = time.time()
    logger.info(f"{output}")

    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4),
            "type": "chat",
            "model": model_name,
            "start": round(start_tstamp, 4),
            "finish": round(start_tstamp, 4),
            "state": state.dict(),
            "images": all_image_hash,
            "ip": request.client.host,
        }
        fout.write(json.dumps(data) + "\n")

####################################################################################################

title_markdown = ("""
# 🌋 LLaVA: Large Language and Vision Assistant
[[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0)
""")

tos_markdown = ("""
### Terms of use
By using this service, users are required to agree to the following terms:
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
""")


learn_more_markdown = ("""
### License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
""")

####################################################################################################

def build_demo(embed_mode):
    textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", visible=False, container=False)
    with gr.Blocks(title="LLaVA", theme=gr.themes.Base()) as demo:
        state = gr.State()

        if not embed_mode:
            gr.Markdown(title_markdown)

        with gr.Row():
            with gr.Column(scale=3):
                with gr.Row(elem_id="model_selector_row"):
                    model_selector = gr.Dropdown(
                        choices=models,
                        value=models[0] if len(models) > 0 else "",
                        interactive=True,
                        show_label=False,
                        container=False)

                imagebox = gr.Image(type="pil")
                image_process_mode = gr.Radio(
                    ["Crop", "Resize", "Pad"],
                    value="Crop",
                    label="Preprocess for non-square image")

                cur_dir = os.path.dirname(os.path.abspath(__file__))
                gr.Examples(examples=[
                    [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
                    [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
                ], inputs=[imagebox, textbox])

                with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
                    temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
                    top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
                    max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)

            with gr.Column(scale=6):
                chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", visible=False, height=550)
                with gr.Row():
                    with gr.Column(scale=8):
                        textbox.render()
                    with gr.Column(scale=1, min_width=60):
                        submit_btn = gr.Button(value="Submit", visible=False)
                with gr.Row(visible=False) as button_row:
                    upvote_btn = gr.Button(value="👍  Upvote", interactive=False)
                    downvote_btn = gr.Button(value="👎  Downvote", interactive=False)
                    flag_btn = gr.Button(value="⚠️  Flag", interactive=False)
                    #stop_btn = gr.Button(value="⏹️  Stop Generation", interactive=False)
                    regenerate_btn = gr.Button(value="🔄  Regenerate", interactive=False)
                    clear_btn = gr.Button(value="🗑️  Clear history", interactive=False)

        if not embed_mode:
            gr.Markdown(tos_markdown)
            gr.Markdown(learn_more_markdown)
        url_params = gr.JSON(visible=False)

        # Register listeners
        btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
        upvote_btn.click(upvote_last_response,
            [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
        downvote_btn.click(downvote_last_response,
            [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
        flag_btn.click(flag_last_response,
            [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
        regenerate_btn.click(regenerate, [state, image_process_mode],
            [state, chatbot, textbox, imagebox] + btn_list).then(
            http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
            [state, chatbot] + btn_list)
        clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)

        textbox.submit(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
            ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
                   [state, chatbot] + btn_list)
        submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
            ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
                   [state, chatbot] + btn_list)

        if args.model_list_mode == "once":
            demo.load(load_demo, [url_params], [state, model_selector,
                chatbot, textbox, submit_btn, button_row, parameter_row],
                _js=get_window_url_params)
        elif args.model_list_mode == "reload":
            demo.load(load_demo_refresh_model_list, None, [state, model_selector,
                chatbot, textbox, submit_btn, button_row, parameter_row])
        else:
            raise ValueError(f"Unknown model list mode: {args.model_list_mode}")

    return demo

####################################################################################################

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="0.0.0.0")
    parser.add_argument("--port", type=int)
    parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
    parser.add_argument("--concurrency-count", type=int, default=8)
    parser.add_argument("--model-list-mode", type=str, default="once",
        choices=["once", "reload"])
    parser.add_argument("--share", action="store_true")
    parser.add_argument("--moderate", action="store_true")
    parser.add_argument("--embed", action="store_true")
    args = parser.parse_args()
    logger.info(f"args: {args}")

    models = get_model_list()

    logger.info(args)
    demo = build_demo(args.embed)
    demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10,
               api_open=False).launch(
        server_name=args.host, server_port=args.port, share=args.share)

