In [None]:
%load_ext autoreload
%autoreload 2

: 

In [52]:
import os
os.chdir('/home/stud/ath/ath_ws/keypoint_dataset_pipeline')

In [53]:
from config import config
from utils import logger

import numpy as np
import h5py
import random
import torch

In [54]:
def print_hdf5_structure(reader):
    def print_group(name, obj):
        if isinstance(obj, h5py.Group):
            print(f"Group: {name}")
            
    reader._file.visititems(print_group)

In [55]:
class StatsReader:
    def __init__(self, train=False):
        filepath = f'{config.paths[config.task.name].output}/data.hdf5'
        self._file = h5py.File(filepath, 'r')
        
        self._init_groups_read_mode()

    def _init_groups_read_mode(self):
        self._detector = self._file[f'{config.task.cam}/detector']
        self._matcher = self._file[f'{config.task.cam}/matcher']
        self._filter = self._file[f'{config.task.cam}/filter']
        self._matches = self._file[f'{config.task.cam}/matches']

        self.detector_normalised = self._detector['normalised']
        self.detector_confidences = self._detector['confidences']

        self.matcher_warp = self._matcher['warp'] # only one you will need
        self.matcher_certainty = self._matcher['certainty']

        self.filter_normalised = self._filter['normalised']
        self.filter_confidences = self._filter['confidences']

        self.cropped_image_reference_coords = self._matches['crop/reference_coords']
        self.cropped_image_target_coords = self._matches['crop/target_coords']
    
    def close(self):
        self._file.close()


In [None]:
config.task.frame_filtering = 'blur_test'

reference_coords_total_count = 0
target_coords_total_count = 0

coords_stats = {}

for track in config.task.tracks:
    config.task.track = track
    logger.info(f'Track : {track}')
    config.task.dataset_kind = track[:2]

    filepath = f'{config.paths[config.task.name].output}/data.hdf5'
    print(filepath)

    reader = StatsReader()
    # print_hdf5_structure(reader)

    coords_stats[track] = 0

    for cam in config.task.cams:
        config.task.cam = cam
        logger.info(f'Cam : {cam}')

        refs_from = reader._file[f'{cam}/matches/crop/reference_coords']
        tars_from = reader._file[f'{cam}/matches/crop/target_coords']

        print(f'Number of Reference Coords  : {len(refs_from)}')
        print(f'Number of Target Coords     : {len(tars_from)}')

        coords_stats[track] += len(refs_from)

        reference_coords_total_count += len(refs_from)
        target_coords_total_count += len(tars_from)
    
    reader.close()
    
print(f'Total Reference Coords : {reference_coords_total_count}')
print(f'Total Target Coords : {target_coords_total_count}')

# Train Data

In [60]:
filepath = f'{config.paths[config.task.name].train_data}/train_data.hdf5'
train_data_f = h5py.File(filepath, 'r')

In [None]:
train_data_f

In [None]:
total_pairs = 0
coords_stats = {}

def print_hdf5_structure(train_data_f):
    def print_group(name, obj):
        if isinstance(obj, h5py.Group):
            print(f"Group: {name} = {len(obj)}")
            track = name.split('/')[0]

            if track not in coords_stats.keys():
                coords_stats[track] = 0

            if name.endswith('reference_coords'):
                global total_pairs
                total_pairs += len(obj)
                logger.info(f'Length so far {total_pairs}')

                coords_stats[track] += len(obj) 
            
    train_data_f.visititems(print_group)

print_hdf5_structure(train_data_f)
logger.info(f'Length {total_pairs}')

# Order videos

In [None]:
sorted_items = sorted(coords_stats.items(), key=lambda item: item[1])

for key, value in sorted_items:
    print(f"{key} : {value}")

In [None]:
# Calculate the total sum of values
total_sum = sum(coords_stats.values())
validation_target = total_sum * 0.25

# Split into validation and training based on value sum
validation = {}
training = {}
current_sum = 0

for key, value in sorted_items:
    if current_sum + value <= validation_target:
        validation[key] = value
        current_sum += value
    else:
        training[key] = value

validation_sum = sum(validation.values())
training_sum = sum(training.values())

print("Validation Set:")
print(f"Count of values (sum): {validation_sum}\n")
for key, value in validation.items():
    print(f"{key}: {value}")

print("\nTraining Set:")
print(f"Count of values (sum): {training_sum}\n")
for key, value in training.items():
    print(f"{key}: {value}")