diff --git a/README.md b/README.md index 2301f84..3539c5e 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,11 @@ sqls aims to provide advanced intelligence for you to edit sql in your own edito - DDL(Data Definition Language) - [ ] CREATE TABLE - [ ] ALTER TABLE + +#### Join completion +If the tables are connected with a foreign key sqls can complete ```JOIN``` statements + +![join_completion](imgs/sqls-fk_joins.gif) #### CodeAction diff --git a/imgs/sqls-fk_joins.gif b/imgs/sqls-fk_joins.gif new file mode 100644 index 0000000..99f7bff Binary files /dev/null and b/imgs/sqls-fk_joins.gif differ diff --git a/internal/completer/candidates.go b/internal/completer/candidates.go index b56af42..547ffb3 100644 --- a/internal/completer/candidates.go +++ b/internal/completer/candidates.go @@ -1,6 +1,7 @@ package completer import ( + "fmt" "strings" "github.com/lighttiger2505/sqls/internal/database" @@ -155,6 +156,169 @@ func (c *Completer) TableCandidates(parent *completionParent, targetTables []*pa return candidates } +func (c *Completer) joinCandidates(lastTable *parseutil.TableInfo, + targetTables, allTables []*parseutil.TableInfo, + joinOn, lowercaseKeywords bool) []lsp.CompletionItem { + var candidates []lsp.CompletionItem + if len(c.DBCache.ForeignKeys) == 0 { + return candidates + } + + tMap := make(map[string]*parseutil.TableInfo) + for _, t := range targetTables { + tMap[t.Name] = t + } + fkMap := make(map[string][][]*database.ForeignKey) + if lastTable == nil { + for t := range tMap { + for k, v := range c.DBCache.ForeignKeys[t] { + fkMap[k] = append(fkMap[k], v) + } + } + } else { + delete(tMap, lastTable.Name) + rTab := []*parseutil.TableInfo{lastTable} + if !joinOn { + rTab = resolveTables(lastTable, c.DBCache) + } + for _, lt := range rTab { + for k, v := range c.DBCache.ForeignKeys[lt.Name] { + if _, ok := tMap[k]; ok { + fkMap[lt.Name] = append(fkMap[lt.Name], v) + } + } + } + + for _, t := range rTab { + if _, ok := tMap[t.Name]; !ok { + tMap[t.Name] = t + } + } + } + + aliases := make(map[string]interface{}) + for _, t := range allTables { + if t.Alias != "" { + aliases[t.Alias] = true + } + } + + for k, v := range fkMap { + for _, fks := range v { + for _, fk := range fks { + candidates = append(candidates, generateForeignKeyCandidate(k, tMap, aliases, + fk, joinOn, lowercaseKeywords)) + } + } + } + return candidates +} + +func resolveTables(t *parseutil.TableInfo, cache *database.DBCache) []*parseutil.TableInfo { + if _, ok := cache.ColumnDescs(t.Name); ok { + return []*parseutil.TableInfo{t} + } + var rv []*parseutil.TableInfo + targetName := strings.ToLower(t.Name) + for _, cond := range cache.SortedTables() { + if strings.Contains(strings.ToLower(cond), targetName) { + rv = append(rv, &parseutil.TableInfo{ + Name: cond, + }) + } + } + return rv +} + +func generateTableAlias(target string, + aliases map[string]interface{}) string { + ch := []rune(target)[0] + i := 1 + var rv string + for { + rv = fmt.Sprintf("%c%d", ch, i) + if _, ok := aliases[rv]; ok { + i++ + continue + } + break + } + return rv +} + +func generateForeignKeyCandidate(target string, + tMap map[string]*parseutil.TableInfo, + aliases map[string]interface{}, + fk *database.ForeignKey, + joinOn, lowercaseKeywords bool) lsp.CompletionItem { + var tAlias string + if joinOn { + tAlias = tMap[target].Alias + if tAlias == "" { + tAlias = tMap[target].Name + } + } else { + tAlias = generateTableAlias(target, aliases) + } + builder := []struct { + sb *strings.Builder + alias string + }{ + { + sb: &strings.Builder{}, + alias: tAlias, + }, + { + sb: &strings.Builder{}, + alias: tAlias, + }, + } + if !joinOn { + builder[1].alias = fmt.Sprintf("${1:%s}", tAlias) + onKw := "ON" + if lowercaseKeywords { + onKw = "on" + } + for _, b := range builder { + b.sb.WriteString(fmt.Sprintf("%s %s %s ", target, b.alias, onKw)) + } + } + andKw := " AND " + if lowercaseKeywords { + andKw = " and " + } + prefix := "" + for _, cur := range *fk { + tIdx, rIdx := 0, 1 + if cur[rIdx].Table == target { + tIdx, rIdx = rIdx, tIdx + } + for _, b := range builder { + b.sb.WriteString(prefix) + } + prefix = andKw + for _, b := range builder { + b.sb.WriteString(strings.Join([]string{b.alias, cur[tIdx].Name}, ".")) + b.sb.WriteString(" = ") + } + rAlias := tMap[cur[rIdx].Table].Alias + if rAlias == "" { + rAlias = cur[rIdx].Table + } + for _, b := range builder { + b.sb.WriteString(strings.Join([]string{rAlias, cur[rIdx].Name}, ".")) + } + } + builder[1].sb.WriteString("$0") + return lsp.CompletionItem{ + Label: builder[0].sb.String(), + Kind: lsp.SnippetCompletion, + Detail: "Join generator for foreign key", + InsertText: builder[1].sb.String(), + InsertTextFormat: lsp.SnippetTextFormat, + } +} + func generateTableCandidates(tables []string, dbCache *database.DBCache) []lsp.CompletionItem { candidates := []lsp.CompletionItem{} for _, tableName := range tables { diff --git a/internal/completer/completer.go b/internal/completer/completer.go index c0ee546..d86ce58 100644 --- a/internal/completer/completer.go +++ b/internal/completer/completer.go @@ -32,6 +32,8 @@ const ( CompletionTypeChange CompletionTypeUser CompletionTypeSchema + CompletionTypeJoin + CompletionTypeJoinOn ) func (ct completionType) String() string { @@ -112,7 +114,7 @@ func (c *Completer) Complete(text string, params lsp.CompletionParams, lowercase lastWord := getLastWord(text, params.Position.Line+1, params.Position.Character) withBackQuote := strings.HasPrefix(lastWord, "`") - items := []lsp.CompletionItem{} + var items []lsp.CompletionItem if c.DBCache != nil { if completionTypeIs(ctx.types, CompletionTypeColumn) { @@ -130,7 +132,11 @@ func (c *Completer) Complete(text string, params lsp.CompletionParams, lowercase items = append(items, candidates...) } if completionTypeIs(ctx.types, CompletionTypeTable) { - candidates := c.TableCandidates(ctx.parent, definedTables) + excl := definedTables + if completionTypeIs(ctx.types, CompletionTypeJoin) { + excl = nil + } + candidates := c.TableCandidates(ctx.parent, excl) if withBackQuote { candidates = toQuotedCandidates(candidates) } @@ -157,6 +163,22 @@ func (c *Completer) Complete(text string, params lsp.CompletionParams, lowercase } items = append(items, candidates...) } + joinOn := completionTypeIs(ctx.types, CompletionTypeJoinOn) + if completionTypeIs(ctx.types, CompletionTypeJoin) || joinOn { + table, err := parseutil.ExtractLastTable(parsed, pos) + if err != nil { + return nil, err + } + tables, err := parseutil.ExtractPrevTables(parsed, pos) + if err != nil { + return nil, err + } + candidates := c.joinCandidates(table, tables, definedTables, joinOn, lowercaseKeywords) + if withBackQuote { + candidates = toQuotedCandidates(candidates) // what to do here? + } + items = append(candidates, items...) + } } if completionTypeIs(ctx.types, CompletionTypeKeyword) { @@ -185,6 +207,8 @@ func populateSortText(items []lsp.CompletionItem) { // This prefix defines the alphabetic priority of each kind. func getSortTextPrefix(kind lsp.CompletionItemKind) string { switch kind { + case lsp.SnippetCompletion: + return "00" case lsp.FieldCompletion: return "0" case lsp.ClassCompletion: @@ -208,7 +232,6 @@ func getSortTextPrefix(kind lsp.CompletionItemKind) string { lsp.OperatorCompletion, lsp.PropertyCompletion, lsp.ReferenceCompletion, - lsp.SnippetCompletion, lsp.StructCompletion, lsp.TextCompletion, lsp.TypeParameterCompletion, @@ -249,7 +272,7 @@ func getCompletionTypes(nw *parseutil.NodeWalker) *CompletionContext { } syntaxPos := parseutil.CheckSyntaxPosition(nw) - t := []completionType{} + var t []completionType p := noneParent switch { case syntaxPos == parseutil.ColName: @@ -349,6 +372,23 @@ func getCompletionTypes(nw *parseutil.NodeWalker) *CompletionContext { CompletionTypeFunction, } } + case syntaxPos == parseutil.JoinClause: + t = []completionType{ + CompletionTypeJoin, + CompletionTypeTable, + CompletionTypeReferencedTable, + CompletionTypeSchema, + CompletionTypeView, + CompletionTypeSubQuery, + } + case syntaxPos == parseutil.JoinOn: + t = []completionType{ + CompletionTypeJoinOn, + CompletionTypeColumn, + CompletionTypeReferencedTable, + CompletionTypeSubQueryColumn, + CompletionTypeSubQuery, + } case syntaxPos == parseutil.InsertColumn: t = []completionType{ CompletionTypeColumn, diff --git a/internal/completer/completer_test.go b/internal/completer/completer_test.go index 774e58f..83e67dd 100644 --- a/internal/completer/completer_test.go +++ b/internal/completer/completer_test.go @@ -165,3 +165,38 @@ func TestComplete(t *testing.T) { }) } } + +func TestGenerateAlias(t *testing.T) { + noMatchesTable := make(map[string]interface{}) + noMatchesTable["XX"] = true + matchesTable := make(map[string]interface{}) + matchesTable["XX"] = true + matchesTable["T1"] = true + + tests := []struct { + name string + table string + tMap map[string]interface{} + want string + }{ + { + "no matches", + "Table", + noMatchesTable, + "T1", + }, + { + "matches", + "Table", + matchesTable, + "T2", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := generateTableAlias(tt.table, tt.tMap); got != tt.want { + t.Errorf("generateAlias() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/database/cache.go b/internal/database/cache.go index d1f18c9..58b2855 100644 --- a/internal/database/cache.go +++ b/internal/database/cache.go @@ -28,7 +28,7 @@ func (u *DBCacheGenerator) GenerateDBCachePrimary(ctx context.Context) (*DBCache return nil, err } dbCache.Schemas = make(map[string]string) - for index,element := range schemas{ + for index, element := range schemas { dbCache.Schemas[strings.ToUpper(index)] = element } @@ -45,7 +45,7 @@ func (u *DBCacheGenerator) GenerateDBCachePrimary(ctx context.Context) (*DBCache return nil, err } dbCache.SchemaTables = make(map[string][]string) - for index,element := range schemaTables{ + for index, element := range schemaTables { dbCache.SchemaTables[strings.ToUpper(index)] = element } @@ -53,6 +53,7 @@ func (u *DBCacheGenerator) GenerateDBCachePrimary(ctx context.Context) (*DBCache if err != nil { return nil, err } + dbCache.ForeignKeys, err = u.genForeignKeysCache(ctx, dbCache.defaultSchema) return dbCache, nil } @@ -88,6 +89,32 @@ func (u *DBCacheGenerator) genColumnCacheAll(ctx context.Context) (map[string][] return genColumnMap(columnDescs), nil } +func (u *DBCacheGenerator) genForeignKeysCache(ctx context.Context, schemaName string) (map[string]map[string][]*ForeignKey, error) { + retVal := make(map[string]map[string][]*ForeignKey) + fk, err := u.repo.DescribeForeignKeysBySchema(ctx, schemaName) + if err != nil { + return nil, err + } + + for _, cur := range fk { + elem := (*cur)[0] + refs, ok := retVal[elem[0].Table] + if !ok { + refs = make(map[string][]*ForeignKey) + } + refs[elem[1].Table] = append(refs[elem[1].Table], cur) + retVal[elem[0].Table] = refs + + refs, ok = retVal[elem[1].Table] + if !ok { + refs = make(map[string][]*ForeignKey) + } + refs[elem[0].Table] = append(refs[elem[0].Table], cur) + retVal[elem[1].Table] = refs + } + return retVal, nil +} + func genColumnMap(columnDescs []*ColumnDesc) map[string][]*ColumnDesc { columnMap := map[string][]*ColumnDesc{} for _, desc := range columnDescs { @@ -107,6 +134,7 @@ type DBCache struct { Schemas map[string]string SchemaTables map[string][]string ColumnsWithParent map[string][]*ColumnDesc + ForeignKeys map[string]map[string][]*ForeignKey } func (dc *DBCache) Database(dbName string) (db string, ok bool) { diff --git a/internal/database/database.go b/internal/database/database.go index 4de55ad..3dd2a36 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -32,6 +32,7 @@ type DBRepository interface { DescribeDatabaseTableBySchema(ctx context.Context, schemaName string) ([]*ColumnDesc, error) Exec(ctx context.Context, query string) (sql.Result, error) Query(ctx context.Context, query string) (*sql.Rows, error) + DescribeForeignKeysBySchema(ctx context.Context, schemaName string) ([]*ForeignKey, error) } type DBOption struct { @@ -39,10 +40,14 @@ type DBOption struct { MaxOpenConns int } +type ColumnBase struct { + Schema string + Table string + Name string +} + type ColumnDesc struct { - Schema string - Table string - Name string + ColumnBase Type string Null string Key string @@ -50,10 +55,21 @@ type ColumnDesc struct { Extra string } +type ForeignKey [][2]*ColumnBase + +type fkItemDesc struct { + fkId string + schema string + table string + column string + refTable string + refColumn string +} + func (cd *ColumnDesc) OnelineDesc() string { items := []string{} if cd.Type != "" { - items = append(items, "`" + cd.Type + "`") + items = append(items, "`"+cd.Type+"`") } if cd.Key == "YES" { items = append(items, "PRIMARY KEY") @@ -163,3 +179,42 @@ func SubqueryColumnDoc(identName string, views []*parseutil.SubQueryView, dbCach } return buf.String() } + +func parseForeignKeys(rows *sql.Rows, schemaName string) ([]*ForeignKey, error) { + var retVal []*ForeignKey + var prevFk string + var cur *ForeignKey + for rows.Next() { + var fkItem fkItemDesc + err := rows.Scan( + &fkItem.fkId, + &fkItem.table, + &fkItem.column, + &fkItem.refTable, + &fkItem.refColumn, + ) + if err != nil { + return nil, err + } + var l, r ColumnBase + l.Schema = schemaName + l.Table = fkItem.table + l.Name = fkItem.column + r.Schema = l.Schema + r.Table = fkItem.refTable + r.Name = fkItem.refColumn + if fkItem.fkId != prevFk { + if cur != nil { + retVal = append(retVal, cur) + } + cur = new(ForeignKey) + } + *cur = append(*cur, [2]*ColumnBase{&l, &r}) + prevFk = fkItem.fkId + } + + if cur != nil { + retVal = append(retVal, cur) + } + return retVal, nil +} diff --git a/internal/database/database_mock.go b/internal/database/database_mock.go index 1382fe4..db0c608 100644 --- a/internal/database/database_mock.go +++ b/internal/database/database_mock.go @@ -17,9 +17,10 @@ type MockDBRepository struct { MockDescribeDatabaseTableBySchema func(context.Context, string) ([]*ColumnDesc, error) MockExec func(context.Context, string) (sql.Result, error) MockQuery func(context.Context, string) (*sql.Rows, error) + MockDescribeForeignKeysBySchema func(context.Context, string) ([]*ForeignKey, error) } -func NewMockDBRepository(conn *sql.DB) DBRepository { +func NewMockDBRepository(_ *sql.DB) DBRepository { return &MockDBRepository{ MockDatabase: func(ctx context.Context) (string, error) { return "world", nil }, MockDatabases: func(ctx context.Context) ([]string, error) { return dummyDatabases, nil }, @@ -37,7 +38,7 @@ func NewMockDBRepository(conn *sql.DB) DBRepository { return nil, nil }, MockDescribeDatabaseTable: func(ctx context.Context) ([]*ColumnDesc, error) { - res := []*ColumnDesc{} + var res []*ColumnDesc res = append(res, dummyCityColumns...) res = append(res, dummyCountryColumns...) res = append(res, dummyCountryLanguageColumns...) @@ -45,7 +46,7 @@ func NewMockDBRepository(conn *sql.DB) DBRepository { }, MockDescribeDatabaseTableBySchema: func(ctx context.Context, schemaName string) ([]*ColumnDesc, error) { - res := []*ColumnDesc{} + var res []*ColumnDesc res = append(res, dummyCityColumns...) res = append(res, dummyCountryColumns...) res = append(res, dummyCountryLanguageColumns...) @@ -61,6 +62,9 @@ func NewMockDBRepository(conn *sql.DB) DBRepository { MockQuery: func(ctx context.Context, query string) (*sql.Rows, error) { return &sql.Rows{}, nil }, + MockDescribeForeignKeysBySchema: func(ctx context.Context, schemaName string) ([]*ForeignKey, error) { + return foreignKeys, nil + }, } } @@ -108,6 +112,10 @@ func (m *MockDBRepository) Query(ctx context.Context, query string) (*sql.Rows, return m.MockQuery(ctx, query) } +func (m *MockDBRepository) DescribeForeignKeysBySchema(ctx context.Context, schemaName string) ([]*ForeignKey, error) { + return m.MockDescribeForeignKeysBySchema(ctx, schemaName) +} + var dummyDatabases = []string{ "information_schema", "mysql", @@ -129,12 +137,14 @@ var dummyTables = []string{ } var dummyCityColumns = []*ColumnDesc{ { - Schema: "world", - Table: "city", - Name: "ID", - Type: "int(11)", - Null: "NO", - Key: "PRI", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "city", + Name: "ID", + }, + Type: "int(11)", + Null: "NO", + Key: "PRI", Default: sql.NullString{ String: "", Valid: false, @@ -142,12 +152,14 @@ var dummyCityColumns = []*ColumnDesc{ Extra: "auto_increment", }, { - Schema: "world", - Table: "city", - Name: "Name", - Type: "char(35)", - Null: "NO", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "city", + Name: "Name", + }, + Type: "char(35)", + Null: "NO", + Key: "", Default: sql.NullString{ String: "", Valid: false, @@ -155,12 +167,14 @@ var dummyCityColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "city", - Name: "CountryCode", - Type: "char(3)", - Null: "NO", - Key: "MUL", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "city", + Name: "CountryCode", + }, + Type: "char(3)", + Null: "NO", + Key: "MUL", Default: sql.NullString{ String: "", Valid: false, @@ -168,12 +182,14 @@ var dummyCityColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "city", - Name: "District", - Type: "char(20)", - Null: "NO", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "city", + Name: "District", + }, + Type: "char(20)", + Null: "NO", + Key: "", Default: sql.NullString{ String: "", Valid: false, @@ -181,12 +197,14 @@ var dummyCityColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "city", - Name: "Population", - Type: "int(11)", - Null: "NO", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "city", + Name: "Population", + }, + Type: "int(11)", + Null: "NO", + Key: "", Default: sql.NullString{ String: "", Valid: false, @@ -196,12 +214,14 @@ var dummyCityColumns = []*ColumnDesc{ } var dummyCountryColumns = []*ColumnDesc{ { - Schema: "world", - Table: "country", - Name: "Code", - Type: "char(3)", - Null: "NO", - Key: "PRI", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "Code", + }, + Type: "char(3)", + Null: "NO", + Key: "PRI", Default: sql.NullString{ String: "", Valid: false, @@ -209,12 +229,14 @@ var dummyCountryColumns = []*ColumnDesc{ Extra: "auto_increment", }, { - Schema: "world", - Table: "country", - Name: "Name", - Type: "char(52)", - Null: "NO", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "Name", + }, + Type: "char(52)", + Null: "NO", + Key: "", Default: sql.NullString{ String: "", Valid: false, @@ -222,12 +244,14 @@ var dummyCountryColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "country", - Name: "CountryCode", - Type: "char(3)", - Null: "NO", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "CountryCode", + }, + Type: "char(3)", + Null: "NO", + Key: "", Default: sql.NullString{ String: "", Valid: false, @@ -235,12 +259,14 @@ var dummyCountryColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "country", - Name: "Continent", - Type: "enum('Asia','Europe','North America','Africa','Oceania','Antarctica','South America')", - Null: "NO", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "Continent", + }, + Type: "enum('Asia','Europe','North America','Africa','Oceania','Antarctica','South America')", + Null: "NO", + Key: "", Default: sql.NullString{ String: "Asia", Valid: false, @@ -248,12 +274,14 @@ var dummyCountryColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "country", - Name: "Region", - Type: "char(26)", - Null: "NO", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "Region", + }, + Type: "char(26)", + Null: "NO", + Key: "", Default: sql.NullString{ String: "", Valid: false, @@ -261,12 +289,14 @@ var dummyCountryColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "country", - Name: "SurfaceArea", - Type: "decimal(10,2)", - Null: "NO", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "SurfaceArea", + }, + Type: "decimal(10,2)", + Null: "NO", + Key: "", Default: sql.NullString{ String: "0.00", Valid: false, @@ -274,12 +304,14 @@ var dummyCountryColumns = []*ColumnDesc{ Extra: "auto_increment", }, { - Schema: "world", - Table: "country", - Name: "IndepYear", - Type: "smallint(6)", - Null: "YES", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "IndepYear", + }, + Type: "smallint(6)", + Null: "YES", + Key: "", Default: sql.NullString{ String: "0", Valid: false, @@ -287,12 +319,14 @@ var dummyCountryColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "country", - Name: "LifeExpectancy", - Type: "decimal(3,1)", - Null: "YES", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "LifeExpectancy", + }, + Type: "decimal(3,1)", + Null: "YES", + Key: "", Default: sql.NullString{ String: "", Valid: false, @@ -300,12 +334,14 @@ var dummyCountryColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "country", - Name: "GNP", - Type: "decimal(10,2)", - Null: "YES", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "GNP", + }, + Type: "decimal(10,2)", + Null: "YES", + Key: "", Default: sql.NullString{ String: "", Valid: false, @@ -313,12 +349,14 @@ var dummyCountryColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "country", - Name: "GNPOld", - Type: "decimal(10,2)", - Null: "YES", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "GNPOld", + }, + Type: "decimal(10,2)", + Null: "YES", + Key: "", Default: sql.NullString{ String: "", Valid: false, @@ -326,12 +364,14 @@ var dummyCountryColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "country", - Name: "LocalName", - Type: "char(45)", - Null: "NO", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "LocalName", + }, + Type: "char(45)", + Null: "NO", + Key: "", Default: sql.NullString{ String: "", Valid: false, @@ -339,12 +379,14 @@ var dummyCountryColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "country", - Name: "GovernmentForm", - Type: "char(45)", - Null: "NO", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "GovernmentForm", + }, + Type: "char(45)", + Null: "NO", + Key: "", Default: sql.NullString{ String: "", Valid: false, @@ -352,12 +394,14 @@ var dummyCountryColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "country", - Name: "HeadOfState", - Type: "char(60)", - Null: "YES", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "HeadOfState", + }, + Type: "char(60)", + Null: "YES", + Key: "", Default: sql.NullString{ String: "", Valid: false, @@ -365,12 +409,14 @@ var dummyCountryColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "country", - Name: "Capital", - Type: "int(11)", - Null: "YES", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "Capital", + }, + Type: "int(11)", + Null: "YES", + Key: "", Default: sql.NullString{ String: "", Valid: false, @@ -378,12 +424,14 @@ var dummyCountryColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "country", - Name: "Code2", - Type: "char(2)", - Null: "NO", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "country", + Name: "Code2", + }, + Type: "char(2)", + Null: "NO", + Key: "", Default: sql.NullString{ String: "", Valid: false, @@ -393,12 +441,14 @@ var dummyCountryColumns = []*ColumnDesc{ } var dummyCountryLanguageColumns = []*ColumnDesc{ { - Schema: "world", - Table: "countrylanguage", - Name: "CountryCode", - Type: "char(3)", - Null: "NO", - Key: "PRI", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "countrylanguage", + Name: "CountryCode", + }, + Type: "char(3)", + Null: "NO", + Key: "PRI", Default: sql.NullString{ String: "", Valid: false, @@ -406,12 +456,14 @@ var dummyCountryLanguageColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "countrylanguage", - Name: "Language", - Type: "char(30)", - Null: "NO", - Key: "PRI", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "countrylanguage", + Name: "Language", + }, + Type: "char(30)", + Null: "NO", + Key: "PRI", Default: sql.NullString{ String: "", Valid: false, @@ -419,12 +471,14 @@ var dummyCountryLanguageColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "countrylanguage", - Name: "IsOfficial", - Type: "enum('T','F')", - Null: "NO", - Key: "F", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "countrylanguage", + Name: "IsOfficial", + }, + Type: "enum('T','F')", + Null: "NO", + Key: "F", Default: sql.NullString{ String: "", Valid: false, @@ -432,12 +486,14 @@ var dummyCountryLanguageColumns = []*ColumnDesc{ Extra: "", }, { - Schema: "world", - Table: "countrylanguage", - Name: "Percentage", - Type: "decimal(4,1)", - Null: "NO", - Key: "", + ColumnBase: ColumnBase{ + Schema: "world", + Table: "countrylanguage", + Name: "Percentage", + }, + Type: "decimal(4,1)", + Null: "NO", + Key: "", Default: sql.NullString{ String: "0.0", Valid: false, @@ -446,6 +502,37 @@ var dummyCountryLanguageColumns = []*ColumnDesc{ }, } +var foreignKeys = []*ForeignKey{ + { + [2]*ColumnBase{ + { + Schema: "world", + Table: "city", + Name: "CountryCode", + }, + { + Schema: "world", + Table: "country", + Name: "Code", + }, + }, + }, + { + [2]*ColumnBase{ + { + Schema: "world", + Table: "countrylanguage", + Name: "CountryCode", + }, + { + Schema: "world", + Table: "country", + Name: "Code", + }, + }, + }, +} + type MockResult struct { MockLastInsertID func() (int64, error) MockRowsAffected func() (int64, error) diff --git a/internal/database/mssql.go b/internal/database/mssql.go index 380735a..7bf5f57 100644 --- a/internal/database/mssql.go +++ b/internal/database/mssql.go @@ -270,7 +270,7 @@ func (db *MssqlDBRepository) DescribeDatabaseTableBySchema(ctx context.Context, AND tc.TABLE_NAME = c.TABLE_NAME AND tc.CONSTRAINT_NAME = ccu.CONSTRAINT_NAME WHERE - c.TABLE_SCHEMA = $1 + c.TABLE_SCHEMA = @p1 ORDER BY c.TABLE_NAME, c.ORDINAL_POSITION @@ -300,6 +300,37 @@ func (db *MssqlDBRepository) DescribeDatabaseTableBySchema(ctx context.Context, return tableInfos, nil } +func (db *MssqlDBRepository) DescribeForeignKeysBySchema(ctx context.Context, schemaName string) ([]*ForeignKey, error) { + rows, err := db.Conn.QueryContext( + ctx, + ` + SELECT fk.name, + src_tbl.name, + src_col.name, + dst_tbl.name, + dst_col.name + FROM sys.foreign_key_columns fkc + JOIN sys.objects fk on fk.object_id = fkc.constraint_object_id + JOIN sys.tables src_tbl + ON src_tbl.object_id = fkc.parent_object_id + JOIN sys.schemas sch + ON src_tbl.schema_id = sch.schema_id + JOIN sys.columns src_col + ON src_col.column_id = parent_column_id AND src_col.object_id = src_tbl.object_id + JOIN sys.tables dst_tbl + ON dst_tbl.object_id = fkc.referenced_object_id + JOIN sys.columns dst_col + ON dst_col.column_id = referenced_column_id AND dst_col.object_id = dst_tbl.object_id + where sch.name = @p1 + order by fk.name, fkc.constraint_object_id + `, schemaName) + if err != nil { + log.Fatal(err) + } + defer func() { _ = rows.Close() }() + return parseForeignKeys(rows, schemaName) +} + func (db *MssqlDBRepository) Exec(ctx context.Context, query string) (sql.Result, error) { return db.Conn.ExecContext(ctx, query) } diff --git a/internal/database/mysql.go b/internal/database/mysql.go index a4da0f4..9ab6392 100644 --- a/internal/database/mysql.go +++ b/internal/database/mysql.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "log" "net" "strconv" @@ -302,6 +303,31 @@ WHERE information_schema.COLUMNS.TABLE_SCHEMA = ? return tableInfos, nil } +func (db *MySQLDBRepository) DescribeForeignKeysBySchema(ctx context.Context, schemaName string) ([]*ForeignKey, error) { + rows, err := db.Conn.QueryContext( + ctx, + ` + select fks.CONSTRAINT_NAME, + fks.TABLE_NAME, + kcu.COLUMN_NAME, + fks.REFERENCED_TABLE_NAME, + kcu.REFERENCED_COLUMN_NAME + from INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS fks + join INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu + on fks.CONSTRAINT_SCHEMA = kcu.TABLE_SCHEMA + and fks.TABLE_NAME = kcu.TABLE_NAME + and fks.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME + where fks.CONSTRAINT_SCHEMA = ? + order by fks.CONSTRAINT_NAME, + kcu.ORDINAL_POSITION + `, schemaName) + if err != nil { + log.Fatal(err) + } + defer func() { _ = rows.Close() }() + return parseForeignKeys(rows, schemaName) +} + func (db *MySQLDBRepository) Exec(ctx context.Context, query string) (sql.Result, error) { return db.Conn.ExecContext(ctx, query) } diff --git a/internal/database/oracle.go b/internal/database/oracle.go index 7a9335f..04d0258 100644 --- a/internal/database/oracle.go +++ b/internal/database/oracle.go @@ -231,6 +231,33 @@ func (db *OracleDBRepository) DescribeDatabaseTableBySchema(ctx context.Context, return tableInfos, nil } +func (db *OracleDBRepository) DescribeForeignKeysBySchema(ctx context.Context, schemaName string) ([]*ForeignKey, error) { + rows, err := db.Conn.QueryContext( + ctx, + ` + SELECT a.CONSTRAINT_NAME, + a.TABLE_NAME, + a.COLUMN_NAME, + b.TABLE_NAME, + b.COLUMN_NAME + FROM ALL_CONS_COLUMNS a + JOIN ALL_CONSTRAINTS c ON a.OWNER = c.OWNER + AND a.CONSTRAINT_NAME = c.CONSTRAINT_NAME + JOIN ALL_CONSTRAINTS c_pk ON c.R_OWNER = c_pk.OWNER + AND c.R_CONSTRAINT_NAME = c_pk.CONSTRAINT_NAME + JOIN ALL_CONS_COLUMNS b ON b.CONSTRAINT_NAME = c_pk.CONSTRAINT_NAME + AND b.POSITION = a.POSITION + WHERE c.constraint_type = 'R' + AND a.OWNER = :1 + ORDER BY a.CONSTRAINT_NAME, a.POSITION + `, schemaName) + if err != nil { + log.Fatal(err) + } + defer func() { _ = rows.Close() }() + return parseForeignKeys(rows, schemaName) +} + func (db *OracleDBRepository) Exec(ctx context.Context, query string) (sql.Result, error) { return db.Conn.ExecContext(ctx, query) } diff --git a/internal/database/postgresql.go b/internal/database/postgresql.go index db6d8f5..d22a608 100644 --- a/internal/database/postgresql.go +++ b/internal/database/postgresql.go @@ -334,6 +334,73 @@ func (db *PostgreSQLDBRepository) DescribeDatabaseTableBySchema(ctx context.Cont return tableInfos, nil } +func (db *PostgreSQLDBRepository) DescribeForeignKeysBySchema(ctx context.Context, schemaName string) ([]*ForeignKey, error) { + rows, err := db.Conn.QueryContext( + ctx, + ` + select kcu.CONSTRAINT_NAME, + kcu.TABLE_NAME, + kcu.COLUMN_NAME, + rel_kcu.TABLE_NAME, + rel_kcu.COLUMN_NAME + from INFORMATION_SCHEMA.TABLE_CONSTRAINTS tco + join INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu + on tco.CONSTRAINT_SCHEMA = kcu.CONSTRAINT_SCHEMA + and tco.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME + join INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rco + on tco.CONSTRAINT_SCHEMA = rco.CONSTRAINT_SCHEMA + and tco.CONSTRAINT_NAME = rco.CONSTRAINT_NAME + join INFORMATION_SCHEMA.KEY_COLUMN_USAGE rel_kcu + on rco.UNIQUE_CONSTRAINT_SCHEMA = rel_kcu.CONSTRAINT_SCHEMA + and rco.UNIQUE_CONSTRAINT_NAME = rel_kcu.CONSTRAINT_NAME + and kcu.ORDINAL_POSITION = rel_kcu.ORDINAL_POSITION + where tco.CONSTRAINT_TYPE = 'FOREIGN KEY' + and tco.CONSTRAINT_SCHEMA = $1 + order by kcu.CONSTRAINT_NAME, + kcu.ORDINAL_POSITION + `, schemaName) + if err != nil { + log.Fatal(err) + } + defer func() { _ = rows.Close() }() + var retVal []*ForeignKey + var prevFk string + var cur *ForeignKey + for rows.Next() { + var fkItem fkItemDesc + err := rows.Scan( + &fkItem.fkId, + &fkItem.table, + &fkItem.column, + &fkItem.refTable, + &fkItem.refColumn, + ) + if err != nil { + return nil, err + } + var l, r ColumnBase + l.Schema = schemaName + l.Table = fkItem.table + l.Name = fkItem.column + r.Schema = l.Schema + r.Table = fkItem.refTable + r.Name = fkItem.refColumn + if fkItem.fkId != prevFk { + if cur != nil { + retVal = append(retVal, cur) + } + cur = new(ForeignKey) + } + *cur = append(*cur, [2]*ColumnBase{&l, &r}) + prevFk = fkItem.fkId + } + + if cur != nil { + retVal = append(retVal, cur) + } + return retVal, nil +} + func (db *PostgreSQLDBRepository) Exec(ctx context.Context, query string) (sql.Result, error) { return db.Conn.ExecContext(ctx, query) } diff --git a/internal/database/sqlite3.go b/internal/database/sqlite3.go index 94305a1..b3f59eb 100644 --- a/internal/database/sqlite3.go +++ b/internal/database/sqlite3.go @@ -138,10 +138,31 @@ func (db *SQLite3DBRepository) DescribeDatabaseTable(ctx context.Context) ([]*Co return all, nil } -func (db *SQLite3DBRepository) DescribeDatabaseTableBySchema(ctx context.Context, schemaName string) ([]*ColumnDesc, error) { +func (db *SQLite3DBRepository) DescribeDatabaseTableBySchema(ctx context.Context, _ string) ([]*ColumnDesc, error) { return db.DescribeDatabaseTable(ctx) } +func (db *SQLite3DBRepository) DescribeForeignKeysBySchema(ctx context.Context, schemaName string) ([]*ForeignKey, error) { + rows, err := db.Conn.QueryContext( + ctx, + ` + SELECT m.name || p."id", + m.name, + p."from", + p."table", + p."to" + FROM sqlite_master m + JOIN pragma_foreign_key_list(m.name) p ON m.name != p."table" + WHERE m.type = 'table' + ORDER BY 1, p."seq" + `) + if err != nil { + log.Fatal(err) + } + defer func() { _ = rows.Close() }() + return parseForeignKeys(rows, schemaName) +} + func (db *SQLite3DBRepository) Exec(ctx context.Context, query string) (sql.Result, error) { return db.Conn.ExecContext(ctx, query) } diff --git a/internal/handler/completion_test.go b/internal/handler/completion_test.go index 4f3a837..9a150e2 100644 --- a/internal/handler/completion_test.go +++ b/internal/handler/completion_test.go @@ -848,6 +848,233 @@ var subQueryCase = []completionTestCase{ }, }, } +var joinClauseCase = []completionTestCase{ + { + name: "join tables", + input: "select CountryCode from city join ", + line: 0, + col: 34, + want: []string{ + "country c1 ON c1.Code = city.CountryCode", + "city", + "country", + "countrylanguage", + }, + }, + { + name: "join tables with reference", + input: "select c.CountryCode from city c join ", + line: 0, + col: 38, + want: []string{ + "country c1 ON c1.Code = c.CountryCode", + "city", + "country", + "countrylanguage", + }, + }, + { + name: "join filtered tables", + input: "select CountryCode from city join co", + line: 0, + col: 36, + want: []string{ + "country c1 ON c1.Code = city.CountryCode", + "country", + "countrylanguage", + }, + }, + { + name: "join filtered tables with reference", + input: "select c.CountryCode from city c join co", + line: 0, + col: 40, + want: []string{ + "country c1 ON c1.Code = c.CountryCode", + "country", + "countrylanguage", + }, + }, + { + name: "left join tables", + input: "select CountryCode from city left join ", + line: 0, + col: 39, + want: []string{ + "country c1 ON c1.Code = city.CountryCode", + "city", + "country", + "countrylanguage", + }, + }, + { + name: "left join tables with reference", + input: "select c.CountryCode from city c left join ", + line: 0, + col: 43, + want: []string{ + "country c1 ON c1.Code = c.CountryCode", + "city", + "country", + "countrylanguage", + }, + }, + { + name: "left outer join tables", + input: "select CountryCode from city left outer join ", + line: 0, + col: 45, + want: []string{ + "country c1 ON c1.Code = city.CountryCode", + "city", + "country", + "countrylanguage", + }, + }, + { + name: "left outer join tables with reference", + input: "select c.CountryCode from city c left outer join ", + line: 0, + col: 49, + want: []string{ + "country c1 ON c1.Code = c.CountryCode", + "city", + "country", + "countrylanguage", + }, + }, +} + +var joinConditionCase = []completionTestCase{ + { + name: "join on columns", + input: "select * from city left join country on ", + line: 0, + col: 40, + want: []string{ + "country.Code = city.CountryCode", + "Code", + "Name", + "CountryCode", + "Continent", + "Region", + "SurfaceArea", + "IndepYear", + "LifeExpectancy", + "GNP", + "GNPOld", + "LocalName", + "GovernmentForm", + "HeadOfState", + "Capital", + "Code2", + }, + }, + { + name: "join on columns with reference", + input: "select * from city c left join country co on ", + line: 0, + col: 45, + want: []string{ + "co.Code = c.CountryCode", + "Code", + "Name", + "CountryCode", + "Continent", + "Region", + "SurfaceArea", + "IndepYear", + "LifeExpectancy", + "GNP", + "GNPOld", + "LocalName", + "GovernmentForm", + "HeadOfState", + "Capital", + "Code2", + }, + }, +} + +var multiJoin = []completionTestCase{ + { + name: "join tables", + input: `select * from city c + join country c1 on c1.Code = c.CountryCode + join `, + line: 2, + col: 27, + want: []string{ + "countrylanguage c2 ON c2.CountryCode = c1.Code", + "city", + "country", + "countrylanguage", + }, + }, + { + name: "join tables start", + input: `select * from city c + join country c1 on c1.Code = c.CountryCode + join co`, + line: 2, + col: 28, + want: []string{ + "countrylanguage c2 ON c2.CountryCode = c1.Code", + "country", + "countrylanguage", + }, + }, + { + name: "join tables on", + input: `select * from city c + join country c1 on c1.Code = c.CountryCode + join countrylanguage c2 ON `, + line: 2, + col: 48, + want: []string{ + "c2.CountryCode = c1.Code", + }, + }, + { + name: "join tables prev line", + input: `select * from city c + join + join countrylanguage c2 ON c2.CountryCode = c1.Code`, + line: 1, + col: 27, + want: []string{ + "country c1 ON c1.Code = c.CountryCode", + }, + bad: []string{ + "country c1 ON c1.Code = c2.CountryCode", + }, + }, +} + +var joinSnippetCompletionCase = []completionTestCase{ + { + name: "join snippet alias", + input: "select CountryCode from city c join country c1 on c1.Code = c.CountryCode", + line: 0, + col: 44, + bad: []string{ + "country c1 ON c1.Code = city.CountryCode", + "country", + }, + }, + { + name: "join snippet alias multi", + input: `select CountryCode from city c join country c1 on c1.Code = c.CountryCode + join countrylanguage c2 on c2.CountryCode = c1.Code`, + line: 1, + col: 25, + bad: []string{ + "country c1 ON c1.Code = city.CountryCode", + "country", + "countrylanguage", + }, + }, +} func TestCompleteMain(t *testing.T) { tx := newTestContext() @@ -901,6 +1128,56 @@ func TestCompleteMain(t *testing.T) { } } +func TestCompleteJoin(t *testing.T) { + tx := newTestContext() + tx.initServer(t) + defer tx.tearDown() + + cfg := &config.Config{ + Connections: []*database.DBConfig{ + {Driver: "mock"}, + }, + } + tx.addWorkspaceConfig(t, cfg) + + testcaseMap := map[string][]completionTestCase{ + "join clause": joinClauseCase, + "join condition": joinConditionCase, + "multi-join": multiJoin, + "snippet": joinSnippetCompletionCase, + } + + for k, v := range testcaseMap { + for _, tt := range v { + t.Run(k+" "+tt.name, func(t *testing.T) { + tx.textDocumentDidOpen(t, testFileURI, tt.input) + + commpletionParams := lsp.CompletionParams{ + TextDocumentPositionParams: lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{ + URI: testFileURI, + }, + Position: lsp.Position{ + Line: tt.line, + Character: tt.col, + }, + }, + CompletionContext: lsp.CompletionContext{ + TriggerKind: 0, + TriggerCharacter: nil, + }, + } + + var got []lsp.CompletionItem + if err := tx.conn.Call(tx.ctx, "textDocument/completion", commpletionParams, &got); err != nil { + t.Fatal("conn.Call textDocument/completion:", err) + } + testCompletionItem(t, tt.want, tt.bad, got) + }) + } + } +} + func TestCompleteNoneDBConnection(t *testing.T) { tx := newTestContext() tx.initServer(t) diff --git a/parser/parseutil/parseutil.go b/parser/parseutil/parseutil.go index cf682f4..34bebab 100644 --- a/parser/parseutil/parseutil.go +++ b/parser/parseutil/parseutil.go @@ -184,6 +184,48 @@ func ExtractSubQueryViews(parsed ast.TokenList, pos token.Pos) ([]*SubQueryInfo, } func ExtractTable(parsed ast.TokenList, pos token.Pos) ([]*TableInfo, error) { + return extractTables(parsed, pos, false) +} + +func ExtractPrevTables(parsed ast.TokenList, pos token.Pos) ([]*TableInfo, error) { + return extractTables(parsed, pos, true) +} + +func ExtractLastTable(parsed ast.TokenList, pos token.Pos) (*TableInfo, error) { + nodes := ExtractTableFactor(parsed) + if len(nodes) == 0 { + return nil, nil + } + + var all []*TableInfo + for _, ident := range nodes { + p := ident.Pos() + if token.ComparePos(p, pos) > 0 { + continue + } + + if isFollowedByOn(parsed, p) { + continue + } + + if isSubQueryByNode(ident) { + continue + } + infos, err := parseTableInfo(ident) + if err != nil { + return nil, err + } + all = append(all, infos...) + } + l := len(all) + var res *TableInfo + if l != 0 { + res = all[l-1] + } + return res, nil +} + +func extractTables(parsed ast.TokenList, pos token.Pos, stopOnPos bool) ([]*TableInfo, error) { stmt, err := extractFocusedStatement(parsed, pos) if err != nil { return nil, err @@ -192,7 +234,11 @@ func ExtractTable(parsed ast.TokenList, pos token.Pos) ([]*TableInfo, error) { if encloseIsSubQuery(stmt, pos) { list = extractFocusedSubQuery(stmt, pos) } - tables, err := extractTableIdentifier(list, false) + var stopPos *token.Pos + if stopOnPos { + stopPos = &pos + } + tables, err := extractTableIdentifier(list, false, stopPos) if err != nil { return nil, err } @@ -209,6 +255,29 @@ func ExtractTable(parsed ast.TokenList, pos token.Pos) ([]*TableInfo, error) { return cleanTables, nil } +func isFollowedByOn(parsed ast.TokenList, pos token.Pos) bool { + nw := NewNodeWalker(parsed, pos) + for _, n := range nw.Paths { + if n.PeekNodeIs(true, + astutil.NodeMatcher{ + NodeTypes: []ast.NodeType{ast.TypeAliased}}) { + if !n.NextNode(true) { + continue + } + } + if n.PeekNodeIs(true, genKeywordMatcher([]string{"ON"})) { + if !n.NextNode(true) { + continue + } + if n.PeekNodeIs(true, astutil.NodeMatcher{ + NodeTypes: []ast.NodeType{ast.TypeComparison}}) { + return true + } + } + } + return false +} + var identifierMatcher = astutil.NodeMatcher{ NodeTypes: []ast.NodeType{ ast.TypeIdentifer, @@ -219,7 +288,7 @@ var identifierMatcher = astutil.NodeMatcher{ } func extractSubQueryColumns(selectStmt ast.TokenList) ([]*SubQueryColumn, []*TableInfo, error) { - tables, err := extractTableIdentifier(selectStmt, true) + tables, err := extractAllTableIdentifiers(selectStmt, true) if err != nil { return nil, nil, err } @@ -268,7 +337,7 @@ func extractSubQueryColumns(selectStmt ast.TokenList) ([]*SubQueryColumn, []*Tab return realIdents, tables, nil } -func extractTableIdentifier(list ast.TokenList, isSubQuery bool) ([]*TableInfo, error) { +func extractTableIdentifier(list ast.TokenList, isSubQuery bool, stopPos *token.Pos) ([]*TableInfo, error) { nodes := []ast.Node{} nodes = append(nodes, ExtractTableReferences(list)...) nodes = append(nodes, ExtractTableReference(list)...) @@ -278,6 +347,11 @@ func extractTableIdentifier(list ast.TokenList, isSubQuery bool) ([]*TableInfo, if !isSubQuery && isSubQueryByNode(ident) { continue } + + if stopPos != nil && token.ComparePos(ident.Pos(), *stopPos) > 0 { + continue + } + infos, err := parseTableInfo(ident) if err != nil { return nil, err @@ -287,6 +361,10 @@ func extractTableIdentifier(list ast.TokenList, isSubQuery bool) ([]*TableInfo, return res, nil } +func extractAllTableIdentifiers(list ast.TokenList, isSubQuery bool) ([]*TableInfo, error) { + return extractTableIdentifier(list, isSubQuery, nil) +} + func filterTokenList(reader *astutil.NodeReader, matcher astutil.NodeMatcher) ast.TokenList { var res []ast.Node for reader.NextNode(false) { @@ -366,7 +444,7 @@ func aliasedToTableInfo(aliased *ast.Aliased) (*TableInfo, error) { ti.DatabaseSchema = v.Parent.String() ti.Name = v.GetChild().String() case *ast.Parenthesis: - tables, err := extractTableIdentifier(v.Inner(), true) + tables, err := extractAllTableIdentifiers(v.Inner(), true) if err != nil { panic(err) } diff --git a/parser/parseutil/position.go b/parser/parseutil/position.go index 9e5d51b..14e28cb 100644 --- a/parser/parseutil/position.go +++ b/parser/parseutil/position.go @@ -17,6 +17,8 @@ const ( TableReference SyntaxPosition = "table_reference" InsertColumn SyntaxPosition = "insert_column" InsertValue SyntaxPosition = "insert_value" + JoinClause SyntaxPosition = "join_clause" + JoinOn SyntaxPosition = "join_on" Unknown SyntaxPosition = "unknown" ) @@ -48,8 +50,6 @@ func CheckSyntaxPosition(nw *NodeWalker) SyntaxPosition { // WHERE Clause "WHERE", "HAVING", - // JOIN Clause - "ON", // Operator "AND", "OR", @@ -74,14 +74,7 @@ func CheckSyntaxPosition(nw *NodeWalker) SyntaxPosition { // INSERT Statement "INSERT INTO", // JOIN Clause - "JOIN", - "INNER JOIN", "CROSS JOIN", - "OUTER JOIN", - "LEFT JOIN", - "RIGHT JOIN", - "LEFT OUTER JOIN", - "RIGHT OUTER JOIN", // DESCRIBE Statement "DESCRIBE", "DESC", @@ -89,6 +82,20 @@ func CheckSyntaxPosition(nw *NodeWalker) SyntaxPosition { "TRUNCATE", })): res = TableReference + case nw.PrevNodesIs(true, genKeywordMatcher([]string{ + "ON", + })): + res = getJoinOnCondition(nw) + case nw.PrevNodesIs(true, genKeywordMatcher([]string{ + "JOIN", + "INNER JOIN", + "OUTER JOIN", + "LEFT JOIN", + "RIGHT JOIN", + "LEFT OUTER JOIN", + "RIGHT OUTER JOIN", + })): + res = getJoinCondition(nw) case isInsertColumns(nw): if isInsertValues(nw) { res = InsertValue @@ -101,6 +108,27 @@ func CheckSyntaxPosition(nw *NodeWalker) SyntaxPosition { return res } +func getJoinCondition(nw *NodeWalker) SyntaxPosition { + for _, n := range nw.Paths { + if n.PeekNodeIs(true, genKeywordMatcher([]string{"ON"})) { + return TableReference + } + } + return JoinClause +} +func getJoinOnCondition(nw *NodeWalker) SyntaxPosition { + switch { + case nw.CurNodeIs(genTokenMatcher([]token.Kind{token.Period})): + return ColName + case nw.CurNodeIs(genTokenMatcher([]token.Kind{token.Whitespace})): + if !nw.PrevNodesIs(true, astutil.NodeMatcher{ + ExpectTokens: []token.Kind{token.Eq}}) { + return JoinOn + } + } + return WhereCondition +} + func genKeywordMatcher(keywords []string) astutil.NodeMatcher { return astutil.NodeMatcher{ ExpectKeyword: keywords, diff --git a/parser/parseutil/position_test.go b/parser/parseutil/position_test.go index 022ab7e..b288b1b 100644 --- a/parser/parseutil/position_test.go +++ b/parser/parseutil/position_test.go @@ -77,6 +77,114 @@ func TestCheckSyntaxPosition(t *testing.T) { }, want: InsertValue, }, + { + name: "join tables", + text: "select CountryCode from city join ", + pos: token.Pos{ + Line: 0, + Col: 34, + }, + want: JoinClause, + }, + { + name: "join filterd tables", + text: "select CountryCode from city join co", + pos: token.Pos{ + Line: 0, + Col: 36, + }, + want: JoinClause, + }, + { + name: "left join tables", + text: "select CountryCode from city left join ", + pos: token.Pos{ + Line: 0, + Col: 39, + }, + want: JoinClause, + }, + { + name: "left outer join tables", + text: "select CountryCode from city left outer join ", + pos: token.Pos{ + Line: 0, + Col: 45, + }, + want: JoinClause, + }, + { + name: "join on columns", + text: "select * from city left join country on ", + pos: token.Pos{ + Line: 0, + Col: 40, + }, + want: JoinOn, + }, + { + name: "join on filterd columns", + text: "select * from city left join country on co", + pos: token.Pos{ + Line: 0, + Col: 42, + }, + want: WhereCondition, + }, + { + name: "join on table", + text: "select * from city left join country on country.", + pos: token.Pos{ + Line: 0, + Col: 48, + }, + want: ColName, + }, + { + name: "join on ", + text: "select * from city left join country on country.Code =", + pos: token.Pos{ + Line: 0, + Col: 54, + }, + want: WhereCondition, + }, + { + name: "join on ", + text: "select * from city left join country on country.Code = ", + pos: token.Pos{ + Line: 0, + Col: 55, + }, + want: WhereCondition, + }, + { + name: "join on ref tables filtered", + text: "select * from city left join country on country.Code = ci", + pos: token.Pos{ + Line: 0, + Col: 57, + }, + want: WhereCondition, + }, + { + name: "join on ref table", + text: "select * from city left join country on country.Code = city.", + pos: token.Pos{ + Line: 0, + Col: 60, + }, + want: ColName, + }, + { + name: "join alias snippet", + text: "select * from city c left join country c1 on c1.Code", + pos: token.Pos{ + Line: 0, + Col: 39, + }, + want: TableReference, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {