#### Inference Demo
This is a demo for loading YTVIS data and performing inference. We have selected multiple trajectories from the YTVIS dataset.

In [None]:
import numpy as np
import torch
from mmengine.config import Config, ConfigDict
from dataset.youtube_loader import VISDataset, vis_collate_fn
from mmtrack.registry import DATASETS
from mmengine import build_from_cfg
from PIL import Image, ImageDraw, ImageSequence, ImageFont
from mmtrack.datasets import VideoSampler, EntireVideoBatchSampler

import json

In [None]:
classes_vis = ['airplane', 'bear', 'bird', 'boat', 'car',
                'cat', 'cow', 'deer', 'dog', 'duck',
                'earless_seal', 'elephant', 'fish',
                'flying_disc', 'fox', 'frog', 'giant_panda',
                'giraffe', 'horse', 'leopard', 'lizard',
                'monkey', 'motorbike', 'mouse', 'parrot',
                'person', 'rabbit', 'shark', 'skateboard',
                'snake', 'snowboard', 'squirrel', 'surfboard',
                'tennis_racket', 'tiger', 'train', 'truck',
                'turtle', 'whale', 'zebra']

In [None]:
# load dataset
cfg = Config.fromfile('dataset/youtube_cfg_480.py')
val_dataset = build_from_cfg(cfg.val_dataloader.dataset, DATASETS)
sampler = VideoSampler(val_dataset)
batch_sampler = EntireVideoBatchSampler(sampler)

In [None]:
import cv2
all_test_case = []
def get_fixed_sample(num_frames=16, idx=1):
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        batch_sampler=batch_sampler,
        collate_fn=vis_collate_fn,
    )
    cnt = 0
    for batch in val_dataloader:
        cnt += 1
        video_pixel_values, video_bboxes, video_cls, video_masks, video_name = batch
        all_test_case.append(batch)
        width, height = 384, 256
        # set gif parameters
        gif_name = f'gt_gif_{video_name}_{cnt}.gif'
        duration = 500  # 4fps
        video_pixel_values, video_bboxes, video_cls, video_masks, video_name = batch
        frames = []
        color_list = ["blue", "red", "blue", "green"]
        for fid in range(num_frames):
            img_array = ((video_pixel_values.permute(0, 2, 3, 1).numpy()[fid] + 1) * 127.5).astype(np.uint8)
            img = Image.fromarray(img_array)

            draw = ImageDraw.Draw(img)
            _cnt = 0
            for bbox in video_bboxes[fid].numpy():
                if video_masks[fid][_cnt % 4] == 0:
                    _cnt += 1
                    continue
                x1, y1, x2, y2 = bbox
                top_left = (x1 * width, y1 * height)
                bottom_right = (x2 * width, y2 * height)
                draw.rectangle([top_left, bottom_right], outline=color_list[_cnt % 4], width=3)
                _cnt += 1
            frames.append(img)

        frames[0].save(gif_name, save_all=True, append_images=frames[1:], duration=duration, loop=0, dpi=(300, 300))
        break
_ = get_fixed_sample()

In [None]:
# load model
# Enable additional instance embedding
os.environ["ADD_INS_EMBED"] = "true"
# learnable frame-wise time embedding
os.environ["open_time_embedding"] = 'learnable_frame'
# Disable box attention mechanism
os.environ["open_box_attn"] = 'disable'
# Disable dynamic size adjustment
os.environ["dynamic_size"] = 'disable'
# Do not use encoder only
os.environ["enc_only"] = 'false'
# Do not use decoder only
os.environ["dec_only"] = 'false'
# Enable injector at decoder fusion stage
os.environ["injector_on"] = 'dec_fuse'
# Disable self-attention mechanism
os.environ["self_on"] = 'false'
# Disable track query
os.environ["track_query"] = 'false'

from diffusers import DPMSolverMultistepScheduler, EulerDiscreteScheduler, DDIMScheduler
from diffusers.utils import export_to_video
from pipelines.pipeline_text_to_video_synth import TextToVideoSDPipeline
from models.unet_3d_condition_gligen import UNet3DConditionModel
pretrained_model_path = "runwayml/stable-diffusion-v1-5"
# Download from here: https://huggingface.co/pengxiang/trackdiffusion_ytvis/tree/main/modelscope_ft/unet
unet = UNet3DConditionModel.from_pretrained("/path/to/finetuning/unet", torch_dtype=torch.float16,)

pipe = TextToVideoSDPipeline.from_pretrained(pretrained_model_path, unet=unet, torch_dtype=torch.float16, variant="fp16", low_cpu_mem_usage=False)
pipe = pipe.to('cuda')

In [None]:
with open('/youtubevis_caption.json', 'r', encoding='utf-8') as file:
    caption_data = json.load(file)

In [None]:
# video_pixel_values, video_bboxes, video_cls, video_masks = get_random_sample(8, 22)
import os
from PIL import Image, ImageDraw
import numpy as np
cnt = 0
for batch in all_test_case:
    cnt += 1
    video_pixel_values, video_bboxes, video_cls, video_masks, video_name = batch
    # Flatten the tensor and convert to list of integers
    video_cls = video_cls[:16]
    video_cls_list = video_cls.view(-1).tolist()

    # Convert each integer to its corresponding class
    video_cls_str_list = [classes_vis[int_idx] if int_idx < len(classes_vis) else "" for int_idx in video_cls_list]
    
    video_bboxes = video_bboxes[:16]
    video_masks = video_masks[:16]
    image = pipe(
        prompt=caption_data[video_name],
        width=384,
        height=256,
        seg_phrases=video_cls_str_list,
        video_masks=video_masks,
        bbox_prompt=video_bboxes,
        num_frames=16,
        num_inference_steps=50,
        guidance_scale=1.2, # You may find this strange here, please take a look at our pipeline.
    ).frames
    
    width, height = 384, 256
    frames = []
    color_list = ["red", "yellow", "blue", "green"]
    for fid in range(16):
        img_array = image[fid]
        img = Image.fromarray(img_array)
        draw = ImageDraw.Draw(img)
        _cnt = 0
        for bbox in video_bboxes[fid].numpy()[:]:
            x1, y1, x2, y2 = bbox
            top_left = (x1 * width, y1 * height)
            bottom_right = (x2 * width, y2 * height)
            draw.rectangle([top_left, bottom_right], outline=color_list[_cnt % 4], width=2)
            _cnt += 1

        frames.append(img)

    # save as GIF
    frames[0].save(f'./output_{video_name}.gif',
                format='GIF',
                append_images=frames[1:],
                save_all=True,
                duration=500,
                loop=0)
