In [None]:
from dotenv import load_dotenv
from pathlib import Path
import sys


sys.path.append(Path("..").resolve().as_posix())
_ = load_dotenv()

In [None]:
from training_toolkit import VideoQAImporter

In [None]:
video_importer = VideoQAImporter()

dataset = video_importer(name="msrvtt-qa_1000", video_path="/data/msrvtt-qa/videos", annotation_path="/data/msrvtt-qa/qa.json")
dataset.save_to_disk("msrvtt-qa_1000")

In [None]:
from training_toolkit import build_trainer, llava_next_video_preset, video_qa_preset

In [None]:
data_preset = video_qa_preset.with_path("msrvtt-qa_1000")

trainer = build_trainer(
    **data_preset.as_kwargs(),
    **llava_next_video_preset.as_kwargs(),
)

In [None]:
trainer.train()

In [None]:
sample = data_preset.dataset["test"][0]

In [None]:
from transformers import AutoProcessor, AutoModel
from training_toolkit import animate_video_sample
from IPython.display import HTML

processor = AutoProcessor.from_pretrained(llava_next_video_preset.hf_model_id)
model = AutoModel.from_pretrained(llava_next_video_preset.training_args["output_dir"])

In [None]:
animation = animate_video_sample(sample)

# and the caption associated with the video clip
print(processor.batch_decode(sample["input_ids"]))

HTML(animation.to_html5_video())

In [None]:
def run_inference(video_clip):
    # Let's use chat template to format the prompt correctly, this time without the caption
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "Provide a detailed caption for this video."},
                {"type": "video"},
            ],
        },
    ]

    # Set add_generation_prompt to add the "ASSISTANT: " at the end
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

    batch = processor(
        text=prompt,
        videos=None,  # we have a processed video, passing it again to processor causes errors
        return_tensors="pt",
    ).to(model.device)
    video_clip = video_clip.to(model.device)

    out = model.generate(
        **batch, pixel_values_videos=video_clip, max_length=MAX_LENGTH, do_sample=True
    )
    generated_text = processor.batch_decode(out, skip_special_tokens=True)
    return generated_text

In [None]:
run_inference(sample["pixel_values_videos"])