### Global imports

In [None]:
import sys
import os
from random import randint
import skimage.color
from packages.video_utils import H264Extractor, Video, Gop
from packages.constants import GOP_SIZE, FRAME_HEIGHT, FRAME_WIDTH, DATASET_ROOT, MACROBLOCK_SIZE
from packages.dataset import VisionDataset, VisionGOPDataset
from packages.common import create_custom_logger

In [None]:
if not os.path.exists(DATASET_ROOT):
    raise Exception(f'Dataset root does not exist: {DATASET_ROOT}')

log = create_custom_logger('h4vdm.ipynb')

Remember to delete dataset.json if you want to add new devices/videos

In [None]:
bin_path = os.path.abspath(os.path.join(os.getcwd(), 'h264-extractor', 'bin'))
h264_ext_bin = os.path.join(bin_path, 'h264dec_ext_info')
h264_extractor = H264Extractor(bin_filename=h264_ext_bin, cache_dir=DATASET_ROOT)
Video.set_h264_extractor(h264_extractor)

dataset = VisionGOPDataset(
    root_path=DATASET_ROOT,
    devices=[],
    media_types = ['videos'],
    properties=['flat'],
    extensions=['mp4'],
    gop_size=GOP_SIZE,
    frame_width=FRAME_WIDTH,
    frame_height=FRAME_HEIGHT,
    gops_per_video=4,
    build_on_init=False,
    force_rebuild=False,
    download_on_init=False,
    ignore_local_dataset=False,
    shuffle=False)

is_loaded = dataset.load()
if not is_loaded:
    log.info('Dataset was not loaded. Building...')
else:
    log.info('Dataset was loaded.')

print(f'Dataset length: {len(dataset)}')

In [None]:
# for i in range(0, len(dataset)):
#     video = dataset[i]
#     print(f'{i} - {video.name}')

In [None]:
# for i in range(0, len(dataset)):
#     try:
#         video = dataset[i]
#         # print(f'Video {i} GOPs: {len(video.get_gops())}')
#     except Exception as e:
#         log.error(f'Error while processing video {i}: {e}')
#         raise e


#     gops = video.get_gops()

#     for gop in gops:
#         print(f'Number of GOPs in video: {len(gops)}')
#         print(f'Frame 0 shape: {gop.get_intra_frame().shape}')
#         for i, inter_frame in enumerate(gop.get_inter_frames()):
#             print(f'Frame {i + 1} shape: {inter_frame.shape}')
#         print(f'Frame types {gop.get_frame_types()}')
#         print(f'Macroblock types shape: {len(gop.get_macroblock_images()[0])} x {gop.get_macroblock_images()[0].shape}')
#         print(f'Luma QPs {len(gop.get_luma_qp_images()[0])}')
# dataset.save('test_dataset_after.json')

In [None]:
# import matplotlib.pyplot as plt
# plt.rcParams['figure.dpi'] = 120

# fig = plt.figure(figsize=(15, 15))

# gop = gops[0]

# for frame_id in range(gop.get_first_frame_number(), gop.get_first_frame_number() + len(gop)):
#     fig.add_subplot(len(gop)//2, len(gop)//2, frame_id + 1)
#     plt.imshow(gop.get_rgb_frame(frame_id))
#     plt.title('Frame {}'.format(frame_id))

Build all GOPs so that cache can be cleaned

In [None]:
# for device in dataset.get_devices():
#     for video_metadata in dataset.dataset[device]:
#         video = dataset._get_video_from_metadata(video_metadata)
#         gops = video.get_gops()

#         Video.h264_extractor.clean_cache()
#         video = None
#         gops = None

### Network

JAN INPUT SIZE = embed_dim = 4 * 8 + 5 = 37
num_heads = HAN_N_HEADS

In [None]:
from packages.network import H4vdmNet
logger = create_custom_logger('h4vdm.ipynb')
from math import tanh, sqrt, log
import torch
from random import random

def compute_similarity(gop1_features, gop2_features):
    return 1 - tanh(torch.norm(gop1_features - gop2_features, 2))

compute_loss = torch.nn.CrossEntropyLoss()

training_dataset_paths = dataset._build_gop_pair_dataset()

print(f'Training dataset length: {len(training_dataset_paths)}')
for gop1_path, gop2_path, label in training_dataset_paths:
    gop1 = Gop.load(gop1_path, None)
    gop2 = Gop.load(gop2_path, None)
    logger.info(f'Gop1: {gop1.video_name} Gop2: {gop2.video_name} Label: {label}')


# net = H4vdmNet()
# optimizer = torch.optim.Adam(net.parameters(), lr=0.001)



# for gop1, gop2, label in training_dataset:
#     label = 1 if label else 0

#     gop1_features = net(gop1)
#     gop2_features = net(gop2)

#     optimizer.zero_grad()

#     similarity = compute_similarity(gop1_features, gop2_features)
#     similarity = torch.tensor(similarity)
#     label = torch.tensor(label, dtype=float)
#     loss = compute_loss(similarity, label)
#     loss.bakcward()

#     optimizer.step()

#     print(f'Loss: {loss.item()}')

    

# # net.forward(gop, debug=True)