-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·121 lines (104 loc) · 3.98 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python3
from os import path
from subprocess import check_call, check_output
import sys
import glob
from multiprocessing import Pool
from functools import lru_cache
from json import loads, dumps
curdir = path.abspath(path.dirname(__file__))
dist_site = path.join(curdir, 'dist', 'lib', 'python3.8', 'site-packages')
sys.path.append(dist_site)
import numpy as np
import cv2
folder = sys.argv[1]
args = loads(sys.argv[2] if len(sys.argv) > 2 else '{}')
if 'ratio_test' not in args:
args['ratio_test'] = 0.75
if 'n_features' not in args:
args['n_features'] = 360
if 'velocity_factor' not in args:
args['velocity_factor'] = 100
def largest_indices(ary, n):
"""Returns the n largest indices from a numpy array."""
# https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
flat = ary.flatten()
indices = np.argpartition(flat, -n)[-n:]
indices = indices[np.argsort(-flat[indices])]
return np.unravel_index(indices, ary.shape)
_cache = {}
def extract_features(img, cache_key):
# Note that caching by storing the img would defeat the purpose
if cache_key in _cache:
return _cache[cache_key]
img = img[:-19, :, :] # perfectly crops out 'google' text
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray = np.float32(gray)
dst = cv2.cornerHarris(gray, 2, 3, 0.04)
idx = largest_indices(dst, args['n_features'])
brief = cv2.xfeatures2d.BriefDescriptorExtractor_create()
kp = [cv2.KeyPoint(float(y), float(x), 1) for x, y in zip(*idx)]
kp, des = brief.compute(img, kp)
_cache[cache_key] = (kp, des)
return kp, des
def get_matching_cost(frame1, frame2, idx1, idx2):
h, w, _ = frame1.shape
kp1, des1 = extract_features(frame1, idx1)
kp2, des2 = extract_features(frame2, idx2)
diag = np.linalg.norm([[0, 0], [w, h]])
bf = cv2.BFMatcher()
matches = bf.knnMatch(des1, des2, k=2)
matches = [m[0]
for m in matches if m[0].distance < args['ratio_test'] * m[1].distance]
frame1_pts = np.float32(
[kp1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
frame2_pts = np.float32(
[kp2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
# In streetview, frame 2 should be a 'zoomed in' version of frame 1, meaning the homography 2 -> 1 should be in bounds
try:
# findHomography throws an error if we have < 4 points
M, _ = cv2.findHomography(frame2_pts, frame1_pts, cv2.RANSAC)
center = np.float32([[w//2, h//2]]).reshape(-1, 1, 2)
# perspectiveTransform can throw an error if M is not full rank (i guess?)
center_to_frame1 = cv2.perspectiveTransform(center, M)
frame2_to_frame1 = cv2.perspectiveTransform(frame2_pts, M)
except Exception:
return diag*0.5
cR = np.linalg.norm(frame1_pts - frame2_to_frame1)
c0 = np.linalg.norm(center - center_to_frame1)
if cR < 0.5 * diag:
return min(c0, cR)
else:
return 0.5 * diag
@lru_cache(16)
def cached_read(path):
return cv2.imread(path)
def compute_opt_path(folder):
img_paths = glob.glob(path.join(folder, '*.jpg'))
img_paths = sorted(img_paths, key=lambda s: int(
''.join([c for c in s if c.isdigit()])))
n = len(img_paths)
window_size = 4
cost = [float('inf')] * n
prevs = [0] * n
cost[0] = 0
for i in range(n-1):
frame1 = cached_read(img_paths[i])
for j in range(i+1, min(n, i+window_size+1)):
frame2 = cached_read(img_paths[j])
match_cost = get_matching_cost(frame1, frame2, i, j)
velocity_cost = args['velocity_factor'] * (j - i - 1)**2
c = match_cost + velocity_cost + cost[i]
if c < cost[j]:
cost[j] = c
prevs[j] = i
pred = prevs[-1]
min_path = [n - 1]
while pred != 0:
min_path.append(pred)
pred = prevs[pred]
min_path.append(0)
min_path.reverse()
return min_path
min_path = compute_opt_path(folder)
print(dumps(min_path))