<a href="https://colab.research.google.com/github/salesforce/LAVIS/blob/main/projects/img2prompt-vqa/img2prompt_vqa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Img2Prompt-VQA: Inference Demo

In [None]:
# install requirements
import sys
if 'google.colab' in sys.modules:
    print('Running in Colab.')
    !git clone https://github.com/salesforce/LAVIS
    %cd LAVIS
    !pip install .
    !pip3 install https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz
else:
    !pip install omegaconf
    %cd ../..
    !pip install .
    !pip3 install https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz

%cd projects/img2prompt-vqa

In [None]:
import torch
import requests
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np

from lavis.common.gradcam import getAttMap
from lavis.models import load_model_and_preprocess

### Load LLM to use

In [None]:
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM

def load_model(model_selection):
    model = AutoModelForCausalLM.from_pretrained(model_selection)
    tokenizer = AutoTokenizer.from_pretrained(model_selection, use_fast=False)
    return model,tokenizer

# Choose LLM to use
# weights for OPT-6.7B/OPT-13B/OPT-30B/OPT-66B will download automatically
print("Loading Large Language Model (LLM)...")
llm_model, tokenizer = load_model('facebook/opt-6.7b')  # ~13G (FP16)
# llm_model, tokenizer = load_model('facebook/opt-13b') # ~26G (FP16)
# llm_model, tokenizer = load_model('facebook/opt-30b') # ~60G (FP16)
# llm_model, tokenizer = load_model('facebook/opt-66b') # ~132G (FP16)

# you need to manually download weights, in order to use OPT-175B
# https://github.com/facebookresearch/metaseq/tree/main/projects/OPT
# llm_model, tokenizer = load_model('facebook/opt-175b')

### Load an example image and question

In [None]:
# img_url = 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/projects/pnp-vqa/demo.png'
# raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')

raw_image = Image.open("./demo.png").convert("RGB")
question = "What item s are spinning which can be used to control electric?"
print(question)
display(raw_image.resize((400, 300)))

In [None]:
# setup device to use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

### Load Img2Prompt-VQA model

In [None]:
model, vis_processors, txt_processors = load_model_and_preprocess(name="img2prompt_vqa", model_type="base", is_eval=True, device=device)

### Preprocess image and text inputs

In [None]:
image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
question = txt_processors["eval"](question)

samples = {"image": image, "text_input": [question]}

### Img2Prompt-VQA utilizes 4 submodels to perform VQA:
#### 1. Image-Question Matching 
Compute the relevancy score of image patches with respect to the question using GradCAM

In [None]:
samples = model.forward_itm(samples=samples)

In [None]:
# Gradcam visualisation
dst_w = 720
w, h = raw_image.size
scaling_factor = dst_w / w

resized_img = raw_image.resize((int(w * scaling_factor), int(h * scaling_factor)))
norm_img = np.float32(resized_img) / 255
gradcam = samples['gradcams'].reshape(24,24)

avg_gradcam = getAttMap(norm_img, gradcam, blur=True)

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(avg_gradcam)
ax.set_yticks([])
ax.set_xticks([])
print('Question: {}'.format(question))

#### 2. Image Captioning
Generate question-guided captions based on the relevancy score

In [None]:
samples = model.forward_cap(samples=samples, num_captions=50, num_patches=20)
print('Examples of question-guided captions: ')
samples['captions'][0][:5]

#### 3. Question Generation
Generate synthetic questions using the captions

In [None]:
samples = model.forward_qa_generation(samples)
print('Sample Question: {} \nSample Answer: {}'.format(samples['questions'][:5], samples['answers'][:5]))

In [None]:
samples['questions']

#### 4. Prompt Construction
Prepare the prompts for LLM

In [None]:
Img2Prompt = model.prompts_construction(samples)

#### 4. Load LLM and Predict Answers


In [None]:
# In this notebook, we only use CPU for LLM inference
# To run inference on GPU, see https://github.com/CR-Gjx/Img2Prompt for reference
device = "cpu"

def postprocess_Answer(text):
    for i, ans in enumerate(text):
        for j, w in enumerate(ans):
            if w == '.' or w == '\n':
                ans = ans[:j].lower()
                break
    return ans

Img2Prompt_input = tokenizer(Img2Prompt, padding='longest', truncation=True, return_tensors="pt").to(device)

assert (len(Img2Prompt_input.input_ids[0])+20) <=2048

outputs_list  = []
outputs = llm_model.generate(input_ids=Img2Prompt_input.input_ids,
                         attention_mask=Img2Prompt_input.attention_mask,
                         max_length=20+len(Img2Prompt_input.input_ids[0]),
                         return_dict_in_generate=True,
                         output_scores=True
                         )
outputs_list.append(outputs)


#### 5. Decoding to answers

In [None]:
outputs_list

pred_answer = tokenizer.batch_decode(outputs.sequences[:, len(Img2Prompt_input.input_ids[0]):])
pred_answer = postprocess_Answer(pred_answer)

print({"question": question, "answer": pred_answer})