# Video QA with BLIP2

Load the clip for an arbitrary action from Ego4d.

In [None]:
import json

from pytorchvideo.data.video import VideoPathHandler

with open("../../ego4d/v2/annotations/fho_main.json") as f:
    fho_main = json.load(f)

video = fho_main["videos"][20]
interval = video["annotated_intervals"][2]
action = interval["narrated_actions"][4]

print(f'video_uid: {video["video_uid"]}')
print(f'start_sec: {action["start_sec"]}')
print(f'end_sec: {action["end_sec"]}')
print(f'clip_uid: {interval["clip_uid"]}')
print(f'clip_start_sec: {action["clip_start_sec"]}')
print(f'clip_end_sec: {action["clip_end_sec"]}')
print(f'narration_text: {action["narration_text"]}')

video_path_handler = VideoPathHandler()
video = video_path_handler.video_from_path(
    f"../../ego4d/v2/clips/{interval['clip_uid']}.mp4"
)
clip = video.get_clip(action["clip_start_sec"], action["clip_end_sec"])

Load `Salesforce/blip2-opt-2.7b`.

In [None]:
import torch
from transformers import Blip2ForConditionalGeneration, Blip2Processor

device = "cuda" if torch.cuda.is_available() else "cpu"
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(
    device
)

Try to perform VideoQA using the middle frame.

In [None]:
prompt = "Question: what is the camera wearer doing? Answer:"
inputs = processor(
    images=clip["video"][:, clip["video"].size(1) // 2, ...],
    text=prompt,
    return_tensors="pt",
).to(device)
print(f"inputs: {({k: v.size() for k, v in inputs.items()})}")
generated_ids = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[
    0
].strip()
print(f"generated_text: {generated_text}")

Try to perform VideoQA using the mean of image features.

In [None]:
# hacky monkeypatching to replace the original vision model
# with a simple mean vision model
import torch.nn as nn


class MeanVisionModel(nn.Module):
    def __init__(self, blip2_vision_model):
        super().__init__()
        self.blip2_vision_model = blip2_vision_model

    def forward(
        self,
        pixel_values=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # pixel_values: (B, C, T, H, W)
        vision_outputs = self.blip2_vision_model(
            pixel_values=pixel_values.permute(0, 2, 1, 3, 4).flatten(end_dim=1),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # last_hidden_state: (B*T, seq_len, hidden_size)
        if return_dict:
            last_hidden_state = vision_outputs.last_hidden_state
        else:
            last_hidden_state = vision_outputs[0]
        seq_len = last_hidden_state.size(1)
        batch_size, _, time, _, _ = pixel_values.size()
        last_hidden_state = last_hidden_state.view(batch_size, time, seq_len, -1).mean(
            dim=1
        )

        if return_dict:
            vision_outputs["last_hidden_state"] = last_hidden_state
            return vision_outputs
        else:
            return (last_hidden_state,) + vision_outputs[1:]


original_vision_model = model.vision_model
model.vision_model = MeanVisionModel(original_vision_model)

In [None]:
inputs = processor(
    images=clip["video"].permute(1, 0, 2, 3), text=prompt, return_tensors="pt"
).to(device)
inputs["pixel_values"] = inputs["pixel_values"].permute(1, 0, 2, 3).unsqueeze(0)
print(f"inputs: {({k: v.size() for k, v in inputs.items()})}")
generated_ids = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[
    0
].strip()
print(f"generated_text: {generated_text}")