diff --git a/internal/db/db.go b/internal/db/db.go index 93dc0d1..635f605 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -42,10 +42,11 @@ type StatementResult struct { ColumnNames []string RowCh chan rowResult Err error + Query string } -func newStatementResult(columnNames []string, rowCh chan rowResult) *StatementResult { - return &StatementResult{ColumnNames: columnNames, RowCh: rowCh} +func newStatementResult(columnNames []string, rowCh chan rowResult, query string) *StatementResult { + return &StatementResult{ColumnNames: columnNames, RowCh: rowCh, Query: query} } func newStatementResultWithError(err error) *StatementResult { @@ -172,7 +173,7 @@ func (db *Db) executeQuery(query string, statementResultCh chan StatementResult) defer rows.Close() - return readQueryResults(rows, statementResultCh) + return readQueryResults(rows, statementResultCh, query) } func (db *Db) prepareStatementsIntoQueries(statementsString string) []string { @@ -220,14 +221,18 @@ func getColumnTypes(rows *sql.Rows) ([]reflect.Type, error) { return types, nil } -func readQueryResults(queryRows *sql.Rows, statementResultCh chan StatementResult) (shouldContinue bool) { +func readQueryResults(queryRows *sql.Rows, statementResultCh chan StatementResult, query string) (shouldContinue bool) { + queries, _ := sqliteparserutils.SplitStatement(query) + queryIndex := 0 hasResultSetToRead := true for hasResultSetToRead { - if shouldContinue := readQueryResultSet(queryRows, statementResultCh); !shouldContinue { + query := queries[queryIndex] + if shouldContinue := readQueryResultSet(queryRows, statementResultCh, query); !shouldContinue { return false } hasResultSetToRead = queryRows.NextResultSet() + queryIndex++ } if err := queryRows.Err(); err != nil { @@ -238,7 +243,7 @@ func readQueryResults(queryRows *sql.Rows, statementResultCh chan StatementResul return true } -func readQueryResultSet(queryRows *sql.Rows, statementResultCh chan StatementResult) (shouldContinue bool) { +func readQueryResultSet(queryRows *sql.Rows, statementResultCh chan StatementResult, query string) (shouldContinue bool) { columnNames, err := getColumnNames(queryRows) if err != nil { statementResultCh <- *newStatementResultWithError(err) @@ -264,7 +269,7 @@ func readQueryResultSet(queryRows *sql.Rows, statementResultCh chan StatementRes rowCh := make(chan rowResult) defer close(rowCh) - statementResultCh <- *newStatementResult(columnNames, rowCh) + statementResultCh <- *newStatementResult(columnNames, rowCh, query) for queryRows.Next() { err = queryRows.Scan(columnPointers...) diff --git a/internal/db/explainTreeBuilder.go b/internal/db/explainTreeBuilder.go new file mode 100644 index 0000000..2585d64 --- /dev/null +++ b/internal/db/explainTreeBuilder.go @@ -0,0 +1,54 @@ +package db + +import ( + "fmt" +) + +type QueryPlanNode struct { + ID string + ParentID string + NotUsed string + Detail string + Children []*QueryPlanNode +} + +func BuildQueryPlanTree(rows [][]string) (*QueryPlanNode, error) { + var nodes []*QueryPlanNode + nodeMap := make(map[string]*QueryPlanNode) + + for _, row := range rows { + id := row[0] + parentId := row[1] + notUsed := row[2] + detail := row[3] + + node := &QueryPlanNode{ + ID: id, + ParentID: parentId, + NotUsed: notUsed, + Detail: detail, + } + + nodes = append(nodes, node) + nodeMap[id] = node + } + + root := &QueryPlanNode{} + for _, node := range nodes { + if node.ParentID == "0" { + root = node + } else { + parent := nodeMap[node.ParentID] + parent.Children = append(parent.Children, node) + } + } + + return root, nil +} + +func PrintQueryPlanTree(node *QueryPlanNode, indent string) { + fmt.Printf("%s%s\n", indent, node.Detail) + for _, child := range node.Children { + PrintQueryPlanTree(child, indent+" ") + } +} diff --git a/internal/db/output.go b/internal/db/output.go index 5c5150d..eae8ead 100644 --- a/internal/db/output.go +++ b/internal/db/output.go @@ -14,6 +14,26 @@ type Printer interface { print(statementResult StatementResult, outF io.Writer) error } +type ExplainQueryPrinter struct{} + +func (eqp ExplainQueryPrinter) print(statementResult StatementResult, outF io.Writer) error { + data := [][]string{} + + tableData, err := appendData(statementResult, data, TABLE) + if err != nil { + return err + } + + root, err := BuildQueryPlanTree(tableData) + if err != nil { + return err + } + println("QUERY PLAN") + PrintQueryPlanTree(root, "") + + return nil +} + type TablePrinter struct { withoutHeader bool } @@ -32,6 +52,7 @@ func (t TablePrinter) print(statementResult StatementResult, outF io.Writer) err table.AppendBulk(tableData) table.Render() + return nil } @@ -100,10 +121,14 @@ func appendData(statementResult StatementResult, data [][]string, mode FormatTyp } data = append(data, formattedRow) } + return data, nil } -func getPrinter(mode enums.PrintMode, withoutHeader bool) (Printer, error) { +func getPrinter(mode enums.PrintMode, withoutHeader bool, isExplainQueryPlan bool) (Printer, error) { + if isExplainQueryPlan { + return &ExplainQueryPrinter{}, nil + } switch mode { case enums.TABLE_MODE: return &TablePrinter{ @@ -143,7 +168,8 @@ func PrintStatementResult(statementResult StatementResult, outF io.Writer, witho return &UnableToPrintStatementResult{} } - printer, err := getPrinter(mode, withoutHeader) + isExplainQueryPlan := IsResultComingFromExplainQueryPlan(statementResult) + printer, err := getPrinter(mode, withoutHeader, isExplainQueryPlan) if err != nil { return err } diff --git a/internal/db/utils.go b/internal/db/utils.go index b9d04a5..9f16e43 100644 --- a/internal/db/utils.go +++ b/internal/db/utils.go @@ -2,6 +2,7 @@ package db import ( "net/url" + "reflect" "strings" "unicode" ) @@ -45,3 +46,25 @@ func NeedsEscaping(name string) bool { } return false } + +var explainQueryPlanStatement = "EXPLAIN QUERY PLAN" +var explainQueryPlanColumnNames = []string{"id", "parent", "notused", "detail"} + +func queryContainsExplainQueryPlanStatement(query string) bool { + return strings.HasPrefix( + strings.ToLower(query), + strings.ToLower(explainQueryPlanStatement), + ) +} + +func columnNamesMatchExplainQueryPlan(colNames []string) bool { + return reflect.DeepEqual(colNames, explainQueryPlanColumnNames) +} + +// "query" can be a string containing multiple queries separated by ";" or a single query +func IsResultComingFromExplainQueryPlan(statementResult StatementResult) bool { + query := statementResult.Query + columnNames := statementResult.ColumnNames + return queryContainsExplainQueryPlanStatement(query) && + columnNamesMatchExplainQueryPlan(columnNames) +} diff --git a/test/db_root_command_shell_test.go b/test/db_root_command_shell_test.go index 6a0c00e..62eba42 100644 --- a/test/db_root_command_shell_test.go +++ b/test/db_root_command_shell_test.go @@ -271,6 +271,19 @@ func (s *DBRootCommandShellSuite) Test_GivenATableNameWithSpecialCharacters_When s.tc.AssertSqlEquals(outS, expected) } +func (s *DBRootCommandShellSuite) Test_GivenATableNameWithTheSameSignatureAsExpainQueryPlan_WhenQueryingIt_ExpectNotToBeTreatedAsExplainQueryPlan() { + _, _, err := s.tc.Execute("CREATE TABLE fake_explain (ID INTEGER PRIMARY KEY, PARENT INTEGER, NOTUSED INTEGER, DETAIL TEXT);") + s.tc.Assert(err, qt.IsNil) + + outS, errS, err := s.tc.ExecuteShell([]string{"SELECT * FROM fake_explain;"}) + s.tc.Assert(err, qt.IsNil) + s.tc.Assert(errS, qt.Equals, "") + + expected := "id parent notused detail" + + s.tc.AssertSqlEquals(outS, expected) +} + func (s *DBRootCommandShellSuite) Test_GivenATableWithRecordsWithSingleQuote_WhenCalllSelectAllFromTable_ExpectSingleQuoteScape() { s.tc.CreateEmptySimpleTable("t") _, errS, err := s.tc.Execute("INSERT INTO t VALUES (0, \"x'x\", 0)")