Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VReplication: handle escaped identifiers in vschema when initializing sequence tables #16169

Merged
merged 11 commits into from
Jun 18, 2024
6 changes: 3 additions & 3 deletions go/test/endtoend/vreplication/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ create table nopk (name varchar(128), age int unsigned);
],
"auto_increment": {
"column": "cid",
"sequence": "customer_seq"
"sequence": "` + "`customer_seq`" + `"
}
},
"customer_name": {
Expand Down Expand Up @@ -295,7 +295,7 @@ create table nopk (name varchar(128), age int unsigned);
],
"auto_increment": {
"column": "cid",
"sequence": "customer_seq"
"sequence": "` + "`product`.`customer_seq`" + `"
}
},
"orders": {
Expand Down Expand Up @@ -345,7 +345,7 @@ create table nopk (name varchar(128), age int unsigned);
],
"auto_increment": {
"column": "cid",
"sequence": "customer_seq"
"sequence": "` + "`customer_seq`" + `"
}
},
"orders": {
Expand Down
4 changes: 2 additions & 2 deletions go/test/endtoend/vreplication/fk_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ func doReshard(t *testing.T, keyspace, workflowName, sourceShards, targetShards
sourceShards: sourceShards,
targetShards: targetShards,
skipSchemaCopy: true,
}, workflowFlavorVtctl)
}, workflowFlavorVtctld)
rs.Create()
waitForWorkflowState(t, vc, fmt.Sprintf("%s.%s", keyspace, workflowName), binlogdatapb.VReplicationWorkflowState_Running.String())
for _, targetTab := range targetTabs {
Expand Down Expand Up @@ -355,7 +355,7 @@ func doMoveTables(t *testing.T, sourceKeyspace, targetKeyspace, workflowName, ta
},
sourceKeyspace: sourceKeyspace,
atomicCopy: atomicCopy,
}, workflowFlavorRandom)
}, workflowFlavorVtctld)
mt.Create()

waitForWorkflowState(t, vc, fmt.Sprintf("%s.%s", targetKeyspace, workflowName), binlogdatapb.VReplicationWorkflowState_Running.String())
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vreplication/fk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
)

const testWorkflowFlavor = workflowFlavorRandom
const testWorkflowFlavor = workflowFlavorVtctld

// TestFKWorkflow runs a MoveTables workflow with atomic copy for a db with foreign key constraints.
// It inserts initial data, then simulates load. We insert both child rows with foreign keys and those without,
Expand Down
8 changes: 2 additions & 6 deletions go/test/endtoend/vreplication/partial_movetables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func testCancel(t *testing.T) {
sourceKeyspace: sourceKeyspace,
tables: table,
sourceShards: shard,
}, workflowFlavorRandom)
}, workflowFlavorVtctld)
mt.Create()

checkDenyList := func(keyspace string, expected bool) {
Expand Down Expand Up @@ -390,9 +390,5 @@ func testPartialMoveTablesBasic(t *testing.T, flavor workflowFlavor) {
// We test with both the vtctlclient and vtctldclient flavors.
func TestPartialMoveTablesBasic(t *testing.T) {
currentWorkflowType = binlogdatapb.VReplicationWorkflowType_MoveTables
for _, flavor := range workflowFlavors {
t.Run(workflowFlavorNames[flavor], func(t *testing.T) {
testPartialMoveTablesBasic(t, flavor)
})
}
testPartialMoveTablesBasic(t, workflowFlavorVtctld)
}
1 change: 1 addition & 0 deletions go/test/endtoend/vreplication/vdiff2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ func TestVDiff2(t *testing.T) {
// We ONLY add primary tablets in this test.
tks, err := vc.AddKeyspace(t, []*Cell{zone3, zone1, zone2}, targetKs, strings.Join(targetShards, ","), customerVSchema, customerSchema, 0, 0, 200, targetKsOpts)
require.NoError(t, err)
verifyClusterHealth(t, vc)

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
Expand Down
15 changes: 7 additions & 8 deletions go/test/endtoend/vreplication/vreplication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1563,17 +1563,16 @@ func switchWrites(t *testing.T, workflowType, ksWorkflow string, reverse bool) {
}
const SwitchWritesTimeout = "91s" // max: 3 tablet picker 30s waits + 1
ensureCanSwitch(t, workflowType, "", ksWorkflow)
// Use vtctldclient for MoveTables SwitchTraffic ~ 50% of the time.
if workflowType == binlogdatapb.VReplicationWorkflowType_MoveTables.String() && time.Now().Second()%2 == 0 {
parts := strings.Split(ksWorkflow, ".")
require.Equal(t, 2, len(parts))
moveTablesAction(t, command, defaultCellName, parts[1], sourceKs, parts[0], "", "--timeout="+SwitchWritesTimeout, "--tablet-types=primary")
targetKs, workflow, found := strings.Cut(ksWorkflow, ".")
require.True(t, found)
if workflowType == binlogdatapb.VReplicationWorkflowType_MoveTables.String() {
moveTablesAction(t, command, defaultCellName, workflow, sourceKs, targetKs, "", "--timeout="+SwitchWritesTimeout, "--tablet-types=primary")
return
}
output, err := vc.VtctlClient.ExecuteCommandWithOutput(workflowType, "--", "--tablet_types=primary",
"--timeout="+SwitchWritesTimeout, "--initialize-target-sequences", command, ksWorkflow)
output, err := vc.VtctldClient.ExecuteCommandWithOutput(workflowType, "--tablet-types=primary", "--workflow", workflow,
"--target-keyspace", targetKs, command, "--timeout="+SwitchWritesTimeout, "--initialize-target-sequences")
if output != "" {
fmt.Printf("Output of switching writes with vtctlclient for %s:\n++++++\n%s\n--------\n", ksWorkflow, output)
fmt.Printf("Output of switching writes with vtctldclient for %s:\n++++++\n%s\n--------\n", ksWorkflow, output)
}
// printSwitchWritesExtraDebug is useful when debugging failures in Switch writes due to corner cases/races
_ = printSwitchWritesExtraDebug
Expand Down
126 changes: 103 additions & 23 deletions go/vt/vtctl/workflow/traffic_switcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,12 @@ func (ts *trafficSwitcher) getTargetSequenceMetadata(ctx context.Context) (map[s
return nil
}
for tableName, tableDef := range kvs.Tables {
// The table name can be escaped in the vschema definition.
escapedTableName, err := sqlescape.UnescapeID(tableName)
if err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid table name %s in keyspace %s: %v",
tableName, keyspace, err)
}
select {
case <-sctx.Done():
return sctx.Err()
Expand All @@ -1396,9 +1402,9 @@ func (ts *trafficSwitcher) getTargetSequenceMetadata(ctx context.Context) (map[s
if complete := func() bool {
smMu.Lock() // Prevent concurrent access to the map
defer smMu.Unlock()
sm := sequencesByBackingTable[tableName]
sm := sequencesByBackingTable[escapedTableName]
if tableDef != nil && tableDef.Type == vindexes.TypeSequence &&
sm != nil && tableName == sm.backingTableName {
sm != nil && escapedTableName == sm.backingTableName {
tablesFound++ // This is also protected by the mutex
sm.backingTableKeyspace = keyspace
// Set the default keyspace name. We will later check to
Expand Down Expand Up @@ -1429,18 +1435,22 @@ func (ts *trafficSwitcher) getTargetSequenceMetadata(ctx context.Context) (map[s
searchGroup, gctx := errgroup.WithContext(ctx)
searchCompleted := make(chan struct{})
for _, keyspace := range keyspaces {
keyspace := keyspace // https://golang.org/doc/faq#closures_and_goroutines
// The keyspace name could be escaped so we need to unescape it.
ks, err := sqlescape.UnescapeID(keyspace)
if err != nil { // Should never happen
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid keyspace name %s: %v", keyspace, err)
}
searchGroup.Go(func() error {
return searchKeyspace(gctx, searchCompleted, keyspace)
return searchKeyspace(gctx, searchCompleted, ks)
})
}
if err := searchGroup.Wait(); err != nil {
return nil, err
}

if tablesFound != tableCount {
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to locate all of the backing sequence tables being used; sequence table metadata: %+v",
sequencesByBackingTable)
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to locate all of the backing sequence tables being used: %s",
strings.Join(maps.Keys(sequencesByBackingTable), ","))
}
return sequencesByBackingTable, nil
}
Expand All @@ -1460,34 +1470,73 @@ func (ts *trafficSwitcher) findSequenceUsageInKeyspace(vschema *vschemapb.Keyspa
targetDBName := targets[0].GetPrimary().DbName()
sequencesByBackingTable := make(map[string]*sequenceMetadata)

for _, table := range ts.Tables() {
vs, ok := vschema.Tables[table]
if !ok || vs.GetAutoIncrement().GetSequence() == "" {
for _, table := range ts.tables {
seqTable, ok := vschema.Tables[table]
if !ok || seqTable.GetAutoIncrement().GetSequence() == "" {
continue
}
// Be sure that the table name is unescaped as it can be escaped
// in the vschema.
escapedTable, err := sqlescape.UnescapeID(table)
if err != nil {
return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid table name %s defined in the sequence table %+v: %v",
table, seqTable, err)
}
sm := &sequenceMetadata{
backingTableName: vs.AutoIncrement.Sequence,
usingTableName: table,
usingTableDefinition: vs,
usingTableDBName: targetDBName,
usingTableName: escapedTable,
usingTableDBName: targetDBName,
}
// If the sequence table is fully qualified in the vschema then
// we don't need to find it later.
if strings.Contains(vs.AutoIncrement.Sequence, ".") {
keyspace, tableName, found := strings.Cut(vs.AutoIncrement.Sequence, ".")
if strings.Contains(seqTable.AutoIncrement.Sequence, ".") {
keyspace, tableName, found := strings.Cut(seqTable.AutoIncrement.Sequence, ".")
if !found {
return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence table name %s defined in the %s keyspace",
vs.AutoIncrement.Sequence, ts.targetKeyspace)
seqTable.AutoIncrement.Sequence, ts.targetKeyspace)
}
// Unescape the table name and keyspace name as they may be escaped in the
// vschema definition if they e.g. contain dashes.
if keyspace, err = sqlescape.UnescapeID(keyspace); err != nil {
return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid keyspace in qualified sequence table name %s defined in sequence table %+v: %v",
seqTable.AutoIncrement.Sequence, seqTable, err)
}
if tableName, err = sqlescape.UnescapeID(tableName); err != nil {
return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid qualified sequence table name %s defined in sequence table %+v: %v",
seqTable.AutoIncrement.Sequence, seqTable, err)
}
sm.backingTableName = tableName
sm.backingTableKeyspace = keyspace
sm.backingTableName = tableName
// Update the definition with the unescaped values.
seqTable.AutoIncrement.Sequence = fmt.Sprintf("%s.%s", keyspace, tableName)
// Set the default keyspace name. We will later check to
// see if the tablet we send requests to is using a dbname
// override and use that if it is.
sm.backingTableDBName = "vt_" + keyspace
} else {
sm.backingTableName, err = sqlescape.UnescapeID(seqTable.AutoIncrement.Sequence)
if err != nil {
return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence table name %s defined in sequence table %+v: %v",
seqTable.AutoIncrement.Sequence, seqTable, err)
}
seqTable.AutoIncrement.Sequence = sm.backingTableName
allFullyQualified = false
}
// The column names can be escaped in the vschema definition.
for i := range seqTable.ColumnVindexes {
escapedColumn, err := sqlescape.UnescapeID(seqTable.ColumnVindexes[i].Column)
if err != nil {
return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence column vindex name %s defined in sequence table %+v: %v",
seqTable.ColumnVindexes[i].Column, seqTable, err)
}
seqTable.ColumnVindexes[i].Column = escapedColumn
}
escapedAutoIncCol, err := sqlescape.UnescapeID(seqTable.AutoIncrement.Column)
if err != nil {
return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid auto-increment column name %s defined in sequence table %+v: %v",
seqTable.AutoIncrement.Column, seqTable, err)
}
seqTable.AutoIncrement.Column = escapedAutoIncCol
sm.usingTableDefinition = seqTable
sequencesByBackingTable[sm.backingTableName] = sm
}

Expand Down Expand Up @@ -1516,10 +1565,25 @@ func (ts *trafficSwitcher) initializeTargetSequences(ctx context.Context, sequen
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "no primary tablet found for target shard %s/%s",
ts.targetKeyspace, target.GetShard().ShardName())
}
usingCol, err := sqlescape.EnsureEscaped(sequenceMetadata.usingTableDefinition.AutoIncrement.Column)
if err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid column name %s specified for sequence in table %s: %v",
sequenceMetadata.usingTableDefinition.AutoIncrement.Column, sequenceMetadata.usingTableName, err)
}
usingDB, err := sqlescape.EnsureEscaped(sequenceMetadata.usingTableDBName)
if err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid database name %s specified for sequence in table %s: %v",
sequenceMetadata.usingTableDBName, sequenceMetadata.usingTableName, err)
}
usingTable, err := sqlescape.EnsureEscaped(sequenceMetadata.usingTableName)
if err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence table name specified for sequence in table %s: %v",
sequenceMetadata.usingTableName, err)
}
query := sqlparser.BuildParsedQuery(sqlGetMaxSequenceVal,
sqlescape.EscapeID(sequenceMetadata.usingTableDefinition.AutoIncrement.Column),
sqlescape.EscapeID(sequenceMetadata.usingTableDBName),
sqlescape.EscapeID(sequenceMetadata.usingTableName),
usingCol,
usingDB,
usingTable,
)
qr, terr := ts.ws.tmc.ExecuteFetchAsApp(ictx, primary.Tablet, true, &tabletmanagerdatapb.ExecuteFetchAsAppRequest{
Query: []byte(query.Query),
Expand Down Expand Up @@ -1580,9 +1644,19 @@ func (ts *trafficSwitcher) initializeTargetSequences(ctx context.Context, sequen
if sequenceTablet.DbNameOverride != "" {
sequenceMetadata.backingTableDBName = sequenceTablet.DbNameOverride
}
backingDB, err := sqlescape.EnsureEscaped(sequenceMetadata.backingTableDBName)
if err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid database name %s in sequence backing table %s: %v",
sequenceMetadata.backingTableDBName, sequenceMetadata.backingTableName, err)
}
backingTable, err := sqlescape.EnsureEscaped(sequenceMetadata.backingTableName)
if err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence backing table name %s: %v",
sequenceMetadata.backingTableName, err)
}
query := sqlparser.BuildParsedQuery(sqlInitSequenceTable,
sqlescape.EscapeID(sequenceMetadata.backingTableDBName),
sqlescape.EscapeID(sequenceMetadata.backingTableName),
backingDB,
backingTable,
nextVal,
nextVal,
nextVal,
Expand Down Expand Up @@ -1615,7 +1689,13 @@ func (ts *trafficSwitcher) initializeTargetSequences(ctx context.Context, sequen
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to get primary tablet for keyspace %s: %v",
sequenceMetadata.backingTableKeyspace, ierr)
}
ierr = ts.TabletManagerClient().ResetSequences(ictx, ti.Tablet, []string{sequenceMetadata.backingTableName})
// ResetSequences interfaces with the schema engine and the actual
// table identifiers DO NOT contain the backticks. So we have to
// ensure that the table name is unescaped.
if backingTable, err = sqlescape.UnescapeID(backingTable); err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid table name %s: %v", backingTable, err)
}
ierr = ts.TabletManagerClient().ResetSequences(ictx, ti.Tablet, []string{backingTable})
if ierr != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to reset the sequence cache for backing table %s on shard %s/%s using tablet %s: %v",
sequenceMetadata.backingTableName, sequenceShard.Keyspace(), sequenceShard.ShardName(), sequenceShard.PrimaryAlias, ierr)
Expand Down
Loading
Loading