In [1]:
# auto reload
%load_ext autoreload
%autoreload 2

# Testing PnP on ETH3D dataset

Install dependencies
```
pip install py7zr torchdata
pip install 'portalocker>=2.0.0'
```

The following cell downloads the ETH3D dataset and extracts it to the `data_root` folder. The dataset is about 4.7GB in size. The `batch_sample` is a data generation head connected to the ETH3D datapipe. 

In [2]:
import torch
from functools import partial
import os
from eth3d import download_pipe, load_pipe
import numpy as np
data_root = 'data_cache_eth3d_dp'
os.makedirs(data_root, exist_ok=True)

scenes = [
    "courtyard_dslr_undistorted.7z",
    "facade_dslr_undistorted.7z",
    "delivery_area_dslr_undistorted.7z",
    "statue_dslr_undistorted.7z",]

files = ['DSC_0323.JPG', 'DSC_0350.JPG', 'DSC_0685.JPG', 'DSC_0490.JPG']

def batch_sample(num_points: int, batch_size: int,
                 pixels: np.ndarray, point_ids: np.ndarray,
                 point_dict: dict,):
    visible = point_ids >= 0
    total_visible = visible.sum()
    pixels = pixels[visible]
    point_ids = point_ids[visible]

    draw = np.random.choice(len(point_ids), (batch_size, num_points), replace=True)
    sampled_point_ids = point_ids[draw]
    sampled_pixels = pixels[draw]
    # dtype is consistent with point
    sampled_points = np.zeros((batch_size, num_points, 3), dtype=next(iter(point_dict.values()))[0].dtype)
    for i in range(batch_size):
        sampled_points[i] = np.array([point_dict[key][0] for key in sampled_point_ids[i]])

    return sampled_pixels, sampled_points


img = load_pipe(download_pipe(data_root, scenes=scenes))
for _ in img:
    pass

Next we instantiate a `PnP` object. Test our solver on the selected four images. 

In [8]:
import pypose as pp
from time import perf_counter
epnp = pp.module.EPnP()

batch_size = 1000
nums = [10, 100, 1000]
for scene, file in zip(scenes, files):
    r_error = []
    t_error = []
    time = []
    for num_points in nums:
        # load data
        img = load_pipe(download_pipe(data_root, scenes=[scene]))
        img = img.map(partial(batch_sample, num_points, batch_size), input_col=('pixels', 'point_ids', 'point'), output_col='sample')
        img = img.filter(lambda x: file in x['jpg_name'])

        # evaluation
        for idx, i in enumerate(img):
            pixels, points = i['sample']
            groundtruth = i['pose']
            camera = i['camera'][i['camera_id']]

            points = torch.from_numpy(points).to(torch.float64)
            pixels = torch.from_numpy(pixels).to(points)
            camera = torch.from_numpy(camera).to(points)

            start = perf_counter()
            pose = epnp(points, pixels, camera)
            end = perf_counter()
            
            time.append(end - start)
            r_error.append(torch.norm(groundtruth.unsqueeze(0).rotation().matrix() - 
                                      pose.rotation().matrix(), dim=(-1, -2)).pow(2).mean().sqrt())
            t_error.append(torch.norm(groundtruth.unsqueeze(0).translation() - 
                                      pose.translation(), dim=-1).pow(2).mean().sqrt())
            assert idx == 0, "Only one image is expected"
    # plot 3 subplots
    import matplotlib.pyplot as plt
    # all axis are in log scale
    plt.figure(figsize=(30, 10))
    plt.subplot(1, 3, 1)
    plt.plot(nums, time, 'o-')
    plt.ylabel('time')
    plt.subplot(1, 3, 2)
    plt.plot(nums, r_error, 'o-')
    plt.ylabel('error')
    plt.subplot(1, 3, 3)
    plt.plot(nums, t_error, 'o-')
    plt.ylabel('error')

    plt.legend()
    plt.title(f'EPnP with {file}')
    plt.show()


KeyboardInterrupt: 