In [None]:
from transformers import AutoProcessor
from datasets import concatenate_datasets, Dataset
import numpy as np
import pandas as pd
from pydantic import BaseModel, Field
import torch

from pathlib import Path
import json
from typing import Generator, Dict

In [None]:
MODEL_ID = "llava-hf/LLaVa-NeXT-Video-7b-hf"
NUM_FRAMES = 8
MAX_LENGTH = 256

In [None]:
from decord import VideoReader, gpu, cpu

def read_video_decord(video_path, num_frames=NUM_FRAMES):
    '''
    Decode the video with Decord decoder.

    Args:
        video_path (str): Path to the video file.
        num_frames (int): Number of frames to sample uniformly. Defaults to NUM_FRAMES

    Returns:
        np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    vr = VideoReader(uri=video_path, ctx=cpu(0)) # you need to install from source to use gpu ctx
    indices = np.arange(0, len(vr), len(vr) / num_frames).astype(int)
    frames = vr.get_batch(indices).asnumpy()
    return frames

In [None]:
def load_qa(qa_path: Path) -> pd.DataFrame:
    # read annotations from disk

    with qa_path.open("r") as f:
        qa = json.load(f)

    samples = []
    for entry in qa:
        for i, conversation in enumerate(entry["conversations"]):
            assert conversation[0]["from"] == "human"
            assert conversation[1]["from"] == "gpt"

            samples.append(
                {
                    "video": entry["video"],
                    "messages": conversation,
                    "question_id": f"{entry['video'][:-4]}_{i}",
                }
            )

    df = pd.DataFrame(samples)
    return df


def prepare_batches(video_path: Path, qa_path: Path) -> Generator[Dict, None, None]:
    df = load_qa(qa_path)

    for idx, row in df.iterrows():
        sample = {
            "video_path": Path(video_path).joinpath(row.video).resolve().as_posix(),
            "text_prompt": row.messages[0]["value"],
            "target_answer": row.messages[1]["value"],
            "question_id": row.question_id,
        }
        
        yield sample
        if idx >= 1000:
            break

In [None]:
# We collate to save everything in tensor format to speed-up dataloading process
# Saving the whole video clip (array) along with caption (string) will slow down iteration
# because unprocessed video clip will take up more memory due to higher resolution
# The processed video on the other hand is always 336x336 in size and fixed frame count per clip
# see: https://discuss.huggingface.co/t/slow-iteration-speed-with-and-without-keep-in-memory-true/33587


def collate_fn(sample, processor):
    video_clip = read_video_decord(
        sample["video_path"]
    )  # change to the video decoder you want

    # Let's use chat template to format the prompt correctly
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": sample["text_prompt"]},
                {"type": "video"},
            ],
        },
        {
            "role": "assistant",
            "content": [
                {"type": "text", "text": sample["target_answer"]},
            ],
        },
    ]

    prompt = processor.apply_chat_template(conversation, add_generation_prompt=False)

    batch = processor(
        text=prompt,
        videos=video_clip,
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt",
    )

    return batch

In [None]:
video_path = Path("/data/msrvtt-qa/videos")
qa_path = Path("/data/msrvtt-qa/qa.json")

ds = Dataset.from_generator(
    prepare_batches, gen_kwargs={"video_path": video_path, "qa_path": qa_path}
)

# ds[0]

In [None]:
# load dataset
# And we also need to load the processor for collate_fn
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False)
processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right

dataset = ds.map(collate_fn, batched=False, fn_kwargs={"processor": processor}, num_proc=8)

In [None]:
# Concatenate the datasets we have and load a tokenizer
# dataset_processed = concatenate_datasets(datasets_combined)
dataset = dataset.shuffle(seed=42)
dataset = dataset.train_test_split(test_size=0.2)

In [None]:
train_dataset, test_dataset = dataset['train'].with_format("torch"), dataset['test'].with_format("torch")

In [None]:
class LlavaNextVideoDataCollatorWithPadding:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, features):
        padded_inputs = self.processor.tokenizer.pad(
            {
                "input_ids": [feat['input_ids'][0] for feat in features], # each element is one batch only so we slice [0]
                "attention_mask": [feat['attention_mask'][0] for feat in features],
            },
            padding=True,
            return_tensors="pt",
        )

        labels = padded_inputs["input_ids"].clone()
        labels[labels == self.processor.tokenizer.pad_token_id] = -100
        padded_inputs["labels"] = labels
        padded_inputs["pixel_values_videos"] = torch.cat([feat['pixel_values_videos'] for feat in features], dim=0)

        return padded_inputs

In [None]:
example = train_dataset[0]

In [None]:
from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML

# convert to image from proceessed tensors
clip = example["pixel_values_videos"][0] * 255
clip = clip.permute(0, 2, 3, 1).clamp(0, 255)

# np array with shape (frames, height, width, channels)
video = np.array(clip).astype(np.uint8)

fig = plt.figure()
im = plt.imshow(video[0,:,:,:])

plt.close() # this is required to not display the generated image

def init():
    im.set_data(video[0,:,:,:])

def animate(i):
    im.set_data(video[i,:,:,:])
    return im

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0],
                               interval=100)
HTML(anim.to_html5_video())

In [None]:
# and the caption associated with the video clip
processor.batch_decode(example["input_ids"])