Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix spqrdump & add feature tests #526

Merged
merged 13 commits into from
Feb 28, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ spqr-worldmock
spqr-balancer
spqr-mover
spqr-workloadreplay
spqrdump
y.output
*.swp
*.swo
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ build_worldmock:
build_workloadreplay:
go build -pgo=auto -o spqr-workloadreplay ./cmd/workloadreplay

build: build_balancer build_coordinator build_coorctl build_router build_mover build_worldmock build_workloadreplay
build_spqrdump:
go build -pgo=auto -o spqrdump ./cmd/spqrdump

build: build_balancer build_coordinator build_coorctl build_router build_mover build_worldmock build_workloadreplay build_spqrdump

build_images:
docker compose build spqr-base-image
Expand Down
184 changes: 100 additions & 84 deletions cmd/spqrdump/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"fmt"
"net"
"strings"

"github.com/jackc/pgx/v5/pgproto3"
"github.com/spf13/cobra"
Expand All @@ -24,7 +25,7 @@ func Dial(addr string) (*grpc.ClientConn, error) {
}

var rootCmd = &cobra.Command{
Use: "coorctl -e localhost:7003",
Use: "spqrdump -e localhost:7003",
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: true,
},
Expand All @@ -35,27 +36,7 @@ var rootCmd = &cobra.Command{
var endpoint string
var proto string
var passwd string

// TODO : unit tests
func DumpRules() error {
cc, err := Dial(endpoint)
if err != nil {
return err
}

rCl := protos.NewShardingRulesServiceClient(cc)
if rules, err := rCl.ListShardingRules(context.Background(), &protos.ListShardingRuleRequest{}); err != nil {
spqrlog.Zero.Error().
Err(err).
Msg("failed to dump endpoint rules")
} else {
for _, rule := range rules.Rules {
fmt.Printf("%s;\n", decode.DecodeRule(rule))
}
}

return nil
}
var logLevel string

// TODO : unit tests
func waitRFQ(fr *pgproto3.Frontend) error {
Expand Down Expand Up @@ -133,14 +114,27 @@ func getconn() (*pgproto3.Frontend, error) {
}

// TODO : unit tests
func DumpRulesPSQL() error {
func DumpKeyRangesPsql() error {
return dumpPsql("SHOW key_ranges;", func(v *pgproto3.DataRow) (string, error) {
l := string(v.Values[2])
r := string(v.Values[3])
id := string(v.Values[0])
shard := string(v.Values[1])

return decode.KeyRange(
&protos.KeyRangeInfo{
KeyRange: &protos.KeyRange{LowerBound: l, UpperBound: r},
ShardId: shard, Krid: id}), nil
})
}

func dumpPsql(query string, rowToStr func(v *pgproto3.DataRow) (string, error)) error {
frontend, err := getconn()
if err != nil {
return err
}
frontend.Send(&pgproto3.Query{
String: "SHOW key_ranges;",
String: query,
})
if err := frontend.Flush(); err != nil {
return err
Expand All @@ -156,16 +150,11 @@ func DumpRulesPSQL() error {

switch v := msg.(type) {
case *pgproto3.DataRow:
l := string(v.Values[2])
r := string(v.Values[3])
id := string(v.Values[0])
shard := string(v.Values[1])

fmt.Printf("%s;\n",
decode.DecodeKeyRange(
&protos.KeyRangeInfo{
KeyRange: &protos.KeyRange{LowerBound: l, UpperBound: r},
ShardId: shard, Krid: id}))
s, err := rowToStr(v)
if err != nil {
return err
}
fmt.Println(s)
case *pgproto3.ErrorResponse:
return fmt.Errorf("failed to wait for RQF: %s", v.Message)
case *pgproto3.ReadyForQuery:
Expand All @@ -176,96 +165,121 @@ func DumpRulesPSQL() error {
}

// TODO : unit tests
func DumpKeyRangesPSQL() error {

frontend, err := getconn()
func DumpKeyRanges() error {
cc, err := Dial(endpoint)
if err != nil {
return err
}
frontend.Send(&pgproto3.Query{
String: "SHOW sharding_rules;",
})
if err := frontend.Flush(); err != nil {
return err
}

for {
if msg, err := frontend.Receive(); err != nil {
return err
} else {
spqrlog.Zero.Debug().
Interface("message", msg).
Msg("received message")
switch v := msg.(type) {
case *pgproto3.DataRow:
col := string(v.Values[2])
id := string(v.Values[0])
tablename := string(v.Values[1])

fmt.Printf("%s;\n",
decode.DecodeRule(
&protos.ShardingRule{
Id: id,
TableName: tablename,
ShardingRuleEntry: []*protos.ShardingRuleEntry{
{
Column: col,
},
},
}),
)
case *pgproto3.ErrorResponse:
return fmt.Errorf("failed to wait for RQF: %s", v.Message)
case *pgproto3.ReadyForQuery:
return nil
}
rCl := protos.NewKeyRangeServiceClient(cc)
if keys, err := rCl.ListAllKeyRanges(context.Background(), &protos.ListAllKeyRangesRequest{}); err != nil {
spqrlog.Zero.Error().
Err(err).
Msg("failed to dump endpoint rules")
} else {
for _, krg := range keys.KeyRangesInfo {
fmt.Println(decode.KeyRange(krg))
}
}

return nil
}

// DumpDistributions dump info about distributions & attached relations via GRPC
// TODO : unit tests
func DumpKeyRanges() error {
func DumpDistributions() error {
cc, err := Dial(endpoint)
if err != nil {
return err
}

rCl := protos.NewKeyRangeServiceClient(cc)
if keys, err := rCl.ListKeyRange(context.Background(), &protos.ListKeyRangeRequest{}); err != nil {
rCl := protos.NewDistributionServiceClient(cc)
if dss, err := rCl.ListDistributions(context.Background(), &protos.ListDistributionsRequest{}); err != nil {
spqrlog.Zero.Error().
Err(err).
Msg("failed to dump endpoint rules")
Msg("failed to dump endpoint distributions")
} else {
for _, krg := range keys.KeyRangesInfo {
fmt.Printf("%s;\n", decode.DecodeKeyRange(krg))
for _, ds := range dss.Distributions {
fmt.Println(decode.Distribution(ds))
for _, rel := range ds.Relations {
fmt.Println(decode.DistributedRelation(rel, ds.Id))
}
}
}

return nil
}

// DumpDistributionsPsql dump info about distributions via psql
// TODO : unit tests
func DumpDistributionsPsql() error {
return dumpPsql("SHOW distributions;", func(v *pgproto3.DataRow) (string, error) {
id := string(v.Values[0])
types := string(v.Values[1])

return decode.Distribution(
&protos.Distribution{
Id: id,
ColumnTypes: strings.Split(types, ","),
}), nil
})
}

// DumpRelationsPsql dump info about distributed relations via psql
// TODO : unit tests
func DumpRelationsPsql() error {
return dumpPsql("SHOW relations;", func(v *pgproto3.DataRow) (string, error) {
name := string(v.Values[0])
ds := string(v.Values[1])
dsKeyStr := strings.Split(string(v.Values[2]), ",")
dsKey := make([]*protos.DistributionKeyEntry, len(dsKeyStr))
for i, elem := range dsKeyStr {
elems := strings.Split(strings.Trim(elem, "()"), ",")
if len(elems) != 2 {
return "", fmt.Errorf("incorrect distribution key entry: \"%s\"", elem)
}
dsKey[i] = &protos.DistributionKeyEntry{
Column: strings.Trim(elems[0], "\""),
HashFunction: elems[1],
}
}

return decode.DistributedRelation(
&protos.DistributedRelation{
Name: name,
DistributionKey: dsKey,
}, ds), nil
})
}

var dump = &cobra.Command{
Use: "dump",
Short: "list running routers in current topology",
Short: "dump current sharding configuration",
RunE: func(cmd *cobra.Command, args []string) error {
if err := spqrlog.UpdateZeroLogLevel(logLevel); err != nil {
return err
}
spqrlog.Zero.Debug().
Str("endpoint", endpoint).
Msg("dialing spqrdump on")

switch proto {
case "grpc":
if err := DumpRules(); err != nil {
if err := DumpDistributions(); err != nil {
return err
}
if err := DumpKeyRanges(); err != nil {
return err
}
return nil
case "psql":
if err := DumpRulesPSQL(); err != nil {
if err := DumpDistributionsPsql(); err != nil {
return err
}
if err := DumpRelationsPsql(); err != nil {
return err
}
if err := DumpKeyRangesPSQL(); err != nil {
if err := DumpKeyRangesPsql(); err != nil {
return err
}
return nil
Expand All @@ -278,12 +292,14 @@ var dump = &cobra.Command{
}

func init() {
rootCmd.PersistentFlags().StringVarP(&endpoint, "endpoint", "e", "localhost:7003", "endpoint for dump metadata")
rootCmd.PersistentFlags().StringVarP(&endpoint, "endpoint", "e", "localhost:7000", "endpoint for dump metadata")

rootCmd.PersistentFlags().StringVarP(&proto, "proto", "t", "grpc", "protocol to use for communication")

rootCmd.PersistentFlags().StringVarP(&passwd, "passwd", "p", "", "password to use for communication")

rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "error", "log level")

rootCmd.AddCommand(dump)
}

Expand Down
17 changes: 16 additions & 1 deletion coordinator/provider/keyranges.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (c *CoordinatorService) SplitKeyRange(ctx context.Context, request *protos.
// TODO : unit tests
func (c *CoordinatorService) ListKeyRange(ctx context.Context, request *protos.ListKeyRangeRequest) (*protos.KeyRangeReply, error) {

krsqb, err := c.impl.ListAllKeyRanges(ctx)
krsqb, err := c.impl.ListKeyRanges(ctx, request.Distribution)
if err != nil {
return nil, err
}
Expand All @@ -83,6 +83,21 @@ func (c *CoordinatorService) ListKeyRange(ctx context.Context, request *protos.L
}, nil
}

func (c *CoordinatorService) ListAllKeyRanges(ctx context.Context, _ *protos.ListAllKeyRangesRequest) (*protos.KeyRangeReply, error) {
krsDb, err := c.impl.ListAllKeyRanges(ctx)
if err != nil {
return nil, err
}

krs := make([]*protos.KeyRangeInfo, len(krsDb))

for i, krg := range krsDb {
krs[i] = krg.ToProto()
}

return &protos.KeyRangeReply{KeyRangesInfo: krs}, nil
}

// TODO : unit tests
func (c *CoordinatorService) MoveKeyRange(ctx context.Context, request *protos.MoveKeyRangeRequest) (*protos.ModifyReply, error) {
if err := c.impl.Move(ctx, &kr.MoveKeyRange{
Expand Down
2 changes: 1 addition & 1 deletion pkg/coord/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (a *Adapter) ListKeyRanges(ctx context.Context, distribution string) ([]*kr
// TODO : unit tests
func (a *Adapter) ListAllKeyRanges(ctx context.Context) ([]*kr.KeyRange, error) {
c := proto.NewKeyRangeServiceClient(a.conn)
reply, err := c.ListKeyRange(ctx, &proto.ListKeyRangeRequest{})
reply, err := c.ListAllKeyRanges(ctx, &proto.ListAllKeyRangesRequest{})
if err != nil {
return nil, err
}
Expand Down
30 changes: 20 additions & 10 deletions pkg/decode/spqrql.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,32 @@ package decode

import (
"fmt"
"strings"

protos "github.com/pg-sharding/spqr/pkg/protos"
)

// TODO : unit tests
func DecodeRule(rule *protos.ShardingRule) string {
// KeyRange returns query to create given key range
func KeyRange(krg *protos.KeyRangeInfo) string {
/* TODO: composite key support */
if rule.TableName != "" {
return fmt.Sprintf("CREATE SHARDING RULE %s TABLE %s COLUMN %s", rule.Id, rule.TableName, rule.ShardingRuleEntry[0].Column)
}
return fmt.Sprintf("CREATE SHARDING RULE %s COLUMN %s", rule.Id, rule.ShardingRuleEntry[0].Column)
return fmt.Sprintf("CREATE KEY RANGE %s FROM %s ROUTE TO %s FOR DISTRIBUTION %s;", krg.Krid, krg.KeyRange.LowerBound, krg.ShardId, krg.DistributionId)
}

// TODO : unit tests
func DecodeKeyRange(krg *protos.KeyRangeInfo) string {
/* TODO: composite key support */
// Distribution returns query to create given distribution
func Distribution(ds *protos.Distribution) string {
return fmt.Sprintf("CREATE DISTRIBUTION %s COLUMN TYPES %s;", ds.Id, strings.Join(ds.ColumnTypes, ", "))
}

return fmt.Sprintf("CREATE KEY RANGE %s FROM %s TO %s ROUTE TO %s", krg.Krid, krg.KeyRange.LowerBound, krg.KeyRange.UpperBound, krg.ShardId)
// DistributedRelation return query to attach relation to distribution
func DistributedRelation(rel *protos.DistributedRelation, ds string) string {
elems := make([]string, len(rel.DistributionKey))
for j, el := range rel.DistributionKey {
if el.HashFunction != "" {
elems[j] = fmt.Sprintf("%s HASH FUNCTION %s", el.Column, el.HashFunction)
} else {
elems[j] = el.Column
}

}
return fmt.Sprintf("ALTER DISTRIBUTION %s ATTACH RELATION %s DISTRIBUTION KEY %s;", ds, rel.Name, strings.Join(elems, ", "))
}