-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathds_trt_4.py
More file actions
20 lines (16 loc) · 703 Bytes
/
ds_trt_4.py
File metadata and controls
20 lines (16 loc) · 703 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import sys
import torch
class TensorRTPart(torch.nn.Module):
def __init__(self, ssd_module):
super().__init__()
self.ssd_module = ssd_module
self.creates_dummy_dim = False
def forward(self, image_nchw):
image_batch = self.ssd_module.preprocess(image_nchw)
locs, labels = self.ssd_module.detector(image_batch)
# DeepStream likes to strip off the batch dimension in order to feed outputs for postprocessing image-by-image
# this is no good for us, so insert an extra unit dimension for DS to strip off
# return a tuple for consistency with future
return (locs, labels)
if __name__ == '__main__':
print('locs,labels')