/
triton_inference_test.py
59 lines (47 loc) · 1.66 KB
/
triton_inference_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import sys
sys.path.append('src/')
import argparse
from requests import get
from read_params import read_params
from dask_client import dask_client
from load_data import load_data
from feature_engg import feature_engg
import tritonclient.grpc as triton_grpc
def triton_inference(df, config_path):
config = read_params(config_path)
df = df.astype(config["triton"]["dtype"])
df = df.sample(frac=0.01)
df = df.drop([config["base"]["target_col"]], axis=1)
df = df.compute().to_pandas().values
ip = get('https://api.ipify.org').content.decode('utf8')
grpc_client = triton_grpc.InferenceServerClient(
url = ip + ':' + config["triton"]["grpc_port"],
verbose = False
)
# Set up Triton input and output objects for GRPC
triton_input_grpc = triton_grpc.InferInput(
'input__0',
(df.shape[0], df.shape[1]),
'FP32'
)
triton_input_grpc.set_data_from_numpy(df)
triton_output_grpc = triton_grpc.InferRequestedOutput('output__0')
request_grpc = grpc_client.infer(
'fil',
model_version='1',
inputs=[triton_input_grpc],
outputs=[triton_output_grpc]
)
# Get results as numpy arrays
predictions = request_grpc.as_numpy('output__0')
print("Triton is Working!")
return predictions
if __name__=="__main__":
args = argparse.ArgumentParser()
args.add_argument("--config", default="params.yaml")
parsed_args = args.parse_args()
client = dask_client(config_path=parsed_args.config)
df = load_data(config_path=parsed_args.config, test = True)
df = feature_engg(df)
triton_inference(df, config_path=parsed_args.config)
client.close()