### Global imports

In [None]:
import os
import torch
from packages.video_utils import H264Extractor, Video
from packages.constants import GOP_SIZE, FRAME_HEIGHT, FRAME_WIDTH, DATASET_ROOT, N_GOPS_FROM_DIFFERENT_DEVICE, N_GOPS_FROM_SAME_DEVICE
from packages.dataset import VisionGOPDataset, GopPairDataset
from packages.common import create_custom_logger
from packages.network import H4vdmNet

In [None]:
if not os.path.exists(DATASET_ROOT):
    raise Exception(f'Dataset root does not exist: {DATASET_ROOT}')

log = create_custom_logger('h4vdm.ipynb')

Remember to delete dataset.json if you want to add new devices/videos

In [None]:
bin_path = os.path.abspath(os.path.join(os.getcwd(), 'h264-extractor', 'bin'))
h264_ext_bin = os.path.join(bin_path, 'h264dec_ext_info')
h264_extractor = H264Extractor(bin_filename=h264_ext_bin, cache_dir=DATASET_ROOT)
Video.set_h264_extractor(h264_extractor)

dataset = VisionGOPDataset(
    root_path=DATASET_ROOT,
    devices=[],
    media_types = ['videos'],
    properties=['flat'],
    extensions=['mp4'],
    gop_size=GOP_SIZE,
    frame_width=FRAME_WIDTH,
    frame_height=FRAME_HEIGHT,
    gops_per_video=4,
    build_on_init=False,
    force_rebuild=False,
    download_on_init=False,
    ignore_local_dataset=False,
    shuffle=False)

is_loaded = dataset.load()
if not is_loaded:
    log.info('Dataset was not loaded. Building...')
else:
    log.info('Dataset was loaded.')

print(f'Dataset length: {len(dataset)}')

In [None]:
pair_dataset = GopPairDataset(dataset, N_GOPS_FROM_SAME_DEVICE, N_GOPS_FROM_DIFFERENT_DEVICE, shuffle=True)

Build all GOPs so that cache can be cleaned

In [None]:
# for device in dataset.get_devices():
#     for video_metadata in dataset.dataset[device]:
#         video = dataset._get_video_from_metadata(video_metadata)
#         gops = video.get_gops()

#         Video.h264_extractor.clean_cache()
#         video = None
#         gops = None

### Network

In [None]:
logger = create_custom_logger('h4vdm.ipynb')

def compute_similarity(gop1_features, gop2_features):
    diff = gop1_features - gop2_features
    norm = torch.norm(diff, 2)
    tan = torch.tanh(norm)
    return (1 - tan)

compute_loss = torch.nn.BCELoss()

In [None]:
net = H4vdmNet()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

tot = len(pair_dataset)
print(f'Training dataset length: {tot}')
optimizer.zero_grad()
for i in range(0, len(pair_dataset)):
    gop1, gop2, label = pair_dataset[i]
    print(f'Iteration {i}/{tot} - Gop1: {gop1.video_name} Gop2: {gop2.video_name} Label: {label}')

    gop1_features = net(gop1, debug=False)
    gop2_features = net(gop2, debug=False)

    similarity = compute_similarity(gop1_features, gop2_features)

    similarity = torch.tensor([similarity, 1-similarity], dtype=float, requires_grad=True)
    label = torch.tensor([label, 1-label], dtype=float, requires_grad=True)
    loss = compute_loss(similarity, label)
    loss.backward()

    optimizer.step()

    print(f'Loss: {loss.item()}')
    print('\n')
    i += 1

print('Done')
