In [1]:
import subprocess
from pathlib import Path

import torch
import torchaudio
import torchvision
from omegaconf import OmegaConf

from dataset.dataset_utils import get_video_and_audio
from dataset.transforms import make_class_grid, quantize_offset
from utils.utils import check_if_file_exists_else_download, which_ffmpeg
from scripts.train_utils import get_model, get_transforms, prepare_inputs


def reencode_video(path, vfps=25, afps=16000, in_size=256):
    assert which_ffmpeg() != '', 'Is ffmpeg installed? Check if the conda environment is activated.'
    new_path = Path.cwd() / 'vis' / f'{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4'
    new_path.parent.mkdir(exist_ok=True)
    new_path = str(new_path)
    cmd = f'{which_ffmpeg()}'
    # no info/error printing
    cmd += ' -hide_banner -loglevel panic'
    cmd += f' -y -i {path}'
    # 1) change fps, 2) resize: min(H,W)=MIN_SIDE (vertical vids are supported), 3) change audio framerate
    cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2"
    cmd += f" -ar {afps}"
    cmd += f' {new_path}'
    subprocess.call(cmd.split())
    cmd = f'{which_ffmpeg()}'
    cmd += ' -hide_banner -loglevel panic'
    cmd += f' -y -i {new_path}'
    cmd += f' -acodec pcm_s16le -ac 1'
    cmd += f' {new_path.replace(".mp4", ".wav")}'
    subprocess.call(cmd.split())
    return new_path


def decode_single_video_prediction(off_logits, grid, item):
    label = item['targets']['offset_label'].item()
    print('Ground Truth offset (sec):', f'{label:.2f} ({quantize_offset(grid, label)[-1].item()})')
    print('Prediction Results:')
    off_probs = torch.softmax(off_logits, dim=-1)
    k = min(off_probs.shape[-1], 5)
    topk_logits, topk_preds = torch.topk(off_logits, k)
    # remove batch dimension
    assert len(topk_logits) == 1, 'batch is larger than 1'
    topk_logits = topk_logits[0]
    topk_preds = topk_preds[0]
    off_logits = off_logits[0]
    off_probs = off_probs[0]
    for target_hat in topk_preds:
        print(
            f'p={off_probs[target_hat]:.4f} ({off_logits[target_hat]:.4f}), "{grid[target_hat]:.2f}" ({target_hat})')
    return off_probs


def patch_config(cfg):
    # the FE ckpts are already in the model ckpt
    cfg.model.params.afeat_extractor.params.ckpt_path = None
    cfg.model.params.vfeat_extractor.params.ckpt_path = None
    # old checkpoints have different names
    cfg.model.params.transformer.target = cfg.model.params.transformer.target\
                                             .replace('.modules.feature_selector.', '.sync_model.')
    return cfg


In [2]:
vfps = 25
afps = 16000
in_size = 256
exp_name = '24-01-04T16-39-21'

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load the model
cfg_path = f'./logs/sync_models/{exp_name}/cfg-{exp_name}.yaml'
ckpt_path = f'./logs/sync_models/{exp_name}/{exp_name}.pt'

# if the model does not exist try to download it from the server
check_if_file_exists_else_download(cfg_path)
check_if_file_exists_else_download(ckpt_path)

# load config
cfg = OmegaConf.load(cfg_path)

# patch config
cfg = patch_config(cfg)

_, model = get_model(cfg, device)
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
model.load_state_dict(ckpt['model'])
model.eval()
print('Model loaded.')

Model loaded.


In [4]:
# list of items to process. Mind the order: (video_path, offset_sec, v_start_i_sec)
to_process = [
    ('./data/vggsound/h264_video_25fps_256side_16000hz_aac/3qesirWAGt4_20000_30000.mp4', 1.6, 0.0),
    ('./data/vggsound/h264_video_25fps_256side_16000hz_aac/ZYc410CE4Rg_0_10000.mp4', -2.0, 4.0),
]

In [5]:
for vid_path, offset_sec, v_start_i_sec in to_process:
    # (optional) checking if the provided video has the correct frame rates
    print(f'Using video: {vid_path}')
    v, _, info = torchvision.io.read_video(vid_path, pts_unit='sec')
    _, H, W, _ = v.shape
    if info['video_fps'] != vfps or info['audio_fps'] != afps or min(H, W) != in_size:
        print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=' ')
        print(f'afps: {info["audio_fps"]} -> {afps};', end=' ')
        print(f'{(H, W)} -> min(H, W)={in_size}')
        vid_path = reencode_video(vid_path, vfps, afps, in_size)
    else:
        print(
            f'No need to reencode: vfps: {info["video_fps"]}; afps: {info["audio_fps"]}; min(H, W)={in_size}')

    # load visual and audio streams
    # rgb: (Tv, 3, H, W) in [0, 225], audio: (Ta,) in [-1, 1]
    rgb, audio, meta = get_video_and_audio(vid_path, get_meta=True)

    # making an item (dict) to apply transformations
    # NOTE: here is how it works:
    # For instance, if the model is trained on 5sec clips, the provided video is 9sec, and `v_start_i_sec=1.3`
    # the transform will crop out a 5sec-clip from 1.3 to 6.3 seconds and shift the start of the audio
    # track by `offset_sec` seconds. It means that if `offset_sec` > 0, the audio will
    # start by `offset_sec` earlier than the rgb track.
    # It is a good idea to use something in [-`max_off_sec`, `max_off_sec`] (-2, +2) seconds (see `grid`)
    item = dict(
        video=rgb, audio=audio, meta=meta, path=vid_path, split='test',
        targets={'v_start_i_sec': v_start_i_sec, 'offset_sec': offset_sec, },
    )

    # making the offset class grid similar to the one used in transforms
    max_off_sec = cfg.data.max_off_sec
    num_cls = cfg.model.params.transformer.params.off_head_cfg.params.out_features
    grid = make_class_grid(-max_off_sec, max_off_sec, num_cls)
    if not (min(grid) <= item['targets']['offset_sec'] <= max(grid)):
        print(f'WARNING: offset_sec={item["targets"]["offset_sec"]} is outside the trained grid: {grid}')

    # applying the test-time transform
    item = get_transforms(cfg, ['test'])['test'](item)

    # prepare inputs for inference
    batch = torch.utils.data.default_collate([item])
    aud, vid, targets = prepare_inputs(batch, device)

    # TODO:
    # sanity check: we will take the input to the `model` and recontruct make a video from it.
    # Use this check to make sure the input makes sense (audio should be ok but shifted as you specified)
    # reconstruct_video_from_input(aud, vid, batch['meta'], vid_path, v_start_i_sec, offset_sec,
    #                              vfps, afps)

    # forward pass
    with torch.set_grad_enabled(False):
        with torch.autocast('cuda', enabled=cfg.training.use_half_precision):
            _, logits = model(vid, aud)

    # simply prints the results of the prediction
    decode_single_video_prediction(logits, grid, item)
    print()

Using video: ./data/vggsound/h264_video_25fps_256side_16000hz_aac/3qesirWAGt4_20000_30000.mp4
No need to reencode: vfps: 25.0; afps: 16000; min(H, W)=256
Ground Truth offset (sec): 1.60 (18)
Prediction Results:
p=0.8076 (11.5469), "1.60" (18)
p=0.1760 (10.0234), "1.80" (19)
p=0.0067 (6.7617), "-0.40" (8)
p=0.0042 (6.2891), "1.40" (17)
p=0.0033 (6.0586), "2.00" (20)

Using video: ./data/vggsound/h264_video_25fps_256side_16000hz_aac/ZYc410CE4Rg_0_10000.mp4
No need to reencode: vfps: 25.0; afps: 16000; min(H, W)=256
Ground Truth offset (sec): -2.00 (0)
Prediction Results:
p=0.8291 (12.7734), "-2.00" (0)
p=0.1194 (10.8359), "-1.80" (1)
p=0.0419 (9.7891), "-1.60" (2)
p=0.0072 (8.0234), "-1.40" (3)
p=0.0013 (6.2969), "-1.20" (4)

