/
main.go
93 lines (83 loc) · 2.9 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// Copyright 2020 The SQLFlow Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// To run this program:
// go generate .. && go run main.go
//
package main
import (
"flag"
"fmt"
"net"
"os"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/reflection"
"sqlflow.org/sqlflow/go/log"
"sqlflow.org/sqlflow/go/proto"
sf "sqlflow.org/sqlflow/go/sql"
server "sqlflow.org/sqlflow/go/sqlflowserver"
)
func newServer(caCrt, caKey string, logger *log.Logger) (*grpc.Server, error) {
var s *grpc.Server
if caCrt != "" && caKey != "" {
creds, err := credentials.NewServerTLSFromFile(caCrt, caKey)
if err != nil {
return nil, fmt.Errorf("failed to load CA crt/key files: %s, %s, %v", caCrt, caKey, err)
}
s = grpc.NewServer(grpc.Creds(creds))
logger.Info("Launch server with SSL/TLS certification.")
} else {
s = grpc.NewServer()
logger.Info("Launch server with insecure mode.")
}
return s, nil
}
func start(modelDir, caCrt, caKey string, port int, isArgoMode bool) {
logger := log.GetDefaultLogger()
s, err := newServer(caCrt, caKey, logger)
if err != nil {
logger.Fatalf("failed to create new gRPC Server: %v", err)
}
if modelDir != "" {
if _, derr := os.Stat(modelDir); derr != nil {
os.Mkdir(modelDir, os.ModePerm)
}
}
if isArgoMode {
proto.RegisterSQLFlowServer(s, server.NewServer(server.SubmitWorkflow, modelDir))
} else {
proto.RegisterSQLFlowServer(s, server.NewServer(sf.RunSQLProgram, modelDir))
}
listenString := fmt.Sprintf(":%d", port)
lis, err := net.Listen("tcp", listenString)
if err != nil {
logger.Fatalf("failed to listen: %v", err)
}
// Register reflection service on gRPC server.
reflection.Register(s)
logger.Infof("Server Started at %s", listenString)
if err := s.Serve(lis); err != nil {
logger.Fatalf("failed to serve: %v", err)
}
}
func main() {
modelDir := flag.String("model_dir", "", "model would be saved on the local dir, otherwise upload to the table.")
logPath := flag.String("log", "", "path/to/log, e.g.: /var/log/sqlflow.log")
caCrt := flag.String("ca-crt", "", "CA certificate file.")
caKey := flag.String("ca-key", "", "CA private key file.")
port := flag.Int("port", 50051, "TCP port to listen on.")
isArgoMode := flag.Bool("argo-mode", false, "Enable Argo workflow model.")
flag.Parse()
log.InitLogger(*logPath, log.OrderedTextFormatter)
start(*modelDir, *caCrt, *caKey, *port, *isArgoMode)
}