# BLIP: Inference Demo
 - [Image Captioning](#Image-Captioning)
 - [VQA](#VQA)
 - [Feature Extraction](#Feature-Extraction)
 - [Image Text Matching](#Image-Text-Matching)

In [None]:
#Based on the publicly available demo for BLIP: https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb#scrollTo=a811a65f
#To replicate with this notebook, you need to run from the above demo link.

# install requirements
import sys
if 'google.colab' in sys.modules:
    print('Running in Colab.')
    !pip3 install transformers==4.15.0 timm==0.4.12 fairscale==0.4.4
    !git clone https://github.com/salesforce/BLIP
    %cd BLIP

In [None]:
from PIL import Image
import requests
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#Default code from BLIP demo
def load_demo_image(image_size,device):
    img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' 
    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')   

    w,h = raw_image.size
    display(raw_image.resize((w//5,h//5)))
    
    transform = transforms.Compose([
        transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ]) 
    image = transform(raw_image).unsqueeze(0).to(device)   
    return image

#Added a function to load and preprocess images
def load_image(img_url,image_size,device):
    raw_image = Image.open(img_url).convert('RGB')   
    w,h = raw_image.size
    display(raw_image.resize((w//5,h//5)))

    transform = transforms.Compose([
        transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ]) 
    image = transform(raw_image).unsqueeze(0).to(device)   
    return image

In [None]:
#Import Drive to read in CFD images and subsequently write data
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from models.blip import blip_decoder

image_size = 384
image = load_demo_image(image_size=image_size, device=device)

model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'

model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)

#Image source
IMG_SOURCE = f'/content/BLIP/cfd/'
images = listdir('/content/BLIP/cfd/')

caption_string = ''

for img in images:
  image = load_image(f'{IMG_SOURCE}{img}',image_size,device)
  with torch.no_grad():
    for p in [0,.1,.2,.3,.4,.5,.6,.7,.8,.9]:
      caption = model.generate(image, sample=True, top_p=p, max_length=20, min_length=5)[0]
      caption_string += f'{img}\t{p}\t{caption}'

with open(f'/content/drive/My Drive/blip_results/blip_caption_string.txt','w') as blip_writer:
  blip_writer.write(caption_string)

In [None]:
#Visual Question Answering Tasks using BLIP VQA
from models.blip_vqa import blip_vqa
from os import listdir

#Model Parameters
image_size = 480
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
    
#Model Initialization
model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)

#Image source
IMG_SOURCE = f'/content/BLIP/cfd/'
images = listdir('/content/BLIP/cfd/')

#American Question
caption_string = ''

question = 'is this person an american?'

for img in images:
  image = load_image(f'{IMG_SOURCE}{img}',image_size,device)

  with torch.no_grad():
    answer = model(image, question, train=False, inference='generate') 
    caption_string += f'{img} :: {answer[0]}\n'

with open(f'/content/drive/My Drive/blip_results/blip_american_string.txt','w') as blip_writer:
  blip_writer.write(caption_string)

#State Question
caption_string = ''

question = 'what state does this person live in?'

for img in images:
  image = load_image(f'{IMG_SOURCE}{img}',image_size,device)

  with torch.no_grad():
    answer = model(image, question, train=False, inference='generate') 
    caption_string += f'{img} :: {answer[0]}\n'

with open(f'/content/drive/My Drive/blip_results/blip_state_string.txt','w') as blip_writer:
  blip_writer.write(caption_string)

In [None]:
#Strings for Text Embeddings
states_dict = {
    'Alaska':'ak',
    'Alabama':'al',
    'Arkansas':'ar',
    'Arizona':'az',
    'California':'ca',
    'Colorado':'co',
    'Connecticut':'ct',
    'Delaware':'de',
    'Florida':'fl',
    'Georgia':'ga',
    'Hawaii':'hi',
    'Iowa':'ia',
    'Idaho':'id',
    'Indiana':'in',
    'Illinois':'il',
    'Kansas':'ks',
    'Kentucky':'ky',
    'Louisiana':'la',
    'Maine':'me',
    'Massachusetts':'ma',
    'Maryland':'md',
    'Michigan':'mi',
    'Minnesota':'mn',
    'Missouri':'mo',
    'Mississippi':'ms',
    'Montana':'mt',
    'North Carolina':'nc',
    'Nebraska':'ne',
    'New Hampshire':'nh',
    'New Jersey':'nj',
    'New Mexico':'nm',
    'Nevada':'nv',
    'New York':'ny',
    'North Dakota':'nd',
    'Ohio':'oh',
    'Oklahoma':'ok',
    'Oregon':'or',
    'Pennsylvania':'pa',
    'Rhode Island':'ri',
    'South Carolina':'sc',
    'South Dakota':'sd',
    'Tennessee':'tn',
    'Texas':'tx',
    'Utah':'ut',
    'Virginia':'va',
    'Vermont':'vt',
    'Washington':'wa',
    'West Virginia':'wv',
    'Wisconsin':'wi',
    'Wyoming':'wy',
    'District of Columbia':'dc',
}

#We/They WEAT
american_stimuli = ['we','us','our','ours','ourselves','familiar','similar','here']
foreign_stimuli = ['they','them','their','theirs','themselves','other','others','there']

#Defining American Target Stimuli
defining_american_targets = [f'a photo of someone who is patriotic',f'a photo of someone who is an immigrant to America',f'a photo of someone who treats people of all races and backgrounds equally']

#State Stimuli
states = list(states_dict.keys())
state_someone = [f'a photo of someone who lives in the state of {state}' for state in states] + [f'a photo of someone who lives in Washington, D.C.']

all_language_targets = american_stimuli + foreign_stimuli + defining_american_targets + states + state_someone


In [None]:
#Image Text Matching with BLIP ITM embeddings
from models.blip_itm import blip_itm
import numpy as np
import pandas as pd
from os import listdir
from PIL import Image

#Standard image preprocessing
def process_image(image,image_size,device):
    raw_image = Image.open(image).convert('RGB')   

    w,h = raw_image.size
    display(raw_image.resize((w//5,h//5)))
    
    transform = transforms.Compose([
        transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ]) 
    image = transform(raw_image).unsqueeze(0).to(device)   
    return image

#Function to get image embedding
def get_image_embedding(image):
  with torch.no_grad():
    image_embeds = model.visual_encoder(image)
    image_feat = model.vision_proj(image_embeds[:,0,:]).numpy().squeeze()
  return image_feat

#Function to get text embedding
def get_text_embedding(text):
  with torch.no_grad():
      text = model.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 
                              return_tensors="pt").to(image.device)

      text_embeds = model.text_encoder(text.input_ids, attention_mask = text.attention_mask,                      
                                            return_dict = True, mode = 'text')
      text_feat = model.text_proj(text_embeds.last_hidden_state[:,0,:]).numpy().squeeze()

  return text_feat

#Model Parameters
image_size = 384
MODEL_ = 'blip_itm'
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
    
#Model Initialization
model = blip_itm(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)

#Get CFD Embeddings
image_source = f'/content/drive/My Drive/CFD'
image_list = listdir(image_source)

img_embs = []

for image in image_list:
  processed_img = process_image(image,image_size,device)
  img_emb = get_image_embedding(processed_img)
  img_embs.append(img_emb)

emb_arr = np.array(img_embs)
emb_df = pd.DataFrame(emb_arr,index=image_list)
emb_df.to_csv(f'/content/drive/My Drive/blip_embeddings/embedding_df_{MODEL_}.vec',sep = ' ')

#Get text embeddings
text_embs = []

for phrase in all_language_targets:
  text_emb = get_text_embedding(phrase)
  text_embs.append(text_emb)

emb_arr = np.array(text_embs)
emb_df = pd.DataFrame(emb_arr,index=all_language_targets)

emb_df.to_csv(f'/content/drive/My Drive/blip_embeddings/lang_df_{MODEL_}.vec',sep = ' ')