diff --git a/server/sqlflowserver.go b/server/sqlflowserver.go index d068c28617..e0bacf0745 100644 --- a/server/sqlflowserver.go +++ b/server/sqlflowserver.go @@ -62,7 +62,7 @@ func (s *Server) Run(req *pb.Request, stream pb.SQLFlow_RunServer) error { sqlStatements := strings.Split(req.Sql, ";") trimedStatements := []string{} for _, singleSQL := range sqlStatements { - sqlToRun := strings.Trim(singleSQL, "\n") + sqlToRun := strings.TrimSpace(singleSQL) if sqlToRun == "" { continue } diff --git a/server/sqlflowserver_test.go b/server/sqlflowserver_test.go index 219288d824..8f06601c3c 100644 --- a/server/sqlflowserver_test.go +++ b/server/sqlflowserver_test.go @@ -36,10 +36,11 @@ import ( ) const ( - testErrorSQL = "ERROR ..." - testQuerySQL = "SELECT ..." - testExecuteSQL = "INSERT ..." - testExtendedSQL = "SELECT ... TRAIN ..." + testErrorSQL = "ERROR ..." + testQuerySQL = "SELECT ..." + testExecuteSQL = "INSERT ..." + testExtendedSQL = "SELECT ... TRAIN ..." + testExtendedSQLWithSpace = "SELECT ... TRAIN ...; \n\t" ) var testServerAddress string @@ -67,6 +68,8 @@ func mockRun(sql string, db *sf.DB, modelDir string, session *pb.Session) *sf.Pi case testExtendedSQL: wr.Write("log 0") wr.Write("log 1") + default: + wr.Write(fmt.Errorf("unexcepted SQL: %s", sql)) } }() return rd @@ -123,7 +126,9 @@ func TestSQL(t *testing.T) { _, err = stream.Recv() a.Equal(status.Error(codes.Unknown, fmt.Sprintf("run error: %v", testErrorSQL)), err) - for _, s := range []string{testQuerySQL, testExecuteSQL, testExtendedSQL} { + testMultipleSQL := fmt.Sprintf("%s; %s", testQuerySQL, testExtendedSQL) + + for _, s := range []string{testQuerySQL, testExecuteSQL, testExtendedSQL, testExtendedSQLWithSpace, testMultipleSQL} { stream, err := c.Run(ctx, &pb.Request{Sql: s}) a.NoError(err) for {