In [1]:
import yaml
from models.visionllama import VisionLLaMA
from dataset.objectCOCO import COCOObjectDataset


with open("configs/config.yaml", 'r') as file:
        config = yaml.safe_load(file)

config['n_patches'] = 16
if "336" in config['vision_encoder']:
config['n_patches'] = 24
        
config['use_retrieval'] = True
dataset = COCOObjectDataset(config, split="train", n_patches=config['n_patches'], max_examples_per_class = config["examples_per_class"])
model = None
old_config = None


  from .autonotebook import tqdm as notebook_tqdm


KeyError: 'n_patches'

In [None]:
import gradio as gr
import cv2
import time
import numpy as np
import matplotlib.pyplot as plt
import io
import os
import skimage
import math
import torch
from PIL import Image

def _get_ViT_mask(mask, height, width, output_height, output_width):
        
    
    pooled_mask = skimage.measure.block_reduce(mask, block_size=(math.floor(height / output_height), math.floor(width / output_width)), func=np.max)

    result_height, result_width = pooled_mask.shape
    # If the result is smaller than 16x16, pad it with zeros
    if result_height < output_height or result_width < output_width:
        pad_height = output_height - result_height
        pad_width = output_width - result_width
        pooled_mask = np.pad(pooled_mask, ((0, pad_height), (0, pad_width)), mode='constant')

    if result_height > output_height or result_width > output_width:
        pooled_mask = pooled_mask[:output_height, :output_width]

    assert pooled_mask.shape == (output_height,output_width)
    return torch.BoolTensor(np.append(1, pooled_mask.flatten()))

def sleep(im):
    time.sleep(2)
    ret = [im["background"]]
    for layer in im["layers"]:
        ret.append(layer)
    return ret
    #return [im["background"], im["layers"][0], im["layers"][1], im['composite']]

def show_anns(anns, ax):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

def generate_predictions(question, images, task, backbone, use_retrieval, use_object_encoder_checkpoint, freeze_llm, chat_history):
    global model
    global old_config
    print(images)
    image = Image.open(images[0][0]).convert('RGB')
    segmentations = [Image.open(x[0]).convert('RGB') for x in images[1:]]
    
    config['freeze_llm'] = freeze_llm
    config['llm_model'] = backbone
    config['task'] = task

    if not use_object_encoder_checkpoint:
        config['pretrained_object_encoder_checkpoint'] = "None"
    else:
        config['pretrained_object_encoder_checkpoint'] = "./llama_2_7b_adapter"

    config['use_retrieval'] = use_retrieval

    if "llama" or "gpt2" in backbone:
        if "336" in config["vision_encoder"]:
            output_width, output_height = 24, 24
        else:
            output_width, output_height = 16, 16
        
    elif "llava" in backbone:
        output_width, output_height = 24, 24

    config['n_patches'] = output_width
    print(config['n_patches'])


    if old_config != config:
        if config['use_retrieval']:
            model = VisionLLaMA(config, retrieval_fn = lambda x, y: dataset.retrieve_closest(x, config["retrieval_k"], b_num = y))    

        else:
            model = VisionLLaMA(config)
        model.load()
        model.eval()
        old_config = config.copy()

    seg_width, seg_height = image.size
    
    vit_masks = []

    
    cropped_images = []
    for segmentation in segmentations:
        seg = np.array(segmentation)
        if np.sum(seg, axis = None) == 0:
            continue
        else:
        
            mask = np.any(seg != [0, 0, 0], axis=-1)

            if config["crop_image"]:
                # print(mask.shape)
                # print(img.shape)
                img = np.array(image)
                img[~mask] = np.array([255,255,255])
                
                
                # Find the indices of non-zero elements in the binary mask
                non_zero_indices = np.where(mask)

                # Get the minimum and maximum values along each axis
                min_x, min_y = np.min(non_zero_indices[1]), np.min(non_zero_indices[0])
                max_x, max_y = np.max(non_zero_indices[1]), np.max(non_zero_indices[0])

                #img = img[min_x: max_x, min_y: max_y]
                img = img[min_y: max_y, min_x: max_x]

                cropped_image = Image.fromarray(np.uint8(img)).convert('RGB')
                cropped_images.append(cropped_image)
                print(min_x, min_y, max_x, max_y)
            
            vit_masks.append(_get_ViT_mask(mask, seg_height, seg_width, output_height, output_width))
       
       #print(_get_ViT_mask(mask, seg_height, seg_width).shape)
    if len(vit_masks) > 0:
        vit_masks = torch.stack(vit_masks, axis = 0)
    imgs = [image] * len(vit_masks) if len(vit_masks) > 0 else [image]
 

    prompts = None
    masks = None
    images = None
    if config['use_retrieval']:
        output, prompts, masks, images = model.generate(vit_masks, imgs, [question], return_retrieved_info=True, cropped_images = cropped_images)
        print(prompts)
        chat_history.append((question, output))
        retrieval_images = [Image.open(images[0][x]) for x in range(len(images[0]))]
        #retrieval_images += [cropped_image]
        return chat_history, retrieval_images
        
    else:
        output = model.generate(vit_masks, imgs, [question])
        chat_history.append((question, output))
        return chat_history, None
    
    #print(prompts, masks, images)
    # fig, ax = plt.subplots()
    # ax.imshow(image)
    # show_anns(masks, ax)
    # plt.axis('off')

    # img_buf = io.BytesIO()
    # plt.savefig(img_buf, format='png')

    

    

with gr.Blocks(title="Olive", theme=gr.themes.Base()).queue() as demo:
    
    with gr.Row():
        
        with gr.Column():
            with gr.Row():
                with gr.Column():
                    
                    im = gr.ImageEditor(
                        type="pil"
                        
                    )
                    with gr.Row():
                        gallery = gr.Gallery(
                            label="Segmentations", show_label=False, elem_id="gallery"
                        , columns=[3], rows=[1], object_fit="contain", height=200)


                        
                        # im_out_1 = gr.Image(type="pil")
                        # im_out_2 = gr.Image(type="pil")
                        # im_out_3 = gr.Image(type="pil")
                        # im_out_4 = gr.Image(type="pil")
                with gr.Column():
                    chatbot = gr.Chatbot(elem_id="chatbot", label="OLIVE Chatbot", height=300)
                    with gr.Row():
                        
                        with gr.Column(scale=8):
                            textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
                        with gr.Column(scale=1, min_width=50):
                            submit_btn = gr.Button(value="Send", variant="primary")
                    retrieval_gallery = gr.Gallery(
                            label="Retrieved Images", show_label=True, elem_id="gallery2"
                        , columns=[5], rows=[1], object_fit="contain", height=100)
   
                    #task = gr.Dropdown(["object_classification", "relation_prediction", "image_captioning", "refCOCOg", "GRIT", "OCR", "ALL", "PointQA", "ObjectInstruct"], label="Task",  info="For now object classification/image captioning", value="object_classification")
                    task = gr.Dropdown(["object_classification", "relation_prediction", "image_captioning", "refCOCOg", "GRIT", "OCR", "ALL", "PointQA", "ObjectInstruct"], label="Task",  info="For now object classification/image captioning", value="object_classification")
                    
                    backbone = gr.Dropdown(["llava-hf/llava-1.5-7b-hf", "meta-llama/Llama-2-7b-chat-hf", "gpt2"], label="Decoder Backbone",  info="Backbone Frozen LLM/VLM", value="meta-llama/Llama-2-7b-chat-hf")
                    
                    freeze_llm = gr.Checkbox(label="freeze llm", info="Freeze llm weights", value=True)
                    obj_encoder_checkpoint = gr.Checkbox(label="obj_encoder", info="Use object encoder checkpoint")
                    use_retrieval = gr.Checkbox(label="use retrieval", info="Use retrieval to understand prediction")



                    

        #btn = gr.Button()
        
        im.change(sleep, outputs=[gallery], inputs=im) 

        
            

        #with gr.Column():    
            # image_output = gr.AnnotatedImage(width=600, height=600)
            # output_text = gr.Textbox()
            # gr.Examples(
            #     examples=[[os.path.join("example_pictures", "cat_and_dog.png"), "Image Captioning", 32, True, True],
            #               [os.path.join("example_pictures", "cat_and_dog.png"), "Object Classification", 5, True, True],
            #               [os.path.join("example_pictures", "birds.png"), "Object Classification", 5, True, True]],
            #     inputs=[im_out_1, task, n_layers, obj_encoder_checkpoint, freeze_llm],
            #     outputs=[image_output, output_text],
            #     fn=generate_predictions,
            #     cache_examples=False,
            # )
            # text_output1 = gr.HighlightedText(
            #                     label="Generated Description", 
            #                     combine_adjacent=False,
            #                     show_legend=True,
            #                 ).style(color_map={"box": "red"})


    submit_btn.click(fn=generate_predictions, 
                        inputs=[textbox, gallery, task, backbone, use_retrieval, obj_encoder_checkpoint, freeze_llm, chatbot],  
                        outputs=[chatbot, retrieval_gallery],  
                        show_progress=True, queue=True)

demo.launch(inbrowser=True)


Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




None


Traceback (most recent call last):
  File "/work/ossowski/anaconda3/envs/llava/lib/python3.10/site-packages/gradio/queueing.py", line 522, in process_events
    response = await route_utils.call_process_api(
  File "/work/ossowski/anaconda3/envs/llava/lib/python3.10/site-packages/gradio/route_utils.py", line 260, in call_process_api
    output = await app.get_blocks().process_api(
  File "/work/ossowski/anaconda3/envs/llava/lib/python3.10/site-packages/gradio/blocks.py", line 1689, in process_api
    result = await self.call_function(
  File "/work/ossowski/anaconda3/envs/llava/lib/python3.10/site-packages/gradio/blocks.py", line 1255, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/ua/ossowski/.local/lib/python3.10/site-packages/anyio/to_thread.py", line 33, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "/ua/ossowski/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
    re

[('/tmp/gradio/ed72398b3d6752e9c0fd4a098bd59b07df7134ac/image.png', None), ('/tmp/gradio/38c8b5c0119a88b95cc9b6bec238ce0456f7140c/image.png', None)]
24


  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 3/3 [00:08<00:00,  2.99s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Initialized model with llava-hf/llava-1.5-7b-hf LLM backbone and openai/clip-vit-large-patch14-336 Vision Encoder
The save path is: ./checkpoints/llava_finetuned_checkpoints/ObjectInstruct/finetuned_llm_clip_336_24x24_patches


Traceback (most recent call last):
  File "/work/ossowski/anaconda3/envs/llava/lib/python3.10/site-packages/gradio/queueing.py", line 522, in process_events
    response = await route_utils.call_process_api(
  File "/work/ossowski/anaconda3/envs/llava/lib/python3.10/site-packages/gradio/route_utils.py", line 260, in call_process_api
    output = await app.get_blocks().process_api(
  File "/work/ossowski/anaconda3/envs/llava/lib/python3.10/site-packages/gradio/blocks.py", line 1689, in process_api
    result = await self.call_function(
  File "/work/ossowski/anaconda3/envs/llava/lib/python3.10/site-packages/gradio/blocks.py", line 1255, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/ua/ossowski/.local/lib/python3.10/site-packages/anyio/to_thread.py", line 33, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "/ua/ossowski/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
    re