Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add sqlc.embed to allow model re-use #1615

Merged
merged 5 commits into from Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/cmd/shim.go
Expand Up @@ -264,6 +264,14 @@ func pluginQueryColumn(c *compiler.Column) *plugin.Column {
}
}

if c.EmbedTable != nil {
out.EmbedTable = &plugin.Identifier{
Catalog: c.EmbedTable.Catalog,
Schema: c.EmbedTable.Schema,
Name: c.EmbedTable.Name,
}
}

return out
}

Expand Down
2 changes: 2 additions & 0 deletions internal/codegen/golang/field.go
Expand Up @@ -15,6 +15,8 @@ type Field struct {
Tags map[string]string
Comment string
Column *plugin.Column
// EmbedFields contains the embedded fields that reuqire scanning.
EmbedFields []string
}

func (gf Field) Tag() string {
Expand Down
9 changes: 9 additions & 0 deletions internal/codegen/golang/query.go
Expand Up @@ -154,6 +154,15 @@ func (v QueryValue) Scan() string {
}
} else {
for _, f := range v.Struct.Fields {

// append any embedded fields
if len(f.EmbedFields) > 0 {
for _, embed := range f.EmbedFields {
out = append(out, "&"+v.Name+"."+f.Name+"."+embed)
}
continue
}

if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() {
out = append(out, "pq.Array(&"+v.Name+"."+f.Name+")")
} else {
Expand Down
63 changes: 59 additions & 4 deletions internal/codegen/golang/result.go
Expand Up @@ -103,6 +103,46 @@ func buildStructs(req *plugin.CodeGenRequest) []Struct {
type goColumn struct {
id int
*plugin.Column
embed *goEmbed
}

type goEmbed struct {
modelType string
modelName string
fields []string
}

// look through all the structs and attempt to find a matching one to embed
// We need the name of the struct and its field names.
func newGoEmbed(embed *plugin.Identifier, structs []Struct) *goEmbed {
if embed == nil {
return nil
}

for _, s := range structs {
embedSchema := "public"
if embed.Schema != "" {
embedSchema = embed.Schema
}

// compare the other attributes
if embed.Catalog != s.Table.Catalog || embed.Name != s.Table.Name || embedSchema != s.Table.Schema {
continue
}

fields := make([]string, len(s.Fields))
for i, f := range s.Fields {
fields[i] = f.Name
}

return &goEmbed{
modelType: s.Name,
modelName: s.Name,
fields: fields,
}
}

return nil
}

func columnName(c *plugin.Column, pos int) string {
Expand Down Expand Up @@ -192,7 +232,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
}
}

if len(query.Columns) == 1 {
if len(query.Columns) == 1 && query.Columns[0].EmbedTable == nil {
c := query.Columns[0]
name := columnName(c, 0)
if c.IsFuncCall {
Expand Down Expand Up @@ -234,6 +274,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
columns = append(columns, goColumn{
id: i,
Column: c,
embed: newGoEmbed(c.EmbedTable, structs),
})
}
var err error
Expand Down Expand Up @@ -287,6 +328,13 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn
for i, c := range columns {
colName := columnName(c.Column, i)
tagName := colName

// overide col/tag with expected model name
if c.embed != nil {
colName = c.embed.modelName
tagName = SetCaseStyle(colName, "snake")
}

fieldName := StructName(colName, req.Settings)
baseFieldName := fieldName
// Track suffixes by the ID of the column, so that columns referring to the same numbered parameter can be
Expand All @@ -309,13 +357,20 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn
if req.Settings.Go.EmitJsonTags {
tags["json"] = JSONTagName(tagName, req.Settings)
}
gs.Fields = append(gs.Fields, Field{
f := Field{
Name: fieldName,
DBName: colName,
Type: goType(req, c.Column),
Tags: tags,
Column: c.Column,
})
}
if c.embed == nil {
f.Type = goType(req, c.Column)
} else {
f.Type = c.embed.modelType
f.EmbedFields = c.embed.fields
}

gs.Fields = append(gs.Fields, f)
if _, found := seen[baseFieldName]; !found {
seen[baseFieldName] = []int{i}
} else {
Expand Down
9 changes: 8 additions & 1 deletion internal/compiler/expand.go
Expand Up @@ -132,9 +132,16 @@ func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node)
for _, p := range parts {
old = append(old, c.quoteIdent(p))
}
oldString := strings.Join(old, ".")

// use the sqlc.embed string instead
if embed, ok := qc.embeds.Find(ref); ok {
oldString = embed.Orig()
}

edits = append(edits, source.Edit{
Location: res.Location - raw.StmtLocation,
Old: strings.Join(old, "."),
Old: oldString,
New: strings.Join(cols, ", "),
})
}
Expand Down
13 changes: 12 additions & 1 deletion internal/compiler/output_columns.go
Expand Up @@ -14,7 +14,7 @@ import (

// OutputColumns determines which columns a statement will output
func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) {
qc, err := buildQueryCatalog(c.catalog, stmt)
qc, err := buildQueryCatalog(c.catalog, stmt, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -201,6 +201,16 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {

case *ast.ColumnRef:
if hasStarRef(n) {

// add a column with a reference to an embedded table
if embed, ok := qc.embeds.Find(n); ok {
cols = append(cols, &Column{
Name: embed.Table.Name,
EmbedTable: embed.Table,
})
continue
}

// TODO: This code is copied in func expand()
for _, t := range tables {
scope := astutils.Join(n.Fields, ".")
Expand Down Expand Up @@ -520,6 +530,7 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef)
NotNull: c.NotNull,
IsArray: c.IsArray,
Length: c.Length,
EmbedTable: c.EmbedTable,
})
}
}
Expand Down
6 changes: 4 additions & 2 deletions internal/compiler/parse.go
Expand Up @@ -86,12 +86,14 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
} else {
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
}
qc, err := buildQueryCatalog(c.catalog, raw.Stmt)

raw, embeds := rewrite.Embeds(raw)
qc, err := buildQueryCatalog(c.catalog, raw.Stmt, embeds)
if err != nil {
return nil, err
}

params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams)
params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds)
if err != nil {
return nil, err
}
Expand Down
1 change: 1 addition & 0 deletions internal/compiler/query.go
Expand Up @@ -29,6 +29,7 @@ type Column struct {
Table *ast.TableName
TableAlias string
Type *ast.TypeName
EmbedTable *ast.TableName

IsSqlcSlice bool // is this sqlc.slice()

Expand Down
6 changes: 4 additions & 2 deletions internal/compiler/query_catalog.go
Expand Up @@ -5,14 +5,16 @@ import (

"github.com/kyleconroy/sqlc/internal/sql/ast"
"github.com/kyleconroy/sqlc/internal/sql/catalog"
"github.com/kyleconroy/sqlc/internal/sql/rewrite"
)

type QueryCatalog struct {
catalog *catalog.Catalog
ctes map[string]*Table
embeds rewrite.EmbedSet
}

func buildQueryCatalog(c *catalog.Catalog, node ast.Node) (*QueryCatalog, error) {
func buildQueryCatalog(c *catalog.Catalog, node ast.Node, embeds rewrite.EmbedSet) (*QueryCatalog, error) {
var with *ast.WithClause
switch n := node.(type) {
case *ast.DeleteStmt:
Expand All @@ -26,7 +28,7 @@ func buildQueryCatalog(c *catalog.Catalog, node ast.Node) (*QueryCatalog, error)
default:
with = nil
}
qc := &QueryCatalog{catalog: c, ctes: map[string]*Table{}}
qc := &QueryCatalog{catalog: c, ctes: map[string]*Table{}, embeds: embeds}
if with != nil {
for _, item := range with.Ctes.Items {
if cte, ok := item.(*ast.CommonTableExpr); ok {
Expand Down
19 changes: 18 additions & 1 deletion internal/compiler/resolve.go
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/kyleconroy/sqlc/internal/sql/astutils"
"github.com/kyleconroy/sqlc/internal/sql/catalog"
"github.com/kyleconroy/sqlc/internal/sql/named"
"github.com/kyleconroy/sqlc/internal/sql/rewrite"
"github.com/kyleconroy/sqlc/internal/sql/sqlerr"
)

Expand All @@ -19,7 +20,7 @@ func dataType(n *ast.TypeName) string {
}
}

func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet) ([]Parameter, error) {
func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) {
c := comp.catalog

aliasMap := map[string]*ast.TableName{}
Expand Down Expand Up @@ -76,6 +77,22 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
}
}

// resolve a table for an embed
for _, embed := range embeds {
table, err := c.GetTable(embed.Table)
if err == nil {
embed.Table = table.Rel
continue
}

if alias, ok := aliasMap[embed.Table.Name]; ok {
embed.Table = alias
continue
}

return nil, fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err)
}

var a []Parameter
for _, ref := range args {
switch n := ref.parent.(type) {
Expand Down