Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement insert from select query with constant #136

Merged
merged 1 commit into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions router/pkg/qrouter/proxy_routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ type RoutingMetadataContext struct {
InsertStmtCols []string
InsertStmtRel string

// For
// INSERT INTO x (...) SELECT ...
TargetList []*pgquery.Node

// TODO: include client ops and metadata here
}

Expand Down Expand Up @@ -252,6 +256,7 @@ func (qr *ProxyQrouter) DeparseSelectStmt(ctx context.Context, selectStmt *pgque

switch q := selectStmt.Node.(type) {
case *pgquery.Node_SelectStmt:
meta.TargetList = q.SelectStmt.TargetList
if clause := q.SelectStmt.FromClause; clause != nil {
// route `insert into rel select from` stmt
spqrlog.Logger.Printf(spqrlog.DEBUG5, "deparsing select from clause, %+v", clause)
Expand Down Expand Up @@ -414,6 +419,7 @@ func (qr *ProxyQrouter) CheckTableIsRoutable(ctx context.Context, node *pgquery.
}

func (qr *ProxyQrouter) Route(ctx context.Context, parsedStmt *pgquery.ParseResult) (RoutingState, error) {
var insert_err error
if parsedStmt == nil {
return nil, ComplexQuery
}
Expand Down Expand Up @@ -475,7 +481,7 @@ func (qr *ProxyQrouter) Route(ctx context.Context, parsedStmt *pgquery.ParseResu
return MultiMatchState{}, nil
}
}
return nil, err
insert_err = err
}
default:
// SELECT, UPDATE and/or DELETE stmts, which
Expand Down Expand Up @@ -513,11 +519,12 @@ func (qr *ProxyQrouter) Route(ctx context.Context, parsedStmt *pgquery.ParseResu
}

spqrlog.Logger.Printf(spqrlog.DEBUG4, "deparsed values list %+v, insertStmtCols %+v", meta.ValuesLists, meta.InsertStmtCols)
if meta.ValuesLists != nil && len(meta.InsertStmtCols) != 0 {
if len(meta.InsertStmtCols) != 0 {
if rule, err := ops.MatchShardingRule(ctx, qr.qdb, meta.InsertStmtRel, meta.InsertStmtCols); err != nil {
// compute matched sharding rule offsets
offsets := make([]int, 0)
j := 0
// TODO: check mapping by rules with multiple columns
for i, s := range meta.InsertStmtCols {
if j == len(rule.Entries) {
break
Expand All @@ -528,17 +535,38 @@ func (qr *ProxyQrouter) Route(ctx context.Context, parsedStmt *pgquery.ParseResu
}

meta.offsets = offsets
routed := false
if insert_err != nil {
if len(meta.offsets) != 0 && len(meta.TargetList) > meta.offsets[0] {
currroute, err := qr.RouteKeyWithRanges(ctx, meta.TargetList[meta.offsets[0]].GetResTarget().Val, meta)
if err != nil {
return nil, err
}

// only firt value from value list
currroute, err := qr.RouteKeyWithRanges(ctx, meta.ValuesLists[0], meta)
if err != nil {
return nil, err
spqrlog.Logger.Printf(spqrlog.DEBUG4, "deparsed route from %+v", currroute)
routed = true
if route == nil {
route = currroute
} else {
route = combine(route, currroute)
}
} else {
return nil, insert_err
}
}
spqrlog.Logger.Printf(spqrlog.DEBUG4, "deparsed route from %+v", currroute)
if route == nil {
route = currroute
} else {
route = combine(route, currroute)

if !routed && meta.ValuesLists != nil {
// only first value from value list
currroute, err := qr.RouteKeyWithRanges(ctx, meta.ValuesLists[0], meta)
if err != nil {
return nil, err
}
spqrlog.Logger.Printf(spqrlog.DEBUG4, "deparsed route from %+v", currroute)
if route == nil {
route = currroute
} else {
route = combine(route, currroute)
}
}
}
}
Expand Down
21 changes: 21 additions & 0 deletions test/regress/tests/router/expected/shard_routing.out
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,27 @@ NOTICE: send query to shard(s) : sh1
20
(2 rows)

-- check that `INSERT FROM SELECT` with constant works
INSERT INTO xx (w_id) SELECT 20;
NOTICE: send query to shard(s) : sh1
SELECT * FROM xx WHERE w_id >= 20;
NOTICE: send query to shard(s) : sh1
w_id
------
20
20
20
(3 rows)

INSERT INTO xxtt1 (j, w_id) SELECT a, 20 from unnest(ARRAY[10]) a;
NOTICE: send query to shard(s) : sh1
SELECT * FROM xxtt1 WHERE w_id = 20;
NOTICE: send query to shard(s) : sh1
i | j | w_id
---+----+------
| 10 | 20
(1 row)

DROP TABLE xx;
NOTICE: send query to shard(s) : sh1,sh2
DROP TABLE xxtt1;
Expand Down
6 changes: 6 additions & 0 deletions test/regress/tests/router/sql/shard_routing.sql
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ SELECT * FROM xxtt1 a WHERE a.w_id >= 21;
INSERT INTO xx SELECT * FROM xx a WHERE a.w_id = 20;
SELECT * FROM xx WHERE w_id >= 20;

-- check that `INSERT FROM SELECT` with constant works
INSERT INTO xx (w_id) SELECT 20;
SELECT * FROM xx WHERE w_id >= 20;
INSERT INTO xxtt1 (j, w_id) SELECT a, 20 from unnest(ARRAY[10]) a;
SELECT * FROM xxtt1 WHERE w_id = 20;

DROP TABLE xx;
DROP TABLE xxtt1;