In [1]:
import os
import io

from models.hawkeye_it import HawkEye_it
from utils.easydict import EasyDict
import torch

from transformers import StoppingCriteria, StoppingCriteriaList

from PIL import Image
import numpy as np
import numpy as np
from decord import VideoReader, cpu
import torchvision.transforms as T
from dataset.video_transforms import (
    GroupNormalize, GroupScale, GroupCenterCrop, 
    Stack, ToTorchFormatTensor
)
from torchvision.transforms.functional import InterpolationMode

from torchvision import transforms

import matplotlib.pyplot as plt

from IPython.display import Video, HTML

from peft import get_peft_model, LoraConfig, TaskType
import copy

In [13]:
# load stage2 model
from utils.config import Config
import os

folder_name, ckpt = 'outputs/Hawkeye-stage3', 1000000

config_file = os.path.join(folder_name, "config.json")
cfg = Config.from_file(config_file)


use_lora = cfg.model.use_lora
cfg.model.use_lora = False      # do not add lora in __init__, we will add it later
cfg.model.vision_encoder.num_frames = 4

model = HawkEye_it(config=cfg.model)
model.set_device_ids([cfg.device])

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [14]:
# add lora to load stage3 ckpt
if use_lora:
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM, inference_mode=False, 
        r=16, lora_alpha=32, lora_dropout=0.
    )
    model.llama_model = get_peft_model(model.llama_model, peft_config)
    model.use_lora = True

state_dict = torch.load(os.path.join(folder_name, 'ckpt_%d.pth'))
if 'model' in state_dict.keys():
    msg = model.load_state_dict(state_dict['model'], strict=False)
else:
    msg = model.load_state_dict(state_dict, strict=False)
print(msg)

model = model.eval()

_IncompatibleKeys(missing_keys=['vision_encoder.encoder.patch_embed.proj.weight', 'vision_encoder.encoder.patch_embed.proj.bias', 'vision_encoder.encoder.blocks.0.norm1.weight', 'vision_encoder.encoder.blocks.0.norm1.bias', 'vision_encoder.encoder.blocks.0.attn.q_bias', 'vision_encoder.encoder.blocks.0.attn.v_bias', 'vision_encoder.encoder.blocks.0.attn.qkv.weight', 'vision_encoder.encoder.blocks.0.attn.proj.weight', 'vision_encoder.encoder.blocks.0.attn.proj.bias', 'vision_encoder.encoder.blocks.0.norm2.weight', 'vision_encoder.encoder.blocks.0.norm2.bias', 'vision_encoder.encoder.blocks.0.mlp.fc1.weight', 'vision_encoder.encoder.blocks.0.mlp.fc1.bias', 'vision_encoder.encoder.blocks.0.mlp.fc2.weight', 'vision_encoder.encoder.blocks.0.mlp.fc2.bias', 'vision_encoder.encoder.blocks.1.norm1.weight', 'vision_encoder.encoder.blocks.1.norm1.bias', 'vision_encoder.encoder.blocks.1.attn.q_bias', 'vision_encoder.encoder.blocks.1.attn.v_bias', 'vision_encoder.encoder.blocks.1.attn.qkv.weight', 

In [4]:
def get_prompt(conv):
    ret = conv.system + conv.sep
    for role, message in conv.messages:
        if message:
            ret += role + ": " + message + conv.sep
        else:
            ret += role + ":"
    return ret


def get_prompt2(conv):
    ret = conv.system + conv.sep
    count = 0
    for role, message in conv.messages:
        count += 1
        if count == len(conv.messages):
            ret += role + ": " + message
        else:
            if message:
                ret += role + ": " + message + conv.sep
            else:
                ret += role + ":"
    return ret


def get_context_emb(conv, model, img_list, answer_prompt=None, print_res=False):
    if answer_prompt:
        prompt = get_prompt2(conv)
    else:
        prompt = get_prompt(conv)
    if print_res:
        print(prompt)
    if '<VideoHere>' in prompt:
        prompt_segs = prompt.split('<VideoHere>')
    else:
        prompt_segs = prompt.split('<ImageHere>')
    assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
    with torch.no_grad():
        seg_tokens = [
            model.llama_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to("cuda:0").input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        seg_embs = [model.llama_model.base_model.model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
    mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
    mixed_embs = torch.cat(mixed_embs, dim=1)
    return mixed_embs


def ask(text, conv):
    conv.messages.append([conv.roles[0], text + '\n'])
        

class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False
    
    
def answer(conv, model, img_list, do_sample=True, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
               repetition_penalty=1.0, length_penalty=1, temperature=1.0, answer_prompt=None, print_res=False):
    stop_words_ids = [
        torch.tensor([835]).to("cuda:0"),
        torch.tensor([2277, 29937]).to("cuda:0")]  # '###' can be encoded in two different ways.
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
    
    conv.messages.append([conv.roles[1], answer_prompt])
    embs = get_context_emb(conv, model, img_list, answer_prompt=answer_prompt, print_res=print_res)
    with torch.no_grad():
        outputs = model.llama_model.generate(
            inputs_embeds=embs,
            max_new_tokens=max_new_tokens,
            stopping_criteria=stopping_criteria,
            num_beams=num_beams,
            do_sample=do_sample,
            min_length=min_length,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            temperature=temperature,
        )
    output_token = outputs[0]
    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
            output_token = output_token[1:]
    if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
            output_token = output_token[1:]
    output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
    output_text = output_text.split('###')[0]  # remove the stop sign '###'
    output_text = output_text.split('Assistant:')[-1].strip()
    conv.messages[-1][1] = output_text + '\n'
    return output_text, output_token.cpu().numpy()

In [5]:
def get_index(num_frames, num_segments):
    seg_size = float(num_frames - 1) / num_segments
    start = int(seg_size / 2)
    offsets = np.array([
        start + int(np.round(seg_size * idx)) for idx in range(num_segments)
    ])
    return offsets


def load_video(video_path, num_segments=8, return_msg=False, resolution=224):
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    num_frames = len(vr)
    frame_indices = get_index(num_frames, num_segments)

    # transform
    crop_size = resolution
    scale_size = resolution
    input_mean = [0.48145466, 0.4578275, 0.40821073]
    input_std = [0.26862954, 0.26130258, 0.27577711]

    transform = T.Compose([
        GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
        GroupCenterCrop(crop_size),
        Stack(),
        ToTorchFormatTensor(),
        GroupNormalize(input_mean, input_std) 
    ])

    images_group = list()
    for frame_index in frame_indices:
        img = Image.fromarray(vr[frame_index].numpy())
        images_group.append(img)
    torch_imgs = transform(images_group)
    if return_msg:
        fps = float(vr.get_avg_fps())
        sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
        # " " should be added in the start and end
        msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
        return torch_imgs, msg
    else:
        return torch_imgs

In [6]:
def get_sinusoid_encoding_table(n_position=784, d_hid=1024, cur_frame=8, ckpt_num_frame=4, pre_n_position=784): 
    ''' Sinusoid position encoding table ''' 
    # TODO: make it with torch instead of numpy 
    def get_position_angle_vec(position): 
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 
    
    # generate checkpoint position embedding
    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)]) 
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 
    sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
    
    print(f"n_position: {n_position}")
    print(f"pre_n_position: {pre_n_position}")
    
    if n_position != pre_n_position:
        T = ckpt_num_frame # checkpoint frame
        P = 14 # checkpoint size
        C = d_hid
        new_P = int((n_position // cur_frame) ** 0.5) # testing size
        if new_P != 14:
            print(f'Pretraining uses 14x14, but current version is {new_P}x{new_P}')
            print(f'Interpolate the position embedding')
            sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
            sinusoid_table = sinusoid_table.reshape(-1, P, P, C).permute(0, 3, 1, 2)
            sinusoid_table = torch.nn.functional.interpolate(
                sinusoid_table, size=(new_P, new_P), mode='bicubic', align_corners=False)
            # BT, C, H, W -> BT, H, W, C ->  B, T, H, W, C
            sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(-1, T, new_P, new_P, C)
            sinusoid_table = sinusoid_table.flatten(1, 3)  # B, THW, C
    
    if cur_frame != ckpt_num_frame:
        print(f'Pretraining uses 4 frames, but current frame is {cur_frame}')
        print(f'Interpolate the position embedding')
        T = ckpt_num_frame # checkpoint frame
        new_T = cur_frame # testing frame
        # interpolate
        P = int((n_position // cur_frame) ** 0.5) # testing size
        C = d_hid
        sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
        sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T)  # BHW, C, T
        sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
        sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
        sinusoid_table = sinusoid_table.flatten(1, 3)  # B, THW, C
        
    return sinusoid_table

In [7]:
def encode_video(model, vid_path, instruction="Watch the video and answer the question.", num_frame=12):
    num_frame = num_frame
    resolution = 224
    new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame)
    model.vision_encoder.encoder.pos_embed = new_pos_emb
    vid, msg = load_video(vid_path, num_segments=num_frame, return_msg=True, resolution=resolution)
    print(msg)
        
    # The model expects inputs of shape: T x C x H x W
    TC, H, W = vid.shape
    video = vid.reshape(1, TC//3, 3, H, W).to("cuda:0")

    img_list = []
    with torch.no_grad():
        image_emb, _ = model.encode_img(video, instruction)
    img_list.append(image_emb)
    return img_list, msg


def answer_video_and_question(model, vid_path, question, instruction="Watch the video and answer the question."):
    img_list, msg = encode_video(model, vid_path, instruction)

    chat = EasyDict({
        "system": "",
        "roles": ("Human", "Assistant"),
        "messages": [],
        "sep": "###"
    })

    chat.messages.append(["", instruction])
    chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> %s\n" % msg])
    ask(question, chat)

    llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)[0]
    return(llm_message)


In [115]:
vid_path = "data/videos/charades/00NN7.mp4"
instruction = "Watch the video and choose the most fitting option based on the observed content."
img_list, msg = encode_video(model, vid_path, instruction)

chat = EasyDict({
    "system": "",
    "roles": ("Human", "Assistant"),
    "messages": [],
    "sep": "###"
})

chat.messages.append(["", instruction])
chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> %s\n" % msg])

ask("What is the woman doing in the middle of the video?", chat)
llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)[0]
print(llm_message)

ask("And at the end of the video?", chat)
llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)[0]
print(llm_message)

ask("And at the beginning of the video?", chat)
llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)[0]
print(llm_message)


n_position: 2352
pre_n_position: 784
Pretraining uses 4 frames, but current frame is 12
Interpolate the position embedding
The video contains 12 frames sampled at 1.3, 3.8, 6.4, 8.9, 11.5, 14.0, 16.6, 19.2, 21.7, 24.2, 26.8, 29.4 seconds.
###: Watch the video and choose the most fitting option based on the observed content.###Human: <Video><VideoHere></Video> The video contains 12 frames sampled at 1.3, 3.8, 6.4, 8.9, 11.5, 14.0, 16.6, 19.2, 21.7, 24.2, 26.8, 29.4 seconds.
###Human: What is the woman doing in the middle of the video?
###Assistant:
The woman in the middle of the video is using a laptop computer.
###: Watch the video and choose the most fitting option based on the observed content.###Human: <Video><VideoHere></Video> The video contains 12 frames sampled at 1.3, 3.8, 6.4, 8.9, 11.5, 14.0, 16.6, 19.2, 21.7, 24.2, 26.8, 29.4 seconds.
###Human: What is the woman doing in the middle of the video?
###Assistant: The woman in the middle of the video is using a laptop computer.
#

In [132]:
vid_path = 'data/videos/qvhighlights/_-5vXZfppKE_60.0_210.0.mp4'
instruction = "Watch the video and choose the most fitting option based on the observed content."
img_list, msg = encode_video(model, vid_path, instruction)

chat = EasyDict({
    "system": "",
    "roles": ("Human", "Assistant"),
    "messages": [],
    "sep": "###"
})

chat.messages.append(["", instruction])
chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> %s\n" % msg])

ask("What is this video about?", chat)
llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)[0]
print(llm_message)

ask("What will the weather be like according to the video?", chat)
llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)[0]
print(llm_message)

ask("At which segment of the video is the temperature shown, and how will it be?", chat)
llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)[0]
print(llm_message)


n_position: 2352
pre_n_position: 784
Pretraining uses 4 frames, but current frame is 12
Interpolate the position embedding
The video contains 12 frames sampled at 5.4, 16.3, 27.3, 38.2, 49.1, 60.0, 70.9, 81.8, 92.7, 103.6, 114.5, 125.4 seconds.
###: Watch the video and choose the most fitting option based on the observed content.###Human: <Video><VideoHere></Video> The video contains 12 frames sampled at 5.4, 16.3, 27.3, 38.2, 49.1, 60.0, 70.9, 81.8, 92.7, 103.6, 114.5, 125.4 seconds.
###Human: What is this video about?
###Assistant:


The video is about a weather forecast for the next few days.
###: Watch the video and choose the most fitting option based on the observed content.###Human: <Video><VideoHere></Video> The video contains 12 frames sampled at 5.4, 16.3, 27.3, 38.2, 49.1, 60.0, 70.9, 81.8, 92.7, 103.6, 114.5, 125.4 seconds.
###Human: What is this video about?
###Assistant: The video is about a weather forecast for the next few days.
###Human: What will the weather be like according to the video?
###Assistant:
The video shows a forecast for the next few days, including a tropical storm and a heavy rain event.
###: Watch the video and choose the most fitting option based on the observed content.###Human: <Video><VideoHere></Video> The video contains 12 frames sampled at 5.4, 16.3, 27.3, 38.2, 49.1, 60.0, 70.9, 81.8, 92.7, 103.6, 114.5, 125.4 seconds.
###Human: What is this video about?
###Assistant: The video is about a weather forecast for the next few days.
###Human: What will the weather be like accordin

In [21]:
vid_path = 'data/videos/qvhighlights/_2mgEMfnYzw_60.0_210.0.mp4'
instruction = "Examine the video and make a decision based on the content presented in the footage."
img_list, msg = encode_video(model, vid_path, instruction, num_frame=20)

chat = EasyDict({
    "system": "",
    "roles": ("Human", "Assistant"),
    "messages": [],
    "sep": "###"
})

chat.messages.append(["", instruction])
chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> %s\n" % msg])

ask("What is this video about?", chat)
llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)[0]
print(llm_message)

ask("What happens at the beginning of the video?", chat)
llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)[0]
print(llm_message)

ask("In which part of the video does a pot of noodles appear?", chat)
llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)[0]
print(llm_message)


n_position: 3920
pre_n_position: 784
Pretraining uses 4 frames, but current frame is 20
Interpolate the position embedding
The video contains 20 frames sampled at 3.7, 11.2, 18.7, 26.2, 33.7, 41.2, 48.7, 56.2, 63.7, 71.2, 78.7, 86.2, 93.7, 101.2, 108.7, 116.2, 123.7, 131.2, 138.7, 146.2 seconds.
###: Examine the video and make a decision based on the content presented in the footage.###Human: <Video><VideoHere></Video> The video contains 20 frames sampled at 3.7, 11.2, 18.7, 26.2, 33.7, 41.2, 48.7, 56.2, 63.7, 71.2, 78.7, 86.2, 93.7, 101.2, 108.7, 116.2, 123.7, 131.2, 138.7, 146.2 seconds.
###Human: What is this video about?
###Assistant:


The video is about a girl who is using a laptop and a phone to make a video call.
###: Examine the video and make a decision based on the content presented in the footage.###Human: <Video><VideoHere></Video> The video contains 20 frames sampled at 3.7, 11.2, 18.7, 26.2, 33.7, 41.2, 48.7, 56.2, 63.7, 71.2, 78.7, 86.2, 93.7, 101.2, 108.7, 116.2, 123.7, 131.2, 138.7, 146.2 seconds.
###Human: What is this video about?
###Assistant: The video is about a girl who is using a laptop and a phone to make a video call.
###Human: What happens at the beginning of the video?
###Assistant:
At the beginning of the video, a girl is shown with a cat on her phone.
###: Examine the video and make a decision based on the content presented in the footage.###Human: <Video><VideoHere></Video> The video contains 20 frames sampled at 3.7, 11.2, 18.7, 26.2, 33.7, 41.2, 48.7, 56.2, 63.7, 71.2, 78.7, 86.2, 93.7, 101.2, 108.7, 116.2, 123.7, 131.2, 138.7, 146.2 seconds.
###Human: What is this video about?
###Assista

In [162]:
vid_path = 'data/videos/qvhighlights/_4tpo4pUN3k_510.0_660.0.mp4'
instruction = "Examine the video and make a decision based on the content presented in the footage."
img_list, msg = encode_video(model, vid_path, instruction, num_frame=20)

chat = EasyDict({
    "system": "",
    "roles": ("Human", "Assistant"),
    "messages": [],
    "sep": "###"
})

chat.messages.append(["", instruction])
chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> %s\n" % msg])

ask("During which part of the video does 'the woman shows different dishes' occur, at the beginning, the middle or the end of the video?", chat)
llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)[0]
print(llm_message)

ask("What happens at the end of the video?", chat)
llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)[0]
print(llm_message)


n_position: 3920
pre_n_position: 784
Pretraining uses 4 frames, but current frame is 20
Interpolate the position embedding
The video contains 20 frames sampled at 3.7, 11.2, 18.7, 26.2, 33.7, 41.2, 48.7, 56.2, 63.7, 71.2, 78.7, 86.2, 93.7, 101.2, 108.7, 116.2, 123.7, 131.2, 138.7, 146.2 seconds.
###: Examine the video and make a decision based on the content presented in the footage.###Human: <Video><VideoHere></Video> The video contains 20 frames sampled at 3.7, 11.2, 18.7, 26.2, 33.7, 41.2, 48.7, 56.2, 63.7, 71.2, 78.7, 86.2, 93.7, 101.2, 108.7, 116.2, 123.7, 131.2, 138.7, 146.2 seconds.
###Human: During which part of the video does 'the woman shows different dishes' occur, at the beginning, the middle or the end of the video?
###Assistant:
The woman shows different dishes at the beginning of the video.
###: Examine the video and make a decision based on the content presented in the footage.###Human: <Video><VideoHere></Video> The video contains 20 frames sampled at 3.7, 11.2, 18.7, 