# Imports

In [1]:
import numpy as np
import h5py
import random
import torch

# Reader

In [2]:
class Reader:
    def __init__(self):
        dataset = 'MOO07_mapping_easy'
        folder_path = f'/storage/local/pranav/datasets/basalt/monado_slam/{dataset}'
        filename = 'data.hdf5'  # do not change 
        filepath = f'{folder_path}/{filename}'
        self._file = h5py.File(filepath, 'r')

        self.cam = 'cam0'
        assert self.cam in ['cam0', 'cam1']
    
        # [width, height]
        self.original_image_shape = [640, 480]
        self.crop_image_shape = [630, 476]
        
        self._init_groups_read_mode()

    def _init_groups_read_mode(self):
        self._detector = self._file[f'{self.cam}/detector']
        self._matcher = self._file[f'{self.cam}/matcher']
        self._filter = self._file[f'{self.cam}/filter']
        self._matches = self._file[f'{self.cam}/matches']

        self.detector_normalised = self._detector['normalised']
        self.detector_confidences = self._detector['confidences']

        self.matcher_warp = self._matcher['warp'] # only one you will need
        self.matcher_certainty = self._matcher['certainty']

        self.filter_normalised = self._filter['normalised']
        self.filter_confidences = self._filter['confidences']

        self.cropped_image_reference_coords = self._matches['crop/reference_coords']
        self.cropped_image_target_coords = self._matches['crop/target_coords']

    def get_random_pair(self):
        keys = list(self.cropped_image_reference_coords.keys())
        random_key = random.choice(keys)
        return random_key

    def close(self):
        self._file.close()

    def _warp_to_pixel_coords(self, warp):
        """
        This function is from a RoMa utils file
        """
        h1, w1 = 476, 630
        h2, w2 = 476, 630

        warp1 = warp[..., :2]
        warp1 = (
            torch.stack(
                (
                    w1 * (warp1[..., 0] + 1) / 2,
                    h1 * (warp1[..., 1] + 1) / 2,
                ),
                axis=-1
            )
        )

        warp2 = warp[..., 2:]
        warp2 = (
            torch.stack(
                (
                    w2 * (warp2[..., 0] + 1) / 2,
                    h2 * (warp2[..., 1] + 1) / 2,
                ),
                axis=-1
            )
        )

        return torch.cat((warp1, warp2), dim=-1)

    def load_warp(self, pair_name):
        warp = self.matcher_warp[pair_name][()]
        warp = torch.from_numpy(warp)

        pixel_coords = self._warp_to_pixel_coords(warp)
        certainty = self.matcher_certainty[pair_name][()]

        return pixel_coords, certainty

    def get_target_keypoint(self, pixel_coords, reference_keypoint):
        """
        Make sure that reference x, y are within center cropped image size
        
        crop_w, crop_h = self.crop_image_shape
        """
        x_a, y_a = reference_keypoint
        x_a, y_a = int(x_a), int(y_a)

        _, _, x_b, y_b = pixel_coords[y_a, x_a]
        x_b, y_b = int(x_b.item()), int(y_b.item())
  
        original_w, original_h = self.original_image_shape
        crop_w, crop_h = self.crop_image_shape 

        left_padding = (original_w - crop_w) // 2
        top_padding = (original_h - crop_h) // 2

        # making sure output is for original image size
        x_b, y_b = x_b + left_padding, y_b + top_padding

        return x_b, y_b


In [3]:
reader = Reader()

In [4]:
reader._file

<HDF5 file "data.hdf5" (mode r)>

# Get Random Pair Name

This is a function to get names of two random consecutive frames. Was useful for testing.

In [5]:
reader.get_random_pair()

'8669603115290_8669636405390'

# Groups inside File

In [6]:
def print_hdf5_structure(reader):
    def print_group(name, obj):
        if isinstance(obj, h5py.Group):
            print(f"Group: {name}")
        # elif isinstance(obj, h5py.Dataset):
        #     print(f"  Dataset: {name} | Shape: {obj.shape} | Data type: {obj.dtype}")

    reader._file.visititems(print_group)

- The `detector` and `matcher` groups store the output of DeDoDe and RoMa. 
- The `filter` group stores the keypoints after we remove those with low RoMa confidences. Used for training
- The `matches` group is only for training.  

In [7]:
print_hdf5_structure(reader)

Group: cam0
Group: cam0/detector
Group: cam0/detector/confidences
Group: cam0/detector/normalised
Group: cam0/filter
Group: cam0/filter/confidences
Group: cam0/filter/normalised
Group: cam0/matcher
Group: cam0/matcher/certainty
Group: cam0/matcher/warp
Group: cam0/matches
Group: cam0/matches/crop
Group: cam0/matches/crop/reference_coords
Group: cam0/matches/crop/target_coords
Group: cam1
Group: cam1/detector
Group: cam1/detector/confidences
Group: cam1/detector/normalised
Group: cam1/filter
Group: cam1/filter/confidences
Group: cam1/filter/normalised
Group: cam1/matcher
Group: cam1/matcher/certainty
Group: cam1/matcher/warp
Group: cam1/matches
Group: cam1/matches/crop
Group: cam1/matches/crop/reference_coords
Group: cam1/matches/crop/target_coords


# Get Target from Reference

In [8]:
pair_name = reader.get_random_pair()
pair_name

'8671500609690_8671533899390'

In [9]:
# pixel_coords holds the keypoint matches
pixel_coords, certainty = reader.load_warp(pair_name)

In [10]:
# Notice shape of pixel_coords. This is the shape of cropped image
pixel_coords.shape  

torch.Size([476, 630, 4])

In [11]:
certainty.shape  

(476, 630)

In [12]:
# (coordinate along width, coordinate along height)
reference_keypoint = [625, 400] 

target_keypoint = reader.get_target_keypoint(pixel_coords, reference_keypoint)
target_keypoint

(635, 381)