In [1]:
from collections import Counter
import os
import json
import torch
from PIL import Image
from pathlib import Path

In [2]:
home_dir = "/root/CVPDL/hw3"
os.chdir(home_dir)

### Data Preprocess

In [3]:
def get_selected_annotation(image_id, annotations):
    selected_annotations = []
    for annotation in annotations:
        if annotation["image_id"] == image_id:
            selected_annotations.append(annotation)
    
    return selected_annotations

def read_image(input_metadata_path, data, data_dir):
    categories = data["categories"]
    images = data["images"]
    annotations = data["annotations"]
    all_images, all_metadata = [], []
    for image in images:
        metadata = {}
        raw_image = Image.open(data_dir / image["file_name"]).convert("RGB")
        all_images.append(raw_image)
        metadata["file_name"] = image["file_name"]
        metadata["height"] = image["height"]
        metadata["width"] = image["width"]
        selected_annotations = get_selected_annotation(image["id"], annotations)
        if len(selected_annotations) == 0:
            continue
        metadata["creature"] = [categories[annotation["category_id"]]["name"] for annotation in selected_annotations]
        metadata["bbox"] = [annotation["bbox"] for annotation in selected_annotations]
        metadata["n_box"] = len(selected_annotations)
        counter = Counter(metadata["creature"])
        metadata["main_creature"] = counter.most_common(1)[0][0]
        all_metadata.append(metadata)

    return all_images, all_metadata

In [4]:
# load images
input_metadata_path = Path("dataset/cvpdl/annotations/train.json")

with input_metadata_path.open("r") as f:
    data = json.load(f)

all_images, all_metadata = read_image(input_metadata_path, data, data_dir=Path("dataset/cvpdl/train"))

### Load Model

In [11]:
from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model_name = "Salesforce/blip2-opt-2.7b"
model_name = "Salesforce/blip2-opt-6.7b-coco"
#model_name = "Salesforce/blip2-flan-t5-xl"
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2ForConditionalGeneration.from_pretrained(
    model_name, torch_dtype=torch.float16
)
model.to(device)

preprocessor_config.json:   0%|          | 0.00/432 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/904 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/548 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/6.95k [00:00<?, ?B/s]

pytorch_model.bin.index.json:   0%|          | 0.00/122k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

pytorch_model-00001-of-00004.bin:   0%|          | 0.00/9.80G [00:00<?, ?B/s]

pytorch_model-00002-of-00004.bin:   0%|          | 0.00/9.93G [00:00<?, ?B/s]

pytorch_model-00003-of-00004.bin:   0%|          | 0.00/9.93G [00:00<?, ?B/s]

pytorch_model-00004-of-00004.bin:   0%|          | 0.00/2.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Blip2ForConditionalGeneration(
  (vision_model): Blip2VisionModel(
    (embeddings): Blip2VisionEmbeddings(
      (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (encoder): Blip2Encoder(
      (layers): ModuleList(
        (0-38): 39 x Blip2EncoderLayer(
          (self_attn): Blip2Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=1408, out_features=4224, bias=True)
            (projection): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
          (mlp): Blip2MLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1408, out_features=6144, bias=True)
            (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((

In [12]:
def inference_one_image(image):
    inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
    generated_ids = model.generate(**inputs)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
    return generated_text

In [13]:
from tqdm import tqdm

id = 0
for image, metadata in tqdm(zip(all_images, all_metadata), total=len(all_metadata)):
    generated_text = inference_one_image(image)
    metadata["gerated_caption"] = generated_text
    metadata["prompt_wtemp1"] = generated_text + ", creature: {}".format(metadata["main_creature"])
    metadata["prompt_wtemp2"] = generated_text + ", creature: {}, height: {}, width: {}".format(metadata["main_creature"], metadata["height"], metadata["width"])
    all_metadata[id] = metadata
    id += 1

100%|██████████| 447/447 [01:44<00:00,  4.26it/s]


### Output File

In [15]:
if model_name == "Salesforce/blip2-flan-t5-xl":
    model_type = "blip2_flan"
elif model_name == "Salesforce/blip2-opt-2.7b":
    model_type = "blip2_opt_2.7b"
elif model_name == "Salesforce/blip2-opt-6.7b-coco":
    model_type = "blip2_opt_6.7b_coco"
else:
    raise ValueError("model name not found")

o_path = Path("lib/config/generated_caption_{}.json".format(model_type))
with o_path.open("w") as f:
    json.dump(all_metadata, f, indent=4)