### Global imports

In [3]:
import sys
import os
import subprocess
import struct
import skimage.transform, skimage.color
import numpy as np

### Local imports

In [4]:
project_path = !pwd
project_path = project_path[0]

# cd to the h264-extractor folder
lib_path = os.path.abspath(os.path.join(project_path, 'h264-extractor', 'openh264', 'info_shipout'))
if lib_path not in sys.path:
    sys.path.append(lib_path)

from slice_pb2 import Slice
from slice_pb2 import SliceType

### Hyperparameters

In [5]:
GOP_SIZE = 10

VIT1_DEPTH = 8
VIT1_PROJECTION_DIMENSION = 256
VIT1_NUM_MSA_HEADS = 8
VIT1_OUTPUT_DIMENSION = 256
VIT1_PATCH_SIZE = 16

VIT2_DEPTH = 4
VIT2_PROJECTION_DIMENSION = 64
VIT2_NUM_MSA_HEADS = 4
VIT2_OUTPUT_DIMENSION = 256 
VIT2_PATCH_SIZE = 16

In [6]:
class H264Extractor():
    def __init__(self, bin_filename, cache_dir):
        if not os.path.exists(bin_filename):
            raise FileNotFoundError(f'cannot locate the binary file "{bin_filename}", was it built?')
        self.bin_filename = bin_filename

        if not os.path.exists(cache_dir):
            os.makedirs(os.path.join(cache_dir))
        self.cache_dir = cache_dir


    def convert_to_h264(self, video_filename):
        """converts a given mp4 video file to an h264 annex b file.
        The h264 file will be saved in the cache folder provided in the constructor

        Args:
            video_filename (str): full path to the mp4 video file

        Raises:
            FileNotFoundError: if the h264 is not generated correctly

        Returns:
            str: full path to the h264 file
        """
        
        video_name = os.path.basename(video_filename).split('.')[0]

        h264_filename = os.path.join(self.cache_dir, video_name + '.h264')

        # extract h264 from the mp4 file using ffmpeg
        cp = subprocess.run(
                ['ffmpeg', '-y', '-i', video_filename, '-vcodec', 'copy', '-an', 
                '-bsf:v', 'h264_mp4toannexb', h264_filename],
                check=True
            )
        
        if not os.path.exists(h264_filename):
            raise FileNotFoundError(f'cannot locate the h264 file "{h264_filename}", it probably hasn\'t been generated')

        return h264_filename
    
    def extract_yuv_and_codes(self, h264_filename):
        # compute the filenames
        video_name = os.path.basename(h264_filename).split('.')[0]
        yuv_filename = os.path.join(self.cache_dir, video_name+ '.yuv') # YUV frames
        coded_data_filename = os.path.join(self.cache_dir, video_name + '.msg') # encoding parameters

        # run the extractor to get the yuv and coded data files
        cp = subprocess.run(
                [self.bin_filename, h264_filename, '--yuv_out', yuv_filename, '--info_out', coded_data_filename, '--n_threads', '0'],
                # for now, only setting threads to 0 is allowed; using other values can result in 
                # unexpected behaviors
                check=True
            )
        
        # remove the h264 file since it's not needed anymore
        # os.remove(h264_filename)
        # if os.path.exists(h264_filename):
        #     print(f'WARNING: could not remove the h264 file "{h264_filename}"')

        # raise exception if the files haven't been generated
        if not os.path.exists(yuv_filename) or not os.path.exists(coded_data_filename):
            raise FileNotFoundError(f'cannot locate the yuv and/or coded data files, they probably haven\'t been generated')
        return (yuv_filename, coded_data_filename)
    
    def clean_cache(self):
        if os.path.exists(self.cache_dir):
            for files in os.listdir(self.cache_dir):
                os.remove(self.cache_dir, files)
            os.rmdir(self.cache_dir)

In [7]:
class VideoHandler():   
    def __init__(self, video_filename, h264_filename, yuv_filename, coded_data_filename):
        if not os.path.exists(video_filename):
            raise FileNotFoundError(f'cannot locate the video file "{video_filename}"')
        # Original file properties
        self.filename = video_filename
        self.path = os.path.dirname(video_filename)
        self.name = os.path.basename(video_filename).split('.')[0]
        self.extension = os.path.basename(video_filename).split('.')[1]

        # Elaborated files properties
        if not os.path.exists(h264_filename):
            raise FileNotFoundError(f'cannot locate the h264 file "{h264_filename}"')
        self.h264_filename = h264_filename

        if not os.path.exists(yuv_filename):
            raise FileNotFoundError(f'cannot locate the yuv file "{yuv_filename}"')
        self.yuv_filename = yuv_filename

        if not os.path.exists(coded_data_filename):
            raise FileNotFoundError(f'cannot locate the coded data file "{coded_data_filename}"')
        self.coded_data_filename = coded_data_filename
    
    def get_rgb_frame(self, frame_number, width, height):
        # In YUV420 format, each pixel of the Y (luma) component is represented by 1 byte, while the U and V (chroma) components are subsampled, so each of them is represented by 0.25 bytes. Hence, the total size is (width * height * 1.5).
        width = 1280
        height = 720
        frame_size = int(width * height * 1.5)
        
        # Color space conversion constants
        U_MAX = 0.436
        V_MAX = 0.615

        with open(self.yuv_filename, 'rb') as yuv_sequence:
            # Read the frame at the specified frame number
            yuv_sequence.seek(frame_number * frame_size)
            
            y = np.frombuffer(yuv_sequence.read(width * height), dtype=np.uint8).reshape((height, width))
            u = np.frombuffer(yuv_sequence.read(width * height // 4), dtype=np.uint8).reshape((height // 2, width // 2))
            v = np.frombuffer(yuv_sequence.read(width * height // 4), dtype=np.uint8).reshape((height // 2, width // 2))

            # Rescale subsampled chroma components to the same size as the luma component
            y = skimage.img_as_float32(y)
            u = skimage.transform.rescale(u, 2.0, 1, anti_aliasing=False)
            v = skimage.transform.rescale(v, 2.0, 1, anti_aliasing=False)

            # Color space conversion
            u = (u * 2 * U_MAX) - U_MAX
            v = (v * 2 * V_MAX) - V_MAX

            # Convert to RGB
            yuv_frame = np.dstack([y, u, v])
            rgb_frame = skimage.color.yuv2rgb(yuv_frame)
        return rgb_frame
    
    def get_gop(self, gop_length: int, width: int = 0, height: int = 0):
        return Gop(self.filename, self.h264_filename, self.yuv_filename, self.coded_data_filename).extract_gop(gop_length, width, height)   

In [8]:
class Gop(VideoHandler):
    def __init__(self, video_filename, h264_filename, yuv_filename, coded_data_filename):
        super().__init__(video_filename, h264_filename, yuv_filename, coded_data_filename)

        self.gop = []
        self.length = 0

        self.intra_frame = None
        self.inter_frames = []
        self.frame_types = []
        self.mb_types = []
        self.luma_qps = []

    def _get_ep_file_iterator(self):
        with open(self.coded_data_filename, 'rb') as file:
            file_size = os.stat(self.coded_data_filename).st_size
            while file.tell() < file_size:
                length_bytes = file.read(4)
                # Interpret data as little-endian unsigned int to convert from C layer to Python
                length = struct.unpack('<I', length_bytes)[0]
                yield file.read(length)

    def _get_slice_iterator(self):
        iterator = self._get_ep_file_iterator()
        for bytes in iterator:
            slice = Slice()
            slice.ParseFromString(bytes)
            yield slice

    def _extract_features(self):
        if self.length == 0 or self.gop is None:
            raise ValueError('GOP not extracted yet')
        
        for slice, frame_number in self.gop:
            self.frame_types.append(slice.type)
            if slice.type == SliceType.I:
                self.intra_frame = self.get_rgb_frame(frame_number, slice.width, slice.height)
                # include difference between I frame and itself (zeros)
                self.inter_frames.append(self.intra_frame - self.intra_frame)
            else:
                self.inter_frames.append(self.get_rgb_frame(frame_number, slice.width, slice.height) - self.intra_frame) # abs()?
            for mb in slice.mbs:
                self.mb_types.append(mb.type)
                self.luma_qps.append(mb.luma_qp)

        return self

    def extract_gop(self, target_length: int, width: int = 0, height: int = 0) -> list:
        # TODO: how to crop?
        slice_iterator = self._get_slice_iterator()

        slice = next(slice_iterator)
        frame_index = 0
        if slice.type != SliceType.I:
            # The first slice is expected to be of type Intra. Find next I slice to start gop
            while(slice.type != SliceType.I):
                try:
                    slice = next(slice_iterator)
                    frame_index += 1
                except StopIteration:
                    raise ValueError('No Intra slice found')

        self.gop.append((slice, frame_index))

        while len(self.gop) < target_length:
            try:
                slice = next(slice_iterator)
                frame_index += 1
                if slice.type == SliceType.I:
                    # GOP is over
                    break
                else:
                    self.gop.append((slice, frame_index))
            except StopIteration:
                if len(self.gop) < target_length:
                    raise ValueError(f'Unable to reach desired GOP length of {target_length}, actual gop length is {len(self.gop)}')
                else:
                    break
                
        self.length = len(self.gop)
        self._extract_features()

        return self

In [9]:
bin_path = os.path.abspath(os.path.join(project_path, 'h264-extractor', 'bin'))
h264_ext_bin = os.path.join(bin_path, 'h264dec_ext_info')

extractor = H264Extractor(h264_ext_bin, os.path.join(project_path, '.cache'))

video_filename = os.path.join(project_path, 'bunny.mp4')

h264_filename = extractor.convert_to_h264(video_filename)
yuv_filename, coded_data_filename = extractor.extract_yuv_and_codes(h264_filename)

ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enab

------------------------------------------------------


In [10]:
video_handler = VideoHandler(video_filename, h264_filename, yuv_filename, coded_data_filename)


gop = video_handler.get_gop(GOP_SIZE)

print(f'GOP length: {gop.length}')
print(gop.intra_frame.shape)
for i, inter_frame in enumerate(gop.inter_frames):
    print(f'Frame {i} shape: {inter_frame.shape}')
print(f'Frame types {gop.frame_types}')
print(f'Macroblock types {gop.mb_types}')
print(f'Luma QPs {gop.luma_qps}')

GOP length: 10
(720, 1280, 3)
Frame 0 shape: (720, 1280, 3)
Frame 1 shape: (720, 1280, 3)
Frame 2 shape: (720, 1280, 3)
Frame 3 shape: (720, 1280, 3)
Frame 4 shape: (720, 1280, 3)
Frame 5 shape: (720, 1280, 3)
Frame 6 shape: (720, 1280, 3)
Frame 7 shape: (720, 1280, 3)
Frame 8 shape: (720, 1280, 3)
Frame 9 shape: (720, 1280, 3)
Frame types [0, 1, 2, 2, 2, 1, 2, 2, 2, 1]
Macroblock types [4, 4, 4, 4, 4, 4, 1, 1, 1, 4, 1, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 4, 1, 4, 4, 4, 4, 4, 4, 4, 4, 1, 4, 1, 4, 4, 4, 1, 1, 4, 4, 4, 4, 1, 1, 4, 4, 1, 4, 4, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 4, 4, 4, 1, 4, 4, 4, 1, 4, 4, 4, 4, 4, 4, 1, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 4, 4, 4, 1, 1, 4, 4, 4, 1, 4, 4, 4, 1, 4, 4, 4, 4, 1, 1, 4, 1, 4, 4, 4, 4, 1, 1, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 4, 1, 1, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 4, 1, 4, 1, 1, 4, 4, 4, 4, 4, 4, 4, 

In [11]:
def intra_preprocessing(frame):
    #vit1
    pass

def diff_preprocessing(frame):
    #vit1
    pass

def frame_types_preprocessing(frame_types):
    #embedding
    pass

def mb_types_preprocessing(mb_types):
    #embedding
    #vit2
    pass

def luma_qps_preprocessing(luma_qps):
    #vit2
    pass

In [12]:
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 120
from matplotlib.patches import Rectangle
from matplotlib.lines import Line2D
import numpy as np

plt.imshow(rgb_frame)
plt.title('Frame {} decoded'.format(FRAME_NUM))

NameError: name 'rgb_frame' is not defined

In [None]:
# now, visualize the encoding parameters

# first we need to get the size of the macroblocks from the enums

import re
from slice_pb2 import MacroblockType

# parse the size of macroblocks from the enum names
mb_size_regex = re.compile('([0-9]{1,2})x([0-9]{1,2})')
mb_size_dict = dict()
# Iterate through all the enum entries to build a dictionary of macroblock sizes
for key, val in MacroblockType.items():
    search_result = mb_size_regex.search(key)
    if search_result is not None:
        mb_size_x = int(search_result.group(1))
        mb_size_y = int(search_result.group(2))
        mb_size_dict[key] = (mb_size_x, mb_size_y)
    else:
        mb_size_dict[key] = (16, 16)

# show some entires
for key in list(mb_size_dict.keys())[:10]:
    print(f'{key}={mb_size_dict[key]}')

In [None]:
# now, begin the visualization
plt.imshow(rgb_frame)
ax = plt.gca()

mb_color_cycle = plt.get_cmap('Set1')
mb_alpha = 0.5
mb_colors = np.asarray([mb_color_cycle(i) for i in range(5)])
mb_colors[:, 3] = mb_alpha
mb_labels = ['INTRA', 'DIRECT_SKIP', 'SKIP', 'DIRECT', 'INTER']
line_color = 'black'
line_width = 0.2

for mb in slice.mbs:
    mb_type = MacroblockType.Name(mb.type)
    mb_size = mb_size_dict[mb_type]

    mb_label_index = None

    # determine the color
    if 'INTRA' in mb_type:
        mb_label_index = 0
    elif 'SKIP' in mb_type and 'DIRECT' in mb_type:
        mb_label_index = 1
    elif 'SKIP' in mb_type:
        mb_label_index = 2
    elif 'DIRECT' in mb_type:
        mb_label_index = 3
    else:
        mb_label_index = 4

    color = mb_colors[mb_label_index]
    
    # compute the lower left corner of the macroblock
    mb_x = mb.x * 16
    mb_y = mb.y * 16

    patches = []

    if mb_size == (16, 16):
        patches.append(Rectangle((mb_x, mb_y), 16, 16, 
            facecolor=color, edgecolor=line_color, linewidth=line_width))
    elif mb_size == (8, 16):
        patches.append(Rectangle((mb_x, mb_y), 8, 16, 
            facecolor=color, edgecolor=line_color, linewidth=line_width))
        patches.append(Rectangle((mb_x + 8, mb_y), 8, 16, 
            facecolor=color, edgecolor=line_color, linewidth=line_width))
    elif mb_size == (16, 8):
        patches.append(Rectangle((mb_x, mb_y), 16, 8, 
            facecolor=color, edgecolor=line_color, linewidth=line_width))
        patches.append(Rectangle((mb_x, mb_y+8), 16, 8, 
            facecolor=color, edgecolor=line_color, linewidth=line_width))
    elif mb_size == (8, 8):
        for i in range(2):
            for j in range(2):
                patches.append(Rectangle((mb_x+i*8, mb_y+j*8), 8, 8, 
                    facecolor=color, edgecolor=line_color, linewidth=line_width))
    elif mb_size == (4,4):
        for i in range(4):
            for j in range(4):
                patches.append(Rectangle((mb_x+i*4, mb_y+j*4), 4, 4, 
                    facecolor=color, edgecolor=line_color, linewidth=line_width))
    else:
        raise ValueError(f'unsupported macroblock size {mb_size}')

    for patch in patches:
        ax.add_patch(patch)

# generate the legend
custom_legends = [Line2D([0], [0], color=x, lw=3) for x in mb_colors]
plt.legend(custom_legends, mb_labels, bbox_to_anchor=(1.1, 1.05))
plt.title('Macroblock Partition and Type')