diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index 967f9d530..570bce4c1 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -516,7 +516,7 @@ func (c *Cluster) compareContainers(description string, setA, setB []v1.Containe newCheck("new statefulset %s's %s (index %d) name does not match the current one", func(a, b v1.Container) bool { return a.Name != b.Name }), newCheck("new statefulset %s's %s (index %d) ports do not match the current one", - func(a, b v1.Container) bool { return !reflect.DeepEqual(a.Ports, b.Ports) }), + func(a, b v1.Container) bool { return !comparePorts(a.Ports, b.Ports) }), newCheck("new statefulset %s's %s (index %d) resources do not match the current ones", func(a, b v1.Container) bool { return !compareResources(&a.Resources, &b.Resources) }), newCheck("new statefulset %s's %s (index %d) environment does not match the current one", @@ -627,6 +627,46 @@ func compareSpiloConfiguration(configa, configb string) bool { return reflect.DeepEqual(oa, ob) } +func areProtocolsEqual(a, b v1.Protocol) bool { + return a == b || + (a == "" && b == v1.ProtocolTCP) || + (a == v1.ProtocolTCP && b == "") +} + +func comparePorts(a, b []v1.ContainerPort) bool { + if len(a) != len(b) { + return false + } + + areContainerPortsEqual := func(a, b v1.ContainerPort) bool { + return a.Name == b.Name && + a.HostPort == b.HostPort && + areProtocolsEqual(a.Protocol, b.Protocol) && + a.HostIP == b.HostIP + } + + findByPortValue := func(portSpecs []v1.ContainerPort, port int32) (v1.ContainerPort, bool) { + for _, portSpec := range portSpecs { + if portSpec.ContainerPort == port { + return portSpec, true + } + } + return v1.ContainerPort{}, false + } + + for _, portA := range a { + portB, found := findByPortValue(b, portA.ContainerPort) + if !found { + return false + } + if !areContainerPortsEqual(portA, portB) { + return false + } + } + + return true +} + func (c *Cluster) enforceMinResourceLimits(spec *acidv1.PostgresSpec) error { var ( diff --git a/pkg/cluster/cluster_test.go b/pkg/cluster/cluster_test.go index dc1f5ff03..d06cc21e1 100644 --- a/pkg/cluster/cluster_test.go +++ b/pkg/cluster/cluster_test.go @@ -5,6 +5,8 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/assert" + "github.com/sirupsen/logrus" acidv1 "github.com/zalando/postgres-operator/pkg/apis/acid.zalan.do/v1" fakeacidv1 "github.com/zalando/postgres-operator/pkg/generated/clientset/versioned/fake" @@ -1088,3 +1090,102 @@ func TestValidUsernames(t *testing.T) { } } } + +func TestComparePorts(t *testing.T) { + testCases := []struct { + name string + setA []v1.ContainerPort + setB []v1.ContainerPort + expected bool + }{ + { + name: "different ports", + setA: []v1.ContainerPort{ + { + Name: "metrics", + ContainerPort: 9187, + Protocol: v1.ProtocolTCP, + }, + }, + + setB: []v1.ContainerPort{ + { + Name: "http", + ContainerPort: 80, + Protocol: v1.ProtocolTCP, + }, + }, + expected: false, + }, + { + name: "no difference", + setA: []v1.ContainerPort{ + { + Name: "metrics", + ContainerPort: 9187, + Protocol: v1.ProtocolTCP, + }, + }, + setB: []v1.ContainerPort{ + { + Name: "metrics", + ContainerPort: 9187, + Protocol: v1.ProtocolTCP, + }, + }, + expected: true, + }, + { + name: "same ports, different order", + setA: []v1.ContainerPort{ + { + Name: "metrics", + ContainerPort: 9187, + Protocol: v1.ProtocolTCP, + }, + { + Name: "http", + ContainerPort: 80, + Protocol: v1.ProtocolTCP, + }, + }, + setB: []v1.ContainerPort{ + { + Name: "http", + ContainerPort: 80, + Protocol: v1.ProtocolTCP, + }, + { + Name: "metrics", + ContainerPort: 9187, + Protocol: v1.ProtocolTCP, + }, + }, + expected: true, + }, + { + name: "same ports, but one with default protocol", + setA: []v1.ContainerPort{ + { + Name: "metrics", + ContainerPort: 9187, + Protocol: v1.ProtocolTCP, + }, + }, + setB: []v1.ContainerPort{ + { + Name: "metrics", + ContainerPort: 9187, + }, + }, + expected: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + got := comparePorts(testCase.setA, testCase.setB) + assert.Equal(t, testCase.expected, got) + }) + } +}