# Test Ego4dFHOClipDataset

In [None]:
import random

import imageio.v3 as iio
import numpy as np
from IPython.display import Image

# helpful functions


def draw_random_items(dataset, num):
    return [dataset[idx] for idx in random.sample(list(range(len(dataset))), num)]


def display_gif(video_tensor, gif_file_name):
    """Prepares and displays a GIF from a video tensor.

    The video tensor is expected to have the following shape:
    (num_channels, num_frames, height, width).
    """
    iio.imwrite(
        gif_file_name,
        video_tensor.permute(1, 2, 3, 0).numpy().astype(np.uint8),
        extension=".gif",
        # infinite loop
        loop=0,
    )
    return Image(gif_file_name)

Load the data and perform common preprocessing steps.

In [None]:
import sys

sys.path.append("../../")
import re

from train import clean_narration_text, load_ego4d_fho_clip_dataset

C_REGEX = re.compile(r"^\#C C", re.IGNORECASE)

annotation_path = "../../ego4d/v2/annotations/fho_main.json"
dataset = load_ego4d_fho_clip_dataset(annotation_path)

# filter out rejected, invalid and non-C actions
dataset = dataset.filter(
    lambda is_rejected, is_valid_action, narration_text: not is_rejected
    and is_valid_action
    and C_REGEX.match(narration_text),
    input_columns=["is_rejected", "is_valid_action", "narration_text"],
)

# remove unused columns
dataset = dataset.remove_columns(
    [
        "warnings",
        "uid",
        "start_sec",
        "end_sec",
        "start_frame",
        "end_frame",
        "is_valid_action",
        "is_partial",
        "clip_start_frame",
        "clip_end_frame",
        "narration_timestamp_sec",
        "clip_narration_timestamp_sec",
        "narration_annotation_uid",
        "structured_verb",
        "freeform_verb",
        "state_transition",
        "critical_frames",
        "clip_critical_frames",
        "frames",
        "is_rejected",
        "is_invalid_annotation",
        "reject_reason",
        "stage",
    ]
)

dataset = dataset.map(
    clean_narration_text,
    input_columns="narration_text",
    remove_columns="narration_text",
)

print(dataset)

Perform LM-specific preprocessing steps.

In [None]:
from functools import partial

from train import add_prompt_column, generate_inputs

# instruction tuned and decoder only LMs
instr_tuned_decoder_only_lm_dataset = dataset.map(
    partial(add_prompt_column, instruct_tuned=True)
).map(
    partial(generate_inputs, use_decoder_only_lm=True),
    remove_columns=["prompt", "cleaned_narration_text"],
)
print("instr_tuned_decoder_only_lm_dataset")
for item in draw_random_items(instr_tuned_decoder_only_lm_dataset, 3):
    print(item)
print("===================================")

# instruction tuned and seq2seq LMs
instr_tuned_seq2seq_lm_dataset = dataset.map(
    partial(add_prompt_column, instruct_tuned=True)
).map(
    partial(generate_inputs, use_decoder_only_lm=False),
    remove_columns=["prompt", "cleaned_narration_text"],
)
print("instr_tuned_seq2seq_lm_dataset ")
for item in draw_random_items(instr_tuned_seq2seq_lm_dataset, 3):
    print(item)
print("===================================")

# non-instruction tuned and decoder only LMs
non_instr_tuned_decoder_only_lm_dataset = dataset.map(
    partial(add_prompt_column, instruct_tuned=False)
).map(
    partial(generate_inputs, use_decoder_only_lm=True),
    remove_columns=["prompt", "cleaned_narration_text"],
)
print("non_instr_tuned_decoder_only_lm_dataset")
for item in draw_random_items(non_instr_tuned_decoder_only_lm_dataset, 3):
    print(item)
print("===================================")

# non-instruction tuned and seq2seq LMs
non_instr_tuned_seq2seq_lm_dataset = dataset.map(
    partial(add_prompt_column, instruct_tuned=False)
).map(
    partial(generate_inputs, use_decoder_only_lm=False),
    remove_columns=["prompt", "cleaned_narration_text"],
)
print("non_instr_tuned_seq2seq_lm_dataset")
for item in draw_random_items(non_instr_tuned_seq2seq_lm_dataset, 3):
    print(item)
print("===================================")

Tokenize text inputs.

In [None]:
from functools import partial

from transformers import Blip2Processor

from train import batch_tokenize

decoder_only_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
seq2seq_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")

# instruction tuned and decoder only LMs
instr_tuned_decoder_only_lm_dataset = instr_tuned_decoder_only_lm_dataset.map(
    partial(batch_tokenize, decoder_only_processor.tokenizer, use_decoder_only_lm=True),
    batched=True,
    remove_columns="input",
)
print("instr_tuned_decoder_only_lm_dataset")
for item in draw_random_items(instr_tuned_decoder_only_lm_dataset, 3):
    print(item)
print("===================================")

# instruction tuned and seq2seq LMs
instr_tuned_seq2seq_lm_dataset = instr_tuned_seq2seq_lm_dataset.map(
    partial(batch_tokenize, seq2seq_processor.tokenizer, use_decoder_only_lm=False),
    batched=True,
    remove_columns="input",
)
print("instr_tuned_seq2seq_lm_dataset ")
for item in draw_random_items(instr_tuned_seq2seq_lm_dataset, 3):
    print(item)
print("===================================")

# non-instruction tuned and decoder only LMs
non_instr_tuned_decoder_only_lm_dataset = non_instr_tuned_decoder_only_lm_dataset.map(
    partial(batch_tokenize, decoder_only_processor.tokenizer, use_decoder_only_lm=True),
    batched=True,
    remove_columns="input",
)
print("non_instr_tuned_decoder_only_lm_dataset")
for item in draw_random_items(non_instr_tuned_decoder_only_lm_dataset, 3):
    print(item)
print("===================================")

# non-instruction tuned and seq2seq LMs
non_instr_tuned_seq2seq_lm_dataset = non_instr_tuned_seq2seq_lm_dataset.map(
    partial(batch_tokenize, seq2seq_processor.tokenizer, use_decoder_only_lm=False),
    batched=True,
    remove_columns="input",
)
print("non_instr_tuned_seq2seq_lm_dataset")
for item in draw_random_items(non_instr_tuned_seq2seq_lm_dataset, 3):
    print(item)
print("===================================")

Split into train and val sets.

In [None]:
# instruction tuned and decoder only LMs
instr_tuned_decoder_only_lm_dataset_train_val = (
    instr_tuned_decoder_only_lm_dataset.train_test_split(test_size=0.1, shuffle=True)
)
print(
    "instr_tuned_decoder_only_lm_dataset_train_val: "
    f"{instr_tuned_decoder_only_lm_dataset_train_val}"
)

# instruction tuned and seq2seq LMs
instr_tuned_seq2seq_lm_dataset_train_val = (
    instr_tuned_seq2seq_lm_dataset.train_test_split(test_size=0.1, shuffle=True)
)
print(
    "instr_tuned_seq2seq_lm_dataset_train_val: "
    f"{instr_tuned_seq2seq_lm_dataset_train_val }"
)

# non-instruction tuned and decoder only LMs
non_instr_tuned_decoder_only_lm_dataset_train_val = (
    non_instr_tuned_decoder_only_lm_dataset.train_test_split(
        test_size=0.1, shuffle=True
    )
)
print(
    "non_instr_tuned_decoder_only_lm_dataset_train_val: "
    f"{non_instr_tuned_decoder_only_lm_dataset_train_val }"
)

# non-instruction tuned and seq2seq LMs
non_instr_tuned_seq2seq_lm_dataset_train_val = (
    non_instr_tuned_seq2seq_lm_dataset.train_test_split(test_size=0.1, shuffle=True)
)
print(
    "non_instr_tuned_seq2seq_lm_dataset_train_val: "
    f"{non_instr_tuned_seq2seq_lm_dataset_train_val }"
)

Set train/val specific transforms.

In [None]:
from pytorchvideo.data.video import VideoPathHandler
from pytorchvideo.transforms import UniformTemporalSubsample
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomRotation

from train import extract_frames

clip_path = "../../ego4d/v2/clips/"
video_path_handler = VideoPathHandler()
train_transform = Compose(
    [
        UniformTemporalSubsample(8),
        RandomHorizontalFlip(),
        RandomRotation((-45, 45)),
    ]
)
val_transform = Compose(
    [
        UniformTemporalSubsample(8),
    ]
)

# instruction tuned and decoder only LMs
instr_tuned_decoder_only_lm_dataset_train_val["train"].set_transform(
    partial(
        extract_frames,
        video_path_handler,
        decoder_only_processor.image_processor,
        clip_path,
        video_transform=train_transform,
    )
)
instr_tuned_decoder_only_lm_dataset_train_val["test"].set_transform(
    partial(
        extract_frames,
        video_path_handler,
        decoder_only_processor.image_processor,
        clip_path,
        video_transform=val_transform,
    )
)

# instruction tuned and seq2seq LMs
instr_tuned_seq2seq_lm_dataset_train_val["train"].set_transform(
    partial(
        extract_frames,
        video_path_handler,
        seq2seq_processor.image_processor,
        clip_path,
        video_transform=train_transform,
    )
)
instr_tuned_seq2seq_lm_dataset_train_val["test"].set_transform(
    partial(
        extract_frames,
        video_path_handler,
        seq2seq_processor.image_processor,
        clip_path,
        video_transform=val_transform,
    )
)

# non-instruction tuned and decoder only LMs
non_instr_tuned_decoder_only_lm_dataset_train_val["train"].set_transform(
    partial(
        extract_frames,
        video_path_handler,
        decoder_only_processor.image_processor,
        clip_path,
        video_transform=train_transform,
    )
)
non_instr_tuned_decoder_only_lm_dataset_train_val["test"].set_transform(
    partial(
        extract_frames,
        video_path_handler,
        decoder_only_processor.image_processor,
        clip_path,
        video_transform=val_transform,
    )
)

# non-instruction tuned and seq2seq LMs
non_instr_tuned_seq2seq_lm_dataset_train_val["train"].set_transform(
    partial(
        extract_frames,
        video_path_handler,
        seq2seq_processor.image_processor,
        clip_path,
        video_transform=train_transform,
    )
)
non_instr_tuned_seq2seq_lm_dataset_train_val["test"].set_transform(
    partial(
        extract_frames,
        video_path_handler,
        seq2seq_processor.image_processor,
        clip_path,
        video_transform=val_transform,
    )
)

Draw three data points from `non_instr_tuned_decoder_only_lm_dataset` and run them through `VideoBlip2ForConditionalGeneration`.

In [None]:
import torch

from video_blip2 import VideoBlip2ForConditionalGeneration

device = "cuda" if torch.cuda.is_available() else "cpu"
model = VideoBlip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b"
).to(device)

In [None]:
from IPython.display import display

from train import PROMPT

print("non_instr_tuned_decoder_only_lm_dataset train")
for i, item in enumerate(
    draw_random_items(non_instr_tuned_decoder_only_lm_dataset_train_val["train"], 3)
):
    print(f"input_ids: {item['input_ids']}")
    display(
        display_gif(
            item["pixel_values"],
            f"non_instr_tuned_decoder_only_lm_dataset_train_{i}.gif",
        )
    )
    with torch.no_grad():
        generated_ids = model.generate(
            pixel_values=item["pixel_values"].unsqueeze(0).to(device),
            input_ids=decoder_only_processor.tokenizer(
                PROMPT, return_tensors="pt"
            ).input_ids.to(device),
        )
    generated_text = decoder_only_processor.batch_decode(
        generated_ids, skip_special_tokens=True
    )[0].strip()
    print(f"generated_text: {generated_text}")
print("non_instr_tuned_decoder_only_lm_dataset val")
for i, item in enumerate(
    draw_random_items(non_instr_tuned_decoder_only_lm_dataset_train_val["test"], 3)
):
    print(f"input_ids: {item['input_ids']}")
    display(
        display_gif(
            item["pixel_values"], f"non_instr_tuned_decoder_only_lm_dataset_val_{i}.gif"
        )
    )
    with torch.no_grad():
        generated_ids = model.generate(
            pixel_values=item["pixel_values"].unsqueeze(0).to(device),
            input_ids=decoder_only_processor.tokenizer(
                PROMPT, return_tensors="pt"
            ).input_ids.to(device),
        )
    generated_text = decoder_only_processor.batch_decode(
        generated_ids, skip_special_tokens=True
    )[0].strip()
    print(f"generated_text: {generated_text}")
print("===================================")
del model
torch.cuda.empty_cache()

Draw three data points from `instr_tuned_seq2seq_lm_dataset` and run them through `VideoBlip2ForConditionalGeneration`.

In [None]:
model = VideoBlip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-flan-t5-xl"
).to(device)

In [None]:
from train import INSTR_PROMPT

print("instr_tuned_seq2seq_lm_dataset train")
for i, item in enumerate(
    draw_random_items(instr_tuned_seq2seq_lm_dataset_train_val["train"], 3)
):
    print(f"input_ids: {item['input_ids']}")
    print(f"labels: {item['labels']}")
    display(
        display_gif(
            item["pixel_values"], f"instr_tuned_seq2seq_lm_dataset_train_{i}.gif"
        )
    )
    with torch.no_grad():
        generated_ids = model.generate(
            pixel_values=item["pixel_values"].unsqueeze(0).to(device),
            input_ids=seq2seq_processor.tokenizer(
                INSTR_PROMPT, return_tensors="pt"
            ).input_ids.to(device),
        )
    generated_text = seq2seq_processor.batch_decode(
        generated_ids, skip_special_tokens=True
    )[0].strip()
    print(f"generated_text: {generated_text}")
print("instr_tuned_seq2seq_lm_dataset val")
for i, item in enumerate(
    draw_random_items(instr_tuned_seq2seq_lm_dataset_train_val["test"], 3)
):
    print(f"input_ids: {item['input_ids']}")
    print(f"labels: {item['labels']}")
    display(
        display_gif(item["pixel_values"], f"instr_tuned_seq2seq_lm_dataset_val_{i}.gif")
    )
    with torch.no_grad():
        generated_ids = model.generate(
            pixel_values=item["pixel_values"].unsqueeze(0).to(device),
            input_ids=seq2seq_processor.tokenizer(
                INSTR_PROMPT, return_tensors="pt"
            ).input_ids.to(device),
        )
    generated_text = seq2seq_processor.batch_decode(
        generated_ids, skip_special_tokens=True
    )[0].strip()
    print(f"generated_text: {generated_text}")
print("===================================")
del model
torch.cuda.empty_cache()