-
Notifications
You must be signed in to change notification settings - Fork 3
/
kp_server.py
33 lines (29 loc) · 1.02 KB
/
kp_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import pickle
import torch
import zmq
from superpoint import SuperPoint
import numpy as np
if __name__ == '__main__':
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind('tcp://127.0.0.1:5555')
extractor_model = SuperPoint({
'descriptor_dim': 256,
'nms_radius': 3,
'max_keypoints': 4096,
'keypoints_threshold': 0.6
})
extractor_model.cuda()
extractor_model.eval()
extractor_model.load_state_dict(torch.load('data/models/superpoint_v1.pth'), strict=True)
while True:
gray = pickle.loads(socket.recv())
print('processing')
with torch.no_grad():
res = extractor_model(torch.from_numpy(gray / 255.).float()[None, None].cuda())
res = {
'keypoints': res['keypoints'][0].cpu().numpy(),
'descriptors': res['descriptors'][0].cpu().numpy().T,
'raw_descs': np.moveaxis(res['raw_descs'][0].cpu().numpy(), 0, -1)
}
socket.send(pickle.dumps(res))