In [None]:
'''
Notebook adapted from https://github.com/camenduru/LLaVA-colab and modified for ContextVLM
Author: Shounak Sural
Description: This notebook demonstrates how to use the LLaVA model for detecting contexts of interest from AV driving images.
'''

%cd /content
!git clone -b v1.0 https://github.com/camenduru/LLaVA
%cd /content/LLaVA

!pip install -q transformers==4.36.2
!pip install -q gradio .

from transformers import AutoTokenizer, BitsAndBytesConfig
from llava.model import LlavaLlamaForCausalLM
import torch

model_path = "4bit/llava-v1.5-13b-3GB"
kwargs = {"device_map": "auto"}
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4'
)
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
    vision_tower.load_model()
vision_tower.to(device='cuda')
image_processor = vision_tower.image_processor

import os
import requests
from PIL import Image
from io import BytesIO
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from transformers import TextStreamer

def caption_image(image_file, prompt):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    disable_torch_init()
    conv_mode = "llava_v0"
    conv = conv_templates[conv_mode].copy()
    roles = conv.roles
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
    inp = f"{roles[0]}: {prompt}"
    inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
    conv.append_message(conv.roles[0], inp)
    conv.append_message(conv.roles[1], None)
    raw_prompt = conv.get_prompt()
    input_ids = tokenizer_image_token(raw_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    with torch.inference_mode():
      output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2,
                                  max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria])
    outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
    conv.messages[-1][-1] = outputs
    output = outputs.rsplit('</s>', 1)[0]
    return image, output

In [None]:
# import locale
# locale.getpreferredencoding = lambda: "UTF-8"
# !mkdir /content/images
# !wget --header 'Authorization: Bearer TOKEN_HERE' https://huggingface.co/camenduru/polaroid/resolve/main/style_name_fix.zip
# !unzip style_name_fix.zip -d /content/images

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
file_names = os.listdir('/content/drive/MyDrive/Driving_Contexts/pittsburgh')
sorted_file_names = sorted(file_names)

# print(sorted_file_names)
# print(len(sorted_file_names))

In [None]:
file = open('/content/drive/MyDrive/context_results_pittsburgh.txt','a')

from tqdm import tqdm
output_list={}
for file_name in tqdm(sorted_file_names):
    try:
        question="Answer the following questions about this image. Generate output (yes/no) in JSON format with no additional text. Are we driving indoors?   Are we driving outdoors?  Is this during daytime?  Is this during nighttime?  Is this during sunny weather?  Is this during cloudy weather?  Is this during rainy weather?  Is this during snowy weather?  Is this during heavy fog?  Is this on a highway? Is this in a city? Does this have skyscrapers like in an urban canyon? Is this a rural area? Is this in a tunnel? Is this on a bridgeIs this in a parking lot? Is some construction work going on here? Is this a paved road? Are there lane markers on the ground? Is there heavy traffic? Is this in a sandstorm? Is this in an area with heavy tree cover? Is this on a dirt road? Is this beneath an underpass? Is this during twilight? "
        image, output = caption_image(f'/content/drive/MyDrive/Driving_Contexts/pittsburgh/{file_name}', question)
        print(file_name,output)
        output_list[file_name]=output
        file.write(file_name+": "+output+"\n")
        # image
    except Exception as e:
        print(f"Error processing {file_name}: {str(e)}")
        continue

file.close()

In [None]:
from google.colab import files
files.download("/content/LLaVA/context_results_pittsburgh.txt")