diff --git a/internal/cmd/branch/vtctld/move_tables.go b/internal/cmd/branch/vtctld/move_tables.go index ac53bbb2..2cedf6a5 100644 --- a/internal/cmd/branch/vtctld/move_tables.go +++ b/internal/cmd/branch/vtctld/move_tables.go @@ -314,6 +314,7 @@ func MoveTablesSwitchTrafficCmd(ch *cmdutil.Helper) *cobra.Command { cmd.Flags().Int64Var(&flags.maxReplicationLagAllowed, "max-replication-lag-allowed", 0, "Maximum replication lag allowed in seconds") cmd.MarkFlagRequired("workflow") // nolint:errcheck cmd.MarkFlagRequired("target-keyspace") // nolint:errcheck + cmd.MarkFlagRequired("tablet-types") // nolint:errcheck return cmd } diff --git a/internal/cmd/branch/vtctld/move_tables_test.go b/internal/cmd/branch/vtctld/move_tables_test.go index f9df4fb9..c1ad301e 100644 --- a/internal/cmd/branch/vtctld/move_tables_test.go +++ b/internal/cmd/branch/vtctld/move_tables_test.go @@ -284,6 +284,30 @@ func TestMoveTablesSwitchTrafficWithMaxLag(t *testing.T) { c.Assert(buf.String(), qt.JSONEquals, map[string]string{"summary": "switched"}) } +func TestMoveTablesSwitchTrafficRequiresTabletTypes(t *testing.T) { + c := qt.New(t) + + org := "my-org" + db := "my-db" + branch := "my-branch" + + svc := &mock.MoveTablesService{} + var buf bytes.Buffer + ch := moveTablesTestHelper(org, svc, &mock.VtctldService{}, &buf) + + cmd := MoveTablesCmd(ch) + cmd.SetArgs([]string{"switch-traffic", db, branch, + "--workflow", "my-workflow", + "--target-keyspace", "target-ks", + }) + err := cmd.Execute() + + c.Assert(err, qt.IsNotNil) + c.Assert(err.Error(), qt.Contains, "required flag") + c.Assert(svc.SwitchTrafficFnInvoked, qt.IsFalse) + c.Assert(buf.String(), qt.Equals, "") +} + func TestMoveTablesReverseTrafficWithFlags(t *testing.T) { c := qt.New(t) setMoveTablesPollInterval(t, 0) @@ -426,6 +450,7 @@ func TestMoveTablesSwitchTrafficOperationFailure(t *testing.T) { cmd.SetArgs([]string{"switch-traffic", db, branch, "--workflow", "my-workflow", "--target-keyspace", "target-ks", + "--tablet-types", "PRIMARY", }) err := cmd.Execute() @@ -464,6 +489,7 @@ func TestMoveTablesSwitchTrafficOperationTimeout(t *testing.T) { cmd.SetArgs([]string{"switch-traffic", db, branch, "--workflow", "my-workflow", "--target-keyspace", "target-ks", + "--tablet-types", "PRIMARY", }) err := cmd.Execute()