# Inference Tutorial

In this tutorial you'll learn:
- [How to load an OmniMAE model](#Load-Model)
- [Inference with Images](#Inference-with-Images)
- [Inference with Videos](#Inference-with-Videos)

### Install modules 

We assume that `torch` and `torchvision` have already installed using the instructions in the [README.md](https://github.com/facebookresearch/omnivore/blob/main/README.md#setup-and-installation). 

Please install the other dependencies required for using Omnivore models - `einops`, `pytorchvideo` and `timm`.

For this tutorial, please additionally install `ipywidgets` and `matplotlib`.

### Make data dir

In [None]:
!mkdir data

### Import modules

In [None]:
import os 

try:
    from omnivore.transforms import SpatialCrop, TemporalCrop, DepthNorm
except:
    # need to also make the omnivore transform module available
    !git clone https://github.com/facebookresearch/omnivore.git
    sys.path.append("./omnivore")

    from omnivore.transforms import SpatialCrop, TemporalCrop, DepthNorm

import csv
import json
from typing import List

import numpy as np
import einops
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from PIL import Image
from pytorchvideo.data.encoded_video import EncodedVideo
from torchvision.transforms._transforms_video import NormalizeVideo

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    ShortSideScale,
    UniformTemporalSubsample,
)

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from ipywidgets import Video

In [None]:
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
IMAGENET_STD = np.array([0.229, 0.224, 0.225])

def show_image(image, rescale=True, mean=IMAGENET_MEAN, std=IMAGENET_STD, title=""):
    
    # Image must be [H, W, 3]
    if not image.shape[2] == 3:
        image = torch.einsum("chw->hwc", image)
        
    # Sometimes, unnormalization has already been made
    if rescale:
        plt.imshow(torch.clip((image * std + mean) * 255, 0, 255).int())
    else:
        plt.imshow(torch.clip(image * 255, 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis("off")
    return

In [None]:
def pad_imvid(x, time_dim=2):
    if x.shape[time_dim] == 1:
        new_shape = [1] * len(x.shape)
        new_shape[time_dim] = 2
        x = x.repeat(new_shape)

    # Duplicate an image if odd number of frames
    if x.shape[time_dim] % 2 == 1:
        nb_frames = x.shape[time_dim]
        repeat_tensor = torch.ones(nb_frames, dtype=torch.int64).to(x.device)
        repeat_tensor[0] = 2  # duplicate first image
        # repeat_tensor[-1] = 2  # duplicate last image
        x = x.repeat_interleave(repeat_tensor, dim=time_dim)
        
    return x

def patchify(imgs, patch_shape):
    """Adapted from omnivision.losses.mae_loss.MAELoss.patchify()."""
    assert imgs.shape[-2] == imgs.shape[-1]  # Spatial dimensions match up

    # Add a dummy time dimension to 2D patches for consistency.
    # Since it is 1, it will not affect the final number of patches
    if len(patch_shape) == 2:
        patch_shape = [1,] + patch_shape
        imgs = imgs.unsqueeze(-3)

    assert imgs.ndim - 2 == len(patch_shape)  # except batch and channel dims
    for i in range(1, len(patch_shape) + 1):
        assert (
            imgs.shape[-i] % patch_shape[-i] == 0
        ), f"image shape {imgs.shape} & patch shape {patch_shape} mismatch at index {i}"

    p = patch_shape[-3]
    q = patch_shape[-2]
    r = patch_shape[-1]
    t = imgs.shape[-3] // p  # temporality    
    h = imgs.shape[-2] // q  # height
    w = imgs.shape[-1] // r  # width
    x = imgs.reshape(shape=(imgs.shape[0], 3, t, p, h, q, w, r))
    x = torch.einsum("nctphqwr->nthwpqrc", x)
    x = x.reshape(shape=(imgs.shape[0], t * h * w, p * q * r, 3))

    return x

def unpatchify(imgs, patch_shape):
    """Our own function to reverse patchify.
    
    Adapted from https://github.com/facebookresearch/mae/blob/main/models_mae.py#L109.
    """
    p = patch_shape[-3]  # temporality
    q = patch_shape[-2]  # height
    r = patch_shape[-1]  # width
    
    # h = w = int(imgs.shape[1]**.5)  # height and width
    img_size = 224  # cheating here
    h = img_size // q
    w = img_size // r
    t = imgs.shape[1] // (h*w)

    x = imgs.reshape((imgs.shape[0], t, h, w, p, q, r, 3))
    x = torch.einsum("nthwpqrc->nctphqwr", x)
    x = x.reshape((imgs.shape[0], 3, t, p, h, q, w, r))
    x = x.reshape((imgs.shape[0], 3, t * p, h * q, w * r))

    return x

def convert_output(
    pred, mask, img, patch_shape, norm_pix_loss=True, norm_pix_per_channel=True, tfm_mean=IMAGENET_MEAN, 
    tfm_std=IMAGENET_STD
):
    """Our own function to convert the output of pretrained model to images.
    
    Adapted from omnivision.losses.mae_loss.MAELoss.compute_mae_loss().
    """
    # Duplicate image if needed (pred is already replicated)
    img = pad_imvid(img)
    
    # Reverse the global normalization of the input image
    img_mean = (
        torch.as_tensor(tfm_mean)
        .to(img.device)
        .reshape([1, -1] + [1] * (img.ndim - 2))
    )
    img_std = (
        torch.as_tensor(tfm_std)
        .to(img.device)
        .reshape([1, -1] + [1] * (img.ndim - 2))
    )
    img = img * img_std + img_mean         
    
    # The output of the model for a single image is a double image
    # so we replicate the true image
    img_shape = img.shape
    if len(img_shape) == 4:  # missing time dimension
        img = einops.repeat(img, "b c h w -> b c t h w", t=2).to(img.device)
    elif len(img_shape) == 5 and img_shape[2] == 1:  # single image to replicate
        img = img[:, :, 0, :, ...]
        img = einops.repeat(img, "b c h w -> b c t h w", t=2).to(img.device)

    # Squeeze back RGB channels from linear output
    pred = pred.reshape((*pred.shape[:-1], pred.shape[-1] // 3, 3))

    # Unnormalize predicted patches
    target = patchify(img, patch_shape)            

    patches_dim = -2
    if norm_pix_loss:
        if not norm_pix_per_channel:
            # Merge the channel with patches and compute mean
            # over all channels of all patches.
            # Else, will compute a mean for each channel separately
            target = torch.flatten(target, patches_dim)
            patches_dim = -1
        mean = target.mean(dim=patches_dim, keepdim=True)
        var = target.var(dim=patches_dim, keepdim=True)
        pred = (var**0.5) * pred + mean      
        
    # Unmasked patches have to be replaced by those from original image
    mask_flatten = mask.reshape(mask.shape[0], -1)     
    pred[mask_flatten] = torch.clone(target[mask_flatten])

    # Unpatchify the predicted images
    pred = unpatchify(pred, patch_shape)       

    return pred

# Inference with Images

First we'll load an image and use the OmniMAE model to classify it. 

## Load Model

We provide several pretrained OmniMAE models via manual download. Available models are described in [model zoo documentation](https://github.com/facebookresearch/omnivore/tree/main/omnimae).

Here we are selecting the base ViT model which was trained on Something Somethingv2 and Image-Net 1K and then finetuned on Image-Net 1K.

In [None]:
# Device on which to run the model
# Set to cuda to load on GPU
device = "cuda" if torch.cuda.is_available() else "cpu" 

# Pick a pretrained model 
from omnimae.omni_mae_model import vit_base_mae_finetune_in1k
model = vit_base_mae_finetune_in1k()

# Set to eval mode and move to desired device
model = model.to(device)
model = model.eval()

### Setup

Download the id to label mapping for the Imagenet1K dataset. This will be used to get the category label names from the predicted class ids.

In [None]:
!wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json -O data/imagenet_class_index.json

In [None]:
with open("data/imagenet_class_index.json", "r") as f:
    imagenet_classnames = json.load(f)

# Create an id to label name mapping
imagenet_id_to_classname = {}
for k, v in imagenet_classnames.items():
    imagenet_id_to_classname[k] = v[1] 

### Load and visualize the image

You can download the test image in the cell below or specify a path to your own image. Before passing the image into the model we need to apply some input transforms. 

In [None]:
# Download the example image file
!wget -O data/library.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/13-11-02-olb-by-RalfR-03.jpg/800px-13-11-02-olb-by-RalfR-03.jpg

In [None]:
image_path = "data/library.jpg"
image = Image.open(image_path).convert("RGB")
plt.figure(figsize=(6, 6))
plt.imshow(image)

In [None]:
image_transform = T.Compose(
    [
        T.Resize(224),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
image = image_transform(image)

# For images, the model expects inputs of shape: B x 3 x T x H x W
print(image.shape)
image = image[None, :, None, :, ...]
print(image.shape)

### Run the model 

The transformed image can be passed through the model to get class predictions

In [None]:
with torch.no_grad():
    
    prediction = model(image.to(device))
    pred_classes = prediction.topk(k=5).indices

pred_class_names = [imagenet_id_to_classname[str(i.item())] for i in pred_classes[0]]
print("Top 5 predicted labels: %s" % ", ".join(pred_class_names))

# Inference with Videos

Now we'll see how to use the OmniMAE model to classify a video. 

## Load Model

We provide several pretrained OmniMAE models via manual download. Available models are described in [model zoo documentation](https://github.com/facebookresearch/omnivore/tree/main/omnimae).

Here we are selecting the base ViT model which was trained on Something Somethingv2 and Image-Net 1K and then finetuned on SSv2 for classification.

In [None]:
# Device on which to run the model
# Set to cuda to load on GPU
device = "cuda" if torch.cuda.is_available() else "cpu" 

# Pick a pretrained model 
from omnimae.omni_mae_model import vit_base_mae_finetune_ssv2
model = vit_base_mae_finetune_ssv2()

# Set to eval mode and move to desired device
model = model.to(device)
model = model.eval()

### Setup 

Download the id to label mapping for the [Something Something v2 dataset](https://developer.qualcomm.com/software/ai-datasets/something-something) and put them under `data/ssv2_labels`. 

This will be used to get the category label names from the predicted class ids.

In [None]:
with open("data/ssv2_labels/labels.json", "r") as f:
    ssv2_classnames = json.load(f)

# Create an id to label name mapping
ssv2_id_to_classname = {}
for k, v in ssv2_classnames.items():
    ssv2_id_to_classname[int(v)] = str(k).replace('"', "")

### Define the transformations for the input required by the model

Before passing the video into the model we need to apply some input transforms and sample a clip of the correct duration.

**Remark**: These are the transformations from Omnivore model. They could be (and actually, are) different for OmniMAE. However, for the purpose of this demo, this is enough.

In [None]:
num_frames = 160
sampling_rate = 2
frames_per_second = 30

video_transform = ApplyTransformToKey(
    key="video",
    transform=T.Compose(
        [
            UniformTemporalSubsample(num_frames), 
            T.Lambda(lambda x: x / 255.0),  
            ShortSideScale(size=224),
            NormalizeVideo(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            TemporalCrop(frames_per_clip=32, stride=40),
            SpatialCrop(crop_size=224, num_crops=3),
        ]
    ),
)

### Load and visualize an example video

We can test the classification of an example video from the kinetics validation set such as this [archery video](https://www.youtube.com/watch?v=3and4vWkW4s).

Otherwise, you can download videos from SSv2 dataset and select one (as 74225.webm here).

In [None]:
# Download the example video file
# !wget https://dl.fbaipublicfiles.com/omnivore/example_data/dance.mp4 -O data/dance.mp4

In [None]:
# Load the example video
# video_path = "data/dance.mp4" 
video_path = "data/74225.webm"

Video.from_file(video_path, width=500)

In [None]:
# We crop the video to a smaller resolution and duration to save RAM
# !ffmpeg -y -ss 12 -i data/dance.mp4 -filter:v scale=224:-1 -t 1 -v 0 data/dance_cropped.mp4
!ffmpeg -y -ss 0 -i data/74225.webm -filter:v scale=224:-1 -t 2 -v 0 data/74225_cropped.webm

# video_path = "data/dance_cropped.mp4" 
video_path = "data/74225_cropped.webm"

Video.from_file(video_path, width=500)

In [None]:
# Initialize an EncodedVideo helper class
video = EncodedVideo.from_path(video_path)

# Load the desired clip and specify the start and end duration.
# The start_sec should correspond to where the action occurs in the video
video_data = video.get_clip(start_sec=0, end_sec=2.0)

# Apply a transform to normalize the video input
video_data = video_transform(video_data)

# Move the inputs to the desired device
video_inputs = video_data["video"]

# Take the first clip 
# The model expects inputs of shape: B x C x T x H x W
print(video_inputs[0].shape)
video_input = video_inputs[0][None, ...]
print(video_input.shape)

### Get model predictions

In [None]:
# Pass the input clip through the model 
with torch.no_grad():
    
    prediction = model(video_input.to(device))

    # Get the predicted classes 
    pred_classes = prediction.topk(k=5).indices

# Map the predicted classes to the label names
pred_class_names = [ssv2_id_to_classname[int(i)] for i in pred_classes[0]]
print("Top 5 predicted labels: %s" % ", ".join(pred_class_names))

# Reconstruction of Images

First we'll load an image and use the OmniMAE model to reconstruct images. 

## Load Model

We provide several pretrained OmniMAE models via manual download. Available models are described in [model zoo documentation](https://github.com/facebookresearch/omnivore/tree/main/omnimae).

Here we are selecting the base ViT model which was trained on Something Somethingv2 and Image-Net 1K.

In [None]:
# Device on which to run the model
# Set to cuda to load on GPU
device = "cuda" if torch.cuda.is_available() else "cpu" 

# Pick a pretrained model 
from omnimae.omni_mae_model import vit_base_mae_pretraining
model = vit_base_mae_pretraining()

# Set to eval mode and move to desired device>
model = model.to(device)
model = model.eval()

### Load and visualize the image

You can download the test image in the cell below or specify a path to your own image. Before passing the image into the model we need to apply some input transforms. 

In [None]:
# Download the example image file
!wget -O data/library.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/13-11-02-olb-by-RalfR-03.jpg/800px-13-11-02-olb-by-RalfR-03.jpg

In [None]:
image_path = "data/library.jpg"
image = Image.open(image_path).convert("RGB")
plt.figure(figsize=(6, 6))
plt.imshow(image)

In [None]:
image_transform = T.Compose(
    [
        T.Resize(224),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
image = image_transform(image)

# For images, the model expects inputs of shape: B x 3 x T x H x W
print(image.shape)
# image = image[None, :, None, :, ...]
image = einops.repeat(image, "c h w -> b c t h w", b=1, t=1)
print(image.shape)

# Move to GPU
image = image.to(device)

In [None]:
# Create a mask on the image (play with the proportion to see the effect!)
proportion_of_mask = 0.9

# Mask is of hape [N, patch_layout] where patches are of shape 2x16x16
# so here mask is of shape 2//2 x 224//16 x 224//16
mask = torch.empty(1, 1, 14, 14, dtype=torch.bool).bernoulli_(1-proportion_of_mask)

# Move to GPU
mask = mask.to(device)

### Run the model 

The transformed image can be passed through the model to get a reconstruction

In [None]:
with torch.no_grad():
    
    cls_token, decoder_patch_features = model.trunk(image, mask=mask)
    outcome = model.head(decoder_patch_features)

In [None]:
patch_shape = [2, 16, 16]

pred_imgs = convert_output(outcome, mask, image, patch_shape)

In [None]:
print(pred_imgs.shape)
pred_single = pred_imgs[0, :, 0].detach().cpu()

show_image(pred_single, rescale=False)

In [None]:
image_single = image[0, :, 0].detach().cpu()

show_image(image_single, rescale=True)

# Reconstruction of Videos

Now, we'll load a video and use the OmniMAE model to reconstruct videos. 

## Load Model

We provide several pretrained OmniMAE models via manual download. Available models are described in [model zoo documentation](https://github.com/facebookresearch/omnivore/tree/main/omnimae).

Here we are selecting the base ViT model which was trained on Something Somethingv2 and Image-Net 1K.

In [None]:
# Device on which to run the model
# Set to cuda to load on GPU
device = "cuda" if torch.cuda.is_available() else "cpu" 

# Pick a pretrained model 
from omnimae.omni_mae_model import vit_base_mae_pretraining
model = vit_base_mae_pretraining()

# Set to eval mode and move to desired device
model = model.to(device)
model = model.eval()

### Define the transformations for the input required by the model

Before passing the video into the model we need to apply some input transforms and sample a clip of the correct duration.

**Remark**: These are the transformations from Omnivore model. They could be (and actually, are) different for OmniMAE. However, for the purpose of this demo, this is enough.

In [None]:
num_frames = 160
sampling_rate = 2
frames_per_second = 30

# clip_duration = (num_frames * sampling_rate) / frames_per_second

video_transform = ApplyTransformToKey(
    key="video",
    transform=T.Compose(
        [
            UniformTemporalSubsample(num_frames), 
            T.Lambda(lambda x: x / 255.0),  
            ShortSideScale(size=224),
            NormalizeVideo(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            TemporalCrop(frames_per_clip=32, stride=40),
            SpatialCrop(crop_size=224, num_crops=3),
        ]
    ),
)

### Load and visualize an example video
We can test the classification of an example video from the kinetics validation set such as this [archery video](https://www.youtube.com/watch?v=3and4vWkW4s).

**Remark**: this is not a video from SSv2, but it allows us to simply test the inference with OmniMAE.

In [None]:
# Download the example video file
!wget https://dl.fbaipublicfiles.com/omnivore/example_data/dance.mp4 -O data/dance.mp4

In [None]:
# Load the example video
video_path = "data/dance.mp4" 
# video_path = "data/74225.webm"

Video.from_file(video_path, width=500)

In [None]:
# We crop the video to a smaller resolution and duration to save RAM
!ffmpeg -y -ss 12 -i data/dance.mp4 -filter:v scale=224:-1 -t 1 -v 0 data/dance_cropped.mp4
# !ffmpeg -y -ss 0 -i data/74225.webm -filter:v scale=224:-1 -t 2 -v 0 data/74225_cropped.webm

video_path = "data/dance_cropped.mp4" 
# video_path = "data/74225_cropped.webm"

Video.from_file(video_path, width=500)

In [None]:
# Initialize an EncodedVideo helper class
video = EncodedVideo.from_path(video_path)

# Load the desired clip and specify the start and end duration.
# The start_sec should correspond to where the action occurs in the video
video_data = video.get_clip(start_sec=0, end_sec=2.0)

# Apply a transform to normalize the video input
video_data = video_transform(video_data)

# Move the inputs to the desired device
video_inputs = video_data["video"]

# Take the first clip 
# The model expects inputs of shape: B x C x T x H x W
print(video_inputs[0].shape)
video_input = video_inputs[0][None, ...]
print(video_input.shape)

# Move to GPU
video_input = video_input.to(device)

In [None]:
# Create a mask on the image (play with the proportion to see the effect!)
proportion_of_mask = 0.9

# Mask is of hape [N, patch_layout] where patches are of shape 2x16x16
# so here mask is of shape 32//2 x 224//16 x 224//16
mask = torch.empty(1, 16, 14, 14, dtype=torch.bool).bernoulli_(1-proportion_of_mask)

# Move to GPU
mask = mask.to(device)

### Run the model 

The transformed image can be passed through the model to get a reconstruction

In [None]:
with torch.no_grad():
    
    cls_token, decoder_patch_features = model.trunk(video_input, mask=mask)
    outcome = model.head(decoder_patch_features)

In [None]:
patch_shape = [2, 16, 16]

pred_imgs = convert_output(outcome, mask, video_input.to(device), patch_shape)

In [None]:
print(pred_imgs.shape)
pred_single = pred_imgs[0, :, -1].detach().cpu()

show_image(pred_single, rescale=False)

In [None]:
video_single = video_input[0, :, -1].detach().cpu()

show_image(video_single, rescale=True)

--------------------------------------