Skip to content

Commit

Permalink
Merge pull request #7287 from planetscale/call-proc
Browse files Browse the repository at this point in the history
Support for CALL procedures
  • Loading branch information
harshit-gangal committed Jan 27, 2021
2 parents f8c0583 + 9438171 commit 68286d9
Show file tree
Hide file tree
Showing 40 changed files with 6,606 additions and 5,911 deletions.
6 changes: 3 additions & 3 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1331,16 +1331,16 @@ func isEOFPacket(data []byte) bool {
//
// Note: This is only valid on actual EOF packets and not on OK packets with the EOF
// type code set, i.e. should not be used if ClientDeprecateEOF is set.
func parseEOFPacket(data []byte) (warnings uint16, more bool, err error) {
func parseEOFPacket(data []byte) (warnings uint16, statusFlags uint16, err error) {
// The warning count is in position 2 & 3
warnings, _, _ = readUint16(data, 1)

// The status flag is in position 4 & 5
statusFlags, _, ok := readUint16(data, 3)
if !ok {
return 0, false, vterrors.Errorf(vtrpc.Code_INTERNAL, "invalid EOF packet statusFlags: %v", data)
return 0, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "invalid EOF packet statusFlags: %v", data)
}
return warnings, (statusFlags & ServerMoreResultsExists) != 0, nil
return warnings, statusFlags, nil
}

// PacketOK contains the ok packet details
Expand Down
3 changes: 3 additions & 0 deletions go/mysql/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,9 @@ func (t testRun) ComQuery(c *Conn, query string, callback func(*sqltypes.Result)
if strings.Contains(query, "panic") {
panic("test panic attack!")
}
if strings.Contains(query, "twice") {
callback(selectRowsResult)
}
callback(selectRowsResult)
return nil
}
Expand Down
1 change: 1 addition & 0 deletions go/mysql/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ const (
ERDuplicatedValueInType = 1291
ERRowIsReferenced2 = 1451
ErNoReferencedRow2 = 1452
ErSPNotVarArg = 1414

// already exists
ERTableExists = 1050
Expand Down
10 changes: 9 additions & 1 deletion go/mysql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ func (c *Conn) ExecuteFetchMulti(query string, maxrows int, wantfields bool) (re
}

res, more, _, err := c.ReadQueryResult(maxrows, wantfields)
if err != nil {
return nil, false, err
}
return res, more, err
}

Expand Down Expand Up @@ -358,6 +361,7 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result,
RowsAffected: packetOk.affectedRows,
InsertID: packetOk.lastInsertID,
SessionStateChanges: packetOk.sessionStateData,
StatusFlags: packetOk.statusFlags,
}, more, warnings, nil
}

Expand Down Expand Up @@ -426,10 +430,13 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result,
// The deprecated EOF packets change means that this is either an
// EOF packet or an OK packet with the EOF type code.
if c.Capabilities&CapabilityClientDeprecateEOF == 0 {
warnings, more, err = parseEOFPacket(data)
var statusFlags uint16
warnings, statusFlags, err = parseEOFPacket(data)
if err != nil {
return nil, false, 0, err
}
more = (statusFlags & ServerMoreResultsExists) != 0
result.StatusFlags = statusFlags
} else {
packetOk, err := c.parseOKPacket(data)
if err != nil {
Expand All @@ -438,6 +445,7 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result,
warnings = packetOk.warnings
more = (packetOk.statusFlags & ServerMoreResultsExists) != 0
result.SessionStateChanges = packetOk.sessionStateData
result.StatusFlags = packetOk.statusFlags
}
return result, more, warnings, nil

Expand Down
29 changes: 29 additions & 0 deletions go/sqltypes/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,27 @@ type Result struct {
InsertID uint64 `json:"insert_id"`
Rows [][]Value `json:"rows"`
SessionStateChanges string `json:"session_state_changes"`
StatusFlags uint16 `json:"status_flags"`
}

//goland:noinspection GoUnusedConst
const (
ServerStatusInTrans = 0x0001
ServerStatusAutocommit = 0x0002
ServerMoreResultsExists = 0x0008
ServerStatusNoGoodIndexUsed = 0x0010
ServerStatusNoIndexUsed = 0x0020
ServerStatusCursorExists = 0x0040
ServerStatusLastRowSent = 0x0080
ServerStatusDbDropped = 0x0100
ServerStatusNoBackslashEscapes = 0x0200
ServerStatusMetadataChanged = 0x0400
ServerQueryWasSlow = 0x0800
ServerPsOutParams = 0x1000
ServerStatusInTransReadonly = 0x2000
ServerSessionStateChanged = 0x4000
)

// ResultStream is an interface for receiving Result. It is used for
// RPC interfaces.
type ResultStream interface {
Expand Down Expand Up @@ -225,3 +244,13 @@ func (result *Result) AppendResult(src *Result) {
func (result *Result) Named() *NamedResult {
return ToNamedResult(result)
}

// IsMoreResultsExists returns true if the status flag has SERVER_MORE_RESULTS_EXISTS set
func (result *Result) IsMoreResultsExists() bool {
return result.StatusFlags&ServerMoreResultsExists == ServerMoreResultsExists
}

// IsInTransaction returns true if the status flag has SERVER_STATUS_IN_TRANS set
func (result *Result) IsInTransaction() bool {
return result.StatusFlags&ServerStatusInTrans == ServerStatusInTrans
}
9 changes: 6 additions & 3 deletions go/test/endtoend/cluster/vtctlclient_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"os/exec"
"strings"

"vitess.io/vitess/go/vt/vterrors"

"vitess.io/vitess/go/vt/log"
)

Expand Down Expand Up @@ -61,9 +63,10 @@ func (vtctlclient *VtctlClientProcess) ApplySchemaWithOutput(Keyspace string, SQ
}

// ApplySchema applies SQL schema to the keyspace
func (vtctlclient *VtctlClientProcess) ApplySchema(Keyspace string, SQL string) (err error) {
_, err = vtctlclient.ApplySchemaWithOutput(Keyspace, SQL, "direct")
return err
func (vtctlclient *VtctlClientProcess) ApplySchema(Keyspace string, SQL string) error {
message, err := vtctlclient.ApplySchemaWithOutput(Keyspace, SQL, "direct")

return vterrors.Wrap(err, message)
}

// ApplyVSchema applies vitess schema (JSON format) to the keyspace
Expand Down
107 changes: 107 additions & 0 deletions go/test/endtoend/vtgate/unsharded/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ import (
"fmt"
"os"
"testing"
"time"

"vitess.io/vitess/go/vt/log"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -91,6 +94,55 @@ CREATE TABLE allDefaults (
}
}
}
`

createProcSQL = `use vt_customer;
CREATE PROCEDURE sp_insert()
BEGIN
insert into allDefaults () values ();
END;
CREATE PROCEDURE sp_delete()
BEGIN
delete from allDefaults;
END;
CREATE PROCEDURE sp_multi_dml()
BEGIN
insert into allDefaults () values ();
delete from allDefaults;
END;
CREATE PROCEDURE sp_variable()
BEGIN
insert into allDefaults () values ();
SELECT min(id) INTO @myvar FROM allDefaults;
DELETE FROM allDefaults WHERE id = @myvar;
END;
CREATE PROCEDURE sp_select()
BEGIN
SELECT * FROM allDefaults;
END;
CREATE PROCEDURE sp_all()
BEGIN
insert into allDefaults () values ();
select * from allDefaults;
delete from allDefaults;
set autocommit = 0;
END;
CREATE PROCEDURE in_parameter(IN val int)
BEGIN
insert into allDefaults(id) values(val);
END;
CREATE PROCEDURE out_parameter(OUT val int)
BEGIN
insert into allDefaults(id) values (128);
select 128 into val from dual;
END;
`
)

Expand All @@ -114,11 +166,19 @@ func TestMain(m *testing.M) {
VSchema: VSchema,
}
if err := clusterInstance.StartUnshardedKeyspace(*Keyspace, 0, false); err != nil {
log.Fatal(err.Error())
return 1
}

// Start vtgate
if err := clusterInstance.StartVtgate(); err != nil {
log.Fatal(err.Error())
return 1
}

masterProcess := clusterInstance.Keyspaces[0].Shards[0].MasterTablet().VttabletProcess
if _, err := masterProcess.QueryTablet(createProcSQL, KeyspaceName, false); err != nil {
log.Fatal(err.Error())
return 1
}

Expand Down Expand Up @@ -215,6 +275,53 @@ func TestDDLUnsharded(t *testing.T) {
assertMatches(t, conn, "show tables", `[[VARCHAR("allDefaults")] [VARCHAR("t1")]]`)
}

func TestCallProcedure(t *testing.T) {
defer cluster.PanicHandler(t)
ctx := context.Background()
vtParams := mysql.ConnParams{
Host: "localhost",
Port: clusterInstance.VtgateMySQLPort,
Flags: mysql.CapabilityClientMultiResults,
DbName: "@master",
}
time.Sleep(5 * time.Second)
conn, err := mysql.Connect(ctx, &vtParams)
require.NoError(t, err)
defer conn.Close()
qr := exec(t, conn, `CALL sp_insert()`)
require.EqualValues(t, 1, qr.RowsAffected)

_, err = conn.ExecuteFetch(`CALL sp_select()`, 1000, true)
require.Error(t, err)
require.Contains(t, err.Error(), "Multi-Resultset not supported in stored procedure")

_, err = conn.ExecuteFetch(`CALL sp_all()`, 1000, true)
require.Error(t, err)
require.Contains(t, err.Error(), "Multi-Resultset not supported in stored procedure")

qr = exec(t, conn, `CALL sp_delete()`)
require.GreaterOrEqual(t, 1, int(qr.RowsAffected))

qr = exec(t, conn, `CALL sp_multi_dml()`)
require.EqualValues(t, 1, qr.RowsAffected)

qr = exec(t, conn, `CALL sp_variable()`)
require.EqualValues(t, 1, qr.RowsAffected)

qr = exec(t, conn, `CALL in_parameter(42)`)
require.EqualValues(t, 1, qr.RowsAffected)

_ = exec(t, conn, `SET @foo = 123`)
qr = exec(t, conn, `CALL in_parameter(@foo)`)
require.EqualValues(t, 1, qr.RowsAffected)
qr = exec(t, conn, "select * from allDefaults where id = 123")
assert.NotEmpty(t, qr.Rows)

_, err = conn.ExecuteFetch(`CALL out_parameter(@foo)`, 100, true)
require.Error(t, err)
require.Contains(t, err.Error(), "OUT and INOUT parameters are not supported")
}

func exec(t *testing.T, conn *mysql.Conn, query string) *sqltypes.Result {
t.Helper()
qr, err := conn.ExecuteFetch(query, 1000, true)
Expand Down
7 changes: 6 additions & 1 deletion go/vt/sqlparser/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ const (
StmtLockTables
StmtUnlockTables
StmtFlush
StmtCallProc
)

//ASTToStatementType returns a StatementType from an AST stmt
Expand Down Expand Up @@ -103,6 +104,8 @@ func ASTToStatementType(stmt Statement) StatementType {
return StmtUnlockTables
case *Flush:
return StmtFlush
case *CallProc:
return StmtCallProc
default:
return StmtUnknown
}
Expand All @@ -111,7 +114,7 @@ func ASTToStatementType(stmt Statement) StatementType {
//CanNormalize takes Statement and returns if the statement can be normalized.
func CanNormalize(stmt Statement) bool {
switch stmt.(type) {
case *Select, *Union, *Insert, *Update, *Delete, *Set:
case *Select, *Union, *Insert, *Update, *Delete, *Set, *CallProc: // TODO: we could merge this logic into ASTrewriter
return true
}
return false
Expand Down Expand Up @@ -262,6 +265,8 @@ func (s StatementType) String() string {
return "UNLOCK_TABLES"
case StmtFlush:
return "FLUSH"
case StmtCallProc:
return "CALL_PROC"
default:
return "UNKNOWN"
}
Expand Down
12 changes: 12 additions & 0 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,12 @@ type (
// ExplainType is an enum for Explain.Type
ExplainType int8

// CallProc represents a CALL statement
CallProc struct {
Name TableName
Params Exprs
}

// OtherRead represents a DESCRIBE, or EXPLAIN statement.
// It should be used only as an indicator. It does not contain
// the full AST for the statement.
Expand Down Expand Up @@ -585,6 +591,7 @@ func (*DropTable) iStatement() {}
func (*DropView) iStatement() {}
func (*TruncateTable) iStatement() {}
func (*RenameTable) iStatement() {}
func (*CallProc) iStatement() {}

func (*CreateView) iDDLStatement() {}
func (*AlterView) iDDLStatement() {}
Expand Down Expand Up @@ -2490,6 +2497,11 @@ func (node *Explain) Format(buf *TrackedBuffer) {
buf.astPrintf(node, "explain %s%v", format, node.Statement)
}

// Format formats the node.
func (node *CallProc) Format(buf *TrackedBuffer) {
buf.astPrintf(node, "call %v(%v)", node.Name, node.Params)
}

// Format formats the node.
func (node *OtherRead) Format(buf *TrackedBuffer) {
buf.WriteString("otherread")
Expand Down
4 changes: 4 additions & 0 deletions go/vt/sqlparser/ast_rewriting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ func TestRewrites(in *testing.T) {
// SELECT * behaves different depending the join type used, so if that has been used, we won't rewrite
in: "SELECT * FROM A JOIN B USING (id1,id2,id3)",
expected: "SELECT * FROM A JOIN B USING (id1,id2,id3)",
}, {
in: "CALL proc(@foo)",
expected: "CALL proc(:__vtudvfoo)",
udv: 1,
}}

for _, tc := range tests {
Expand Down
8 changes: 8 additions & 0 deletions go/vt/sqlparser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1843,6 +1843,14 @@ var (
input: "release savepoint a",
}, {
input: "release savepoint `@@@;a`",
}, {
input: "call proc()",
}, {
input: "call qualified.proc()",
}, {
input: "call proc(1, 'foo')",
}, {
input: "call proc(@param)",
}}
)

Expand Down

0 comments on commit 68286d9

Please sign in to comment.