Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cancel proto msgs #150

Merged
merged 4 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions pkg/conn/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ type DBInstance interface {
Close() error
Status() InstanceStatus
SetStatus(status InstanceStatus)

Cancel(csm *pgproto3.CancelRequest) error

Tls() *tls.Config
}

type PostgreSQLInstance struct {
Expand All @@ -47,6 +51,8 @@ type PostgreSQLInstance struct {

hostname string
status InstanceStatus

tlsconfig *tls.Config
}

func (pgi *PostgreSQLInstance) SetStatus(status InstanceStatus) {
Expand Down Expand Up @@ -80,9 +86,10 @@ func NewInstanceConn(host string, tlsconfig *tls.Config) (DBInstance, error) {
}

instance := &PostgreSQLInstance{
hostname: host,
conn: netconn,
status: NotInitialized,
hostname: host,
conn: netconn,
status: NotInitialized,
tlsconfig: tlsconfig,
}

if tlsconfig != nil {
Expand All @@ -98,6 +105,10 @@ func NewInstanceConn(host string, tlsconfig *tls.Config) (DBInstance, error) {
return instance, nil
}

func (pgi *PostgreSQLInstance) Cancel(csm *pgproto3.CancelRequest) error {
return pgi.frontend.Send(csm)
}

func (pgi *PostgreSQLInstance) CheckRW() (bool, error) {
msg := &pgproto3.Query{
String: "SELECT pg_is_in_recovery()",
Expand Down Expand Up @@ -129,6 +140,10 @@ func (pgi *PostgreSQLInstance) CheckRW() (bool, error) {
}
}

func (pgi *PostgreSQLInstance) Tls() *tls.Config {
return pgi.tlsconfig
}

var _ DBInstance = &PostgreSQLInstance{}

func (pgi *PostgreSQLInstance) ReqBackendSsl(tlsconfig *tls.Config) error {
Expand Down Expand Up @@ -160,9 +175,3 @@ func (pgi *PostgreSQLInstance) ReqBackendSsl(tlsconfig *tls.Config) error {
spqrlog.Logger.Printf(spqrlog.DEBUG5, "initaited backend connection with TLS (%p)", pgi)
return nil
}

func (pgi *PostgreSQLInstance) Cancel() error {
msg := &pgproto3.CancelRequest{}

return pgi.frontend.Send(msg)
}
51 changes: 48 additions & 3 deletions router/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/tls"
"encoding/binary"
"fmt"
"math/rand"
"net"

"github.com/jackc/pgproto3/v2"
Expand Down Expand Up @@ -57,16 +58,29 @@ type RouterClient interface {
ProcCopy(query pgproto3.FrontendMessage) error
ProcCopyComplete(query *pgproto3.FrontendMessage) error
ReplyParseComplete() error

CancelMsg() *pgproto3.CancelRequest

GetCancelPid() uint32
GetCancelKey() uint32

Close() error
}

type PsqlClient struct {
client.Client

activeParamSet map[string]string
savepointParamSet map[string]map[string]string
savepointParamTxCnt map[string]int
beginTxParamSet map[string]string

/* cancel */
csm *pgproto3.CancelRequest

cancel_pid uint32
cancel_key uint32

txCnt int

rule *config.FrontendRule
Expand All @@ -86,6 +100,14 @@ type PsqlClient struct {
server server.Server
}

func (cl *PsqlClient) GetCancelPid() uint32 {
return cl.cancel_pid
}

func (cl *PsqlClient) GetCancelKey() uint32 {
return cl.cancel_key
}

func copymap(params map[string]string) map[string]string {
ret := make(map[string]string)

Expand Down Expand Up @@ -382,7 +404,7 @@ func (cl *PsqlClient) AssignRule(rule *config.FrontendRule) error {
return nil
}

// startup + ssl
// startup + ssl/cancel
func (cl *PsqlClient) Init(tlsconfig *tls.Config) error {
spqrlog.Logger.Printf(spqrlog.LOG, "init client connection with ssl: %t", tlsconfig != nil)

Expand Down Expand Up @@ -455,18 +477,30 @@ func (cl *PsqlClient) Init(tlsconfig *tls.Config) error {
return err
}
case conn.CANCELREQ:
return fmt.Errorf("cancel is not supported")
cl.csm = &pgproto3.CancelRequest{}
if err = cl.csm.Decode(msg); err != nil {
return err
}

return nil
default:
return fmt.Errorf("protocol number %d not supported", protoVer)
}

/* setup client params */

for k, v := range sm.Parameters {
cl.SetParam(k, v)
}

cl.startupMsg = sm
cl.be = backend

cl.cancel_key = rand.Uint32()
cl.cancel_pid = rand.Uint32()

spqrlog.Logger.Printf(spqrlog.DEBUG2, "client %p cancel key/pid: %d %d", cl, cl.cancel_key, cl.cancel_pid)

if tlsconfig != nil && protoVer != conn.SSLREQ {
if err := cl.Send(
&pgproto3.ErrorResponse{
Expand Down Expand Up @@ -522,6 +556,13 @@ func (cl *PsqlClient) Auth(rt *route.Route) error {
}
}

if err := cl.Send(&pgproto3.BackendKeyData{
ProcessID: cl.cancel_pid,
SecretKey: cl.cancel_key,
}); err != nil {
return err
}

for _, msg := range []pgproto3.BackendMessage{
&pgproto3.ReadyForQuery{
TxStatus: byte(conn.TXIDLE),
Expand Down Expand Up @@ -751,7 +792,7 @@ func (cl *PsqlClient) ProcCommand(query pgproto3.FrontendMessage, waitForResp bo
case *pgproto3.ErrorResponse:
return fmt.Errorf(v.Message)
default:
spqrlog.Logger.Printf(spqrlog.DEBUG2, "client %p msg type from server: %T", v)
spqrlog.Logger.Printf(spqrlog.DEBUG2, "client %p msg type from server: %T", cl, v)
if replyCl {
err = cl.Send(msg)
if err != nil {
Expand Down Expand Up @@ -868,6 +909,10 @@ func (cl *PsqlClient) SetTsa(s string) {

var _ RouterClient = &PsqlClient{}

func (cl *PsqlClient) CancelMsg() *pgproto3.CancelRequest {
return cl.csm
}

type FakeClient struct {
RouterClient
}
Expand Down
28 changes: 27 additions & 1 deletion router/pkg/datashard/datashard.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ type Shard interface {
ConstructSM() *pgproto3.StartupMessage
Instance() conn.DBInstance

Cancel() error

Params() ParameterSet
Close() error
}
Expand Down Expand Up @@ -75,6 +77,9 @@ type Conn struct {

dedicated conn.DBInstance
ps ParameterSet

backend_key_pid uint32
backend_key_secret uint32
}

func (sh *Conn) Close() error {
Expand All @@ -87,6 +92,23 @@ func (sh *Conn) Instance() conn.DBInstance {
return sh.dedicated
}

func (sh *Conn) Cancel() error {
pgiTmp, err := conn.NewInstanceConn(sh.dedicated.Hostname(), nil /* no tls for cancel */)
if err != nil {
return err
}
defer pgiTmp.Close()

msg := &pgproto3.CancelRequest{
ProcessID: sh.backend_key_pid,
SecretKey: sh.backend_key_secret,
}

spqrlog.Logger.Printf(spqrlog.DEBUG1, "sendind cancel msg %v over %p", msg, &pgiTmp)

return pgiTmp.Cancel(msg)
}

func (sh *Conn) AddTLSConf(tlsconfig *tls.Config) error {
if err := sh.dedicated.ReqBackendSsl(tlsconfig); err != nil {
spqrlog.Logger.Printf(spqrlog.DEBUG3, "failed to init ssl on host %v of datashard %v: %v", sh.dedicated.Hostname(), sh.Name(), err)
Expand Down Expand Up @@ -172,9 +194,13 @@ func (sh *Conn) Auth(sm *pgproto3.StartupMessage) error {
Value: v.Value,
}) {
spqrlog.Logger.Printf(spqrlog.DEBUG1, "ignored parameter status %v %v", v.Name, v.Value)
} else {
spqrlog.Logger.Printf(spqrlog.DEBUG5, "parameter status %v %v", v.Name, v.Value)
}
case *pgproto3.BackendKeyData:
spqrlog.Logger.Printf(spqrlog.DEBUG1, "ignored backend key data %v %v", v.ProcessID, v.SecretKey)
sh.backend_key_pid = v.ProcessID
sh.backend_key_secret = v.SecretKey
spqrlog.Logger.Printf(spqrlog.DEBUG5, "backend key data %v %v", v.ProcessID, v.SecretKey)
default:
spqrlog.Logger.Printf(spqrlog.DEBUG1, "unexpected msg type received %T", v)
}
Expand Down
11 changes: 10 additions & 1 deletion router/pkg/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,26 @@ func (r *InstanceImpl) serv(netconn net.Conn) error {
return err
}

defer netconn.Close()

if routerClient.DB() == "spqr-console" {
return r.AdmConsole.Serve(context.Background(), routerClient)
}

spqrlog.Logger.Printf(spqrlog.DEBUG2, "clint %p: prerouting phase succeeded", routerClient)
if routerClient.CancelMsg() != nil {
return r.RuleRouter.CancelClient(routerClient.CancelMsg())
}

spqrlog.Logger.Printf(spqrlog.DEBUG2, "client %p: prerouting phase succeeded", routerClient)

cmngr, err := rulerouter.MatchConnectionPooler(routerClient, r.RuleRouter.Config())
if err != nil {
return err
}

r.RuleRouter.AddClient(routerClient)
defer r.RuleRouter.ReleaseClient(routerClient)

return Frontend(r.Qrouter, routerClient, cmngr, r.RuleRouter.Config())
}

Expand Down
7 changes: 4 additions & 3 deletions router/pkg/rulerouter/route_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ type RoutePool interface {
}

type RoutePoolImpl struct {
mu sync.Mutex
pool map[route.Key]*route.Route
shardMapping map[string]*config.Shard
mu sync.Mutex
pool map[route.Key]*route.Route
shardMapping map[string]*config.Shard
clientMapping map[int] /*backend id*/ client.Client
}

var _ RoutePool = &RoutePoolImpl{}
Expand Down
46 changes: 45 additions & 1 deletion router/pkg/rulerouter/rulerouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rulerouter

import (
"crypto/tls"
"fmt"
"log"
"net"
"os"
Expand Down Expand Up @@ -32,6 +33,11 @@ type RuleRouter interface {
AddWorldShard(key qdb.ShardKey) error
AddShardInstance(key qdb.ShardKey, host string)

CancelClient(csm *pgproto3.CancelRequest) error
AddClient(cl rclient.RouterClient)

ReleaseClient(cl rclient.RouterClient)

Config() *config.Router
}

Expand All @@ -44,6 +50,9 @@ type RuleRouterImpl struct {

mu sync.Mutex
rcfg *config.Router

clmu sync.Mutex
clmp map[uint32]rclient.RouterClient
}

func (r *RuleRouterImpl) AddWorldShard(key qdb.ShardKey) error {
Expand Down Expand Up @@ -127,6 +136,7 @@ func NewRouter(tlsconfig *tls.Config, rcfg *config.Router) *RuleRouterImpl {
rmgr: rule.NewMgr(frontendRules, backendRules, defaultFrontendRule, defaultBackendRule),
lg: log.New(os.Stdout, "router", 0),
tlsconfig: tlsconfig,
clmp: map[uint32]rclient.RouterClient{},
}
}

Expand All @@ -137,11 +147,15 @@ func (r *RuleRouterImpl) PreRoute(conn net.Conn) (rclient.RouterClient, error) {
return cl, err
}

if cl.CancelMsg() != nil {
return cl, nil
}

if cl.DB() == "spqr-console" {
return cl, nil
}

// match client frontend rule
// match client to frontend rule
key := *route.NewRouteKey(cl.Usr(), cl.DB())
frRule, err := r.rmgr.MatchKeyFrontend(key)
if err != nil {
Expand Down Expand Up @@ -199,6 +213,7 @@ func (r *RuleRouterImpl) PreRoute(conn net.Conn) (rclient.RouterClient, error) {

func (r *RuleRouterImpl) PreRouteAdm(conn net.Conn) (rclient.RouterClient, error) {
cl := rclient.NewPsqlClient(conn)

if err := cl.Init(r.tlsconfig); err != nil {
return nil, err
}
Expand Down Expand Up @@ -249,4 +264,33 @@ func (r *RuleRouterImpl) Config() *config.Router {
return r.rcfg
}

func (r *RuleRouterImpl) AddClient(cl rclient.RouterClient) {
r.clmu.Lock()
defer r.clmu.Unlock()
r.clmp[cl.GetCancelPid()] = cl
}

func (r *RuleRouterImpl) ReleaseClient(cl rclient.RouterClient) {
r.clmu.Lock()
defer r.clmu.Unlock()
delete(r.clmp, cl.GetCancelPid())
}

func (r *RuleRouterImpl) CancelClient(csm *pgproto3.CancelRequest) error {
r.clmu.Lock()
defer r.clmu.Unlock()

if cl, ok := r.clmp[csm.ProcessID]; ok {
if cl.GetCancelKey() != csm.SecretKey {
return fmt.Errorf("cancel secret does not match")
}
if cl.Server() != nil {
spqrlog.Logger.Printf(spqrlog.DEBUG1, "cancelling client pid %d", csm.ProcessID)
return cl.Server().Cancel()
}
return nil
}
return fmt.Errorf("no client with pid %d", csm.ProcessID)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure if we should return an error here. Is this error will be passing to the client?
Seems like if there is no process, then client was already released?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this error will not be passed to the client

}

var _ RuleRouter = &RuleRouterImpl{}
Loading
Loading