diff --git a/pkg/app/piped/executor/ecs/ecs.go b/pkg/app/piped/executor/ecs/ecs.go index ee025b5b21..ceef6b8c42 100644 --- a/pkg/app/piped/executor/ecs/ecs.go +++ b/pkg/app/piped/executor/ecs/ecs.go @@ -451,6 +451,21 @@ func routing(ctx context.Context, in *executor.Input, platformProviderName strin return false } + currListenerRuleArns, err := client.GetListenerRuleArns(ctx, currListenerArns) + if err != nil { + in.LogPersister.Errorf("Failed to get current active listener rule: %v", err) + return false + } + + if len(currListenerRuleArns) > 0 { + if err := client.ModifyRules(ctx, currListenerRuleArns, routingTrafficCfg); err != nil { + in.LogPersister.Errorf("Failed to routing traffic to PRIMARY/CANARY variants: %v", err) + return false + } + + return true + } + if err := client.ModifyListeners(ctx, currListenerArns, routingTrafficCfg); err != nil { in.LogPersister.Errorf("Failed to routing traffic to PRIMARY/CANARY variants: %v", err) return false diff --git a/pkg/app/piped/executor/ecs/rollback.go b/pkg/app/piped/executor/ecs/rollback.go index 87c76306ff..16699f6ded 100644 --- a/pkg/app/piped/executor/ecs/rollback.go +++ b/pkg/app/piped/executor/ecs/rollback.go @@ -158,8 +158,23 @@ func rollback(ctx context.Context, in *executor.Input, platformProviderName stri return false } + currListenerRuleArns, err := client.GetListenerRuleArns(ctx, currListenerArns) + if err != nil { + in.LogPersister.Errorf("Failed to get current active listener rule: %v", err) + return false + } + + if len(currListenerRuleArns) > 0 { + if err := client.ModifyRules(ctx, currListenerRuleArns, routingTrafficCfg); err != nil { + in.LogPersister.Errorf("Failed to routing traffic to PRIMARY/CANARY variants: %v", err) + return false + } + + return true + } + if err := client.ModifyListeners(ctx, currListenerArns, routingTrafficCfg); err != nil { - in.LogPersister.Errorf("Failed to routing traffic to PRIMARY variant: %v", err) + in.LogPersister.Errorf("Failed to routing traffic to PRIMARY/CANARY variants: %v", err) return false } } diff --git a/pkg/app/piped/platformprovider/ecs/client.go b/pkg/app/piped/platformprovider/ecs/client.go index 3b16819f80..c92f4d17e1 100644 --- a/pkg/app/piped/platformprovider/ecs/client.go +++ b/pkg/app/piped/platformprovider/ecs/client.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -417,6 +418,30 @@ func (c *client) GetListenerArns(ctx context.Context, targetGroup types.LoadBala return arns, nil } +func (c *client) GetListenerRuleArns(ctx context.Context, listenerArns []string) ([]string, error) { + var ruleArns []string + + // Fetch all rules by listeners + for _, listenerArn := range listenerArns { + input := &elasticloadbalancingv2.DescribeRulesInput{ + ListenerArn: aws.String(listenerArn), + } + output, err := c.elbClient.DescribeRules(ctx, input) + if err != nil { + return nil, err + } + for _, rule := range output.Rules { + ruleArns = append(ruleArns, *rule.RuleArn) + } + } + + if len(ruleArns) == 0 { + return nil, platformprovider.ErrNotFound + } + + return ruleArns, nil +} + func (c *client) getLoadBalancerArn(ctx context.Context, targetGroupArn string) (string, error) { input := &elasticloadbalancingv2.DescribeTargetGroupsInput{ TargetGroupArns: []string{targetGroupArn}, @@ -437,11 +462,21 @@ func (c *client) ModifyListeners(ctx context.Context, listenerArns []string, rou return fmt.Errorf("invalid listener configuration: requires 2 target groups") } - modifyListener := func(ctx context.Context, listenerArn string) error { - input := &elasticloadbalancingv2.ModifyListenerInput{ - ListenerArn: aws.String(listenerArn), - DefaultActions: []elbtypes.Action{ - { + for _, listenerArn := range listenerArns { + // Describe the listener to get the current actions + describeListenersOutput, err := c.elbClient.DescribeListeners(ctx, &elasticloadbalancingv2.DescribeListenersInput{ + ListenerArns: []string{listenerArn}, + }) + if err != nil { + return fmt.Errorf("error describing listener %s: %w", listenerArn, err) + } + + // Prepare the actions to be modified + var modifiedActions []elbtypes.Action + for _, action := range describeListenersOutput.Listeners[0].DefaultActions { + if action.Type == elbtypes.ActionTypeEnumForward { + // Modify only the forward action (new logic) + modifiedAction := elbtypes.Action{ Type: elbtypes.ActionTypeEnumForward, ForwardConfig: &elbtypes.ForwardActionConfig{ TargetGroups: []elbtypes.TargetGroupTuple{ @@ -455,16 +490,74 @@ func (c *client) ModifyListeners(ctx context.Context, listenerArns []string, rou }, }, }, - }, - }, + } + modifiedActions = append(modifiedActions, modifiedAction) + } else { + // Keep other actions unchanged (new logic) + modifiedActions = append(modifiedActions, action) + } } - _, err := c.elbClient.ModifyListener(ctx, input) - return err + + // Modify the listener + _, err = c.elbClient.ModifyListener(ctx, &elasticloadbalancingv2.ModifyListenerInput{ + ListenerArn: aws.String(listenerArn), + DefaultActions: modifiedActions, + }) + if err != nil { + return fmt.Errorf("error modifying listener %s: %w", listenerArn, err) + } + } + return nil +} + +func (c *client) ModifyRules(ctx context.Context, listenerRuleArns []string, routingTrafficCfg RoutingTrafficConfig) error { + if len(routingTrafficCfg) != 2 { + return fmt.Errorf("invalid listener configuration: requires 2 target groups") } - for _, listener := range listenerArns { - if err := modifyListener(ctx, listener); err != nil { - return err + for _, ruleArn := range listenerRuleArns { + // Describe the rule to get current actions + describeRulesOutput, err := c.elbClient.DescribeRules(ctx, &elasticloadbalancingv2.DescribeRulesInput{ + RuleArns: []string{ruleArn}, + }) + if err != nil { + return fmt.Errorf("error describing listener rule %v: %w", strings.Join(listenerRuleArns, ", "), err) + } + + // Prepare the actions to be modified + var modifiedActions []elbtypes.Action + for _, action := range describeRulesOutput.Rules[0].Actions { + if action.Type == elbtypes.ActionTypeEnumForward { + // Modify only the forward action (new logic) + modifiedAction := elbtypes.Action{ + Type: elbtypes.ActionTypeEnumForward, + ForwardConfig: &elbtypes.ForwardActionConfig{ + TargetGroups: []elbtypes.TargetGroupTuple{ + { + TargetGroupArn: aws.String(routingTrafficCfg[0].TargetGroupArn), + Weight: aws.Int32(int32(routingTrafficCfg[0].Weight)), + }, + { + TargetGroupArn: aws.String(routingTrafficCfg[1].TargetGroupArn), + Weight: aws.Int32(int32(routingTrafficCfg[1].Weight)), + }, + }, + }, + } + modifiedActions = append(modifiedActions, modifiedAction) + } else { + // Keep other actions unchanged (new logic) + modifiedActions = append(modifiedActions, action) + } + } + + // Modify the rule with the new actions + _, err = c.elbClient.ModifyRule(ctx, &elasticloadbalancingv2.ModifyRuleInput{ + RuleArn: aws.String(ruleArn), + Actions: modifiedActions, + }) + if err != nil { + return fmt.Errorf("error modifying listener rule %s: %w", ruleArn, err) } } return nil diff --git a/pkg/app/piped/platformprovider/ecs/ecs.go b/pkg/app/piped/platformprovider/ecs/ecs.go index 692f50fb24..c819c58659 100644 --- a/pkg/app/piped/platformprovider/ecs/ecs.go +++ b/pkg/app/piped/platformprovider/ecs/ecs.go @@ -57,7 +57,9 @@ type ECS interface { type ELB interface { GetListenerArns(ctx context.Context, targetGroup types.LoadBalancer) ([]string, error) + GetListenerRuleArns(ctx context.Context, listenerArns []string) ([]string, error) ModifyListeners(ctx context.Context, listenerArns []string, routingTrafficCfg RoutingTrafficConfig) error + ModifyRules(ctx context.Context, listenerRuleArns []string, routingTrafficCfg RoutingTrafficConfig) error } // Registry holds a pool of aws client wrappers.