-
Notifications
You must be signed in to change notification settings - Fork 15
/
infer_tensorrt.py
37 lines (30 loc) · 891 Bytes
/
infer_tensorrt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import cv2
import argparse
from siamese.siamese_network_trt import SiameseNetworkTRT
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--image1',
type=str,
help="Path to first image of the pair.",
required=True
)
parser.add_argument(
'--image2',
type=str,
help="Path to second image of the pair.",
required=True
)
parser.add_argument(
'--engine',
type=str,
help="Path to tensorrt engine generated by 'onnx_to_trt.py'.",
required=True
)
args = parser.parse_args()
model = SiameseNetworkTRT()
model.load_model(args.engine)
image1 = cv2.imread(args.image1)
image2 = cv2.imread(args.image2)
similarity = model.predict(image1, image2)
print(F"Similarity between the two images = {round(similarity, 2)}")