In [None]:
from dotenv import load_dotenv
from pathlib import Path
import sys


sys.path.append(Path("..").resolve().as_posix())
_ = load_dotenv()

In [None]:
from training_toolkit import build_trainer, llava_next_video_preset, local_video_preset

In [None]:
trainer = build_trainer(
    **local_video_preset.as_kwargs(),
    **llava_next_video_preset.as_kwargs(),
)

In [None]:
trainer.train()

## Check out the result

In [None]:
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
    f"{OUTPUT_DIR}/checkpoint-40",
    torch_dtype=torch.float16,
    device_map="auto",
)

In [None]:
dataset = load_from_disk("msrvtt_1000.hf")
train_dataset, test_dataset = dataset['train'].with_format("torch"), dataset['test'].with_format("torch")

In [None]:
from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML

# convert to image from proceessed tensors
clip = example["pixel_values_videos"][0] * 255
clip = clip.permute(0, 2, 3, 1).clamp(0, 255)

# np array with shape (frames, height, width, channels)
video = np.array(clip).astype(np.uint8)

fig = plt.figure()
im = plt.imshow(video[0,:,:,:])

plt.close() # this is required to not display the generated image

def init():
    im.set_data(video[0,:,:,:])

def animate(i):
    im.set_data(video[i,:,:,:])
    return im

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0],
                               interval=100)
HTML(anim.to_html5_video())

# and the caption associated with the video clip
processor.batch_decode(example["input_ids"])

In [None]:
# And we also need to load the processor for collate_fn
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False)
processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right

In [None]:
example = test_dataset[0]

# convert to image from proceessed tensors
clip = example["pixel_values_videos"][0] * 255
clip = clip.permute(0, 2, 3, 1).clamp(0, 255)

# np array with shape (frames, height, width, channels)
video = np.array(clip).astype(np.uint8)

fig = plt.figure()
im = plt.imshow(video[0,:,:,:])

plt.close() # this is required to not display the generated image

def init():
    im.set_data(video[0,:,:,:])

def animate(i):
    im.set_data(video[i,:,:,:])
    return im

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0],
                               interval=100)
HTML(anim.to_html5_video())

In [None]:
processor.batch_decode(example["input_ids"])

In [None]:
def run_inference(video_clip, model):
    # Let's use chat template to format the prompt correctly, this time without the caption
    conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "Provide a detailed caption for this video."},
                    {"type": "video"},
                    ],
            },
        ]

    # Set add_generation_prompt to add the "ASSISTANT: " at the end
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

    batch = processor(
        text=prompt,
        videos=None, # we have a processed video, passing it again to processor causes errors
        return_tensors="pt"
    ).to(model.device)
    video_clip = video_clip.to(model.device)

    out = model.generate(**batch, pixel_values_videos=video_clip, max_length=MAX_LENGTH, do_sample=True)
    generated_text = processor.batch_decode(out, skip_special_tokens=True)
    return generated_text

In [None]:
run_inference(example["pixel_values_videos"], model)

In [None]:
old_model = LlavaNextVideoForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto",
)

In [None]:
run_inference(example["pixel_values_videos"], old_model)