Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
floreks committed Oct 26, 2023
2 parents a106dfe + de0d89c commit 0801896
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 181 deletions.
6 changes: 3 additions & 3 deletions pkg/module/agent_registrar/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ func TestRegister(t *testing.T) {
Return(zaptest.NewLogger(t))
mockRpcApi.EXPECT().
AgentInfo(gomock.Any(), gomock.Any()).
Return(&api.AgentInfo{Id: 123, ProjectId: 456}, nil)
Return(&api.AgentInfo{Id: 123, ClusterId: "456"}, nil)
mockAgentTracker.EXPECT().
RegisterConnection(gomock.Any(), gomock.Any()).
Do(func(ctx context.Context, connectedAgentInfo *agent_tracker.ConnectedAgentInfo) error {
assert.EqualValues(t, 123, connectedAgentInfo.AgentId)
assert.EqualValues(t, 456, connectedAgentInfo.ProjectId)
assert.EqualValues(t, "456", connectedAgentInfo.ClusterId)
assert.EqualValues(t, 123456789, connectedAgentInfo.ConnectionId)
return nil
})
Expand Down Expand Up @@ -67,7 +67,7 @@ func TestRegister_registerAgent_Error(t *testing.T) {
Return(zaptest.NewLogger(t))
mockRpcApi.EXPECT().
AgentInfo(gomock.Any(), gomock.Any()).
Return(&api.AgentInfo{Id: 1, ProjectId: 1}, nil)
Return(&api.AgentInfo{Id: 1, ClusterId: "1"}, nil)
mockAgentTracker.EXPECT().
RegisterConnection(gomock.Any(), gomock.Any()).
Return(expectedErr)
Expand Down
2 changes: 1 addition & 1 deletion pkg/module/agent_tracker/agent_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestConnectedAgentInfoSize(t *testing.T) {
ConnectedAt: timestamppb.Now(),
ConnectionId: 1231232,
AgentId: 123123,
ProjectId: 3232323,
ClusterId: "3232323",
})
require.NoError(t, err)
data, err := proto.Marshal(&redistool.ExpiringValue{
Expand Down
169 changes: 26 additions & 143 deletions pkg/module/agent_tracker/tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ var (
func TestRegisterConnection_HappyPath(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
r, connectedAgents, byAgentId, byProjectId, _, info := setupTracker(t)
r, connectedAgents, byAgentId, _, info := setupTracker(t)

byProjectId.EXPECT().
Set(gomock.Any(), info.ProjectId, info.ConnectionId, gomock.Any())
byAgentId.EXPECT().
Set(gomock.Any(), info.AgentId, info.ConnectionId, gomock.Any())
connectedAgents.EXPECT().
Expand All @@ -52,15 +50,12 @@ func TestRegisterConnection_HappyPath(t *testing.T) {
func TestRegisterConnection_AllCalledOnError(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
r, connectedAgents, byAgentId, byProjectId, _, info := setupTracker(t)
r, connectedAgents, byAgentId, _, info := setupTracker(t)

err1 := errors.New("err1")
err2 := errors.New("err2")
err3 := errors.New("err3")

byProjectId.EXPECT().
Set(gomock.Any(), info.ProjectId, info.ConnectionId, gomock.Any()).
Return(err1)
byAgentId.EXPECT().
Set(gomock.Any(), info.AgentId, info.ConnectionId, gomock.Any()).
Return(err2)
Expand All @@ -81,14 +76,8 @@ func TestRegisterConnection_AllCalledOnError(t *testing.T) {
func TestUnregisterConnection_HappyPath(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
r, connectedAgents, byAgentId, byProjectId, _, info := setupTracker(t)
r, connectedAgents, byAgentId, _, info := setupTracker(t)

gomock.InOrder(
byProjectId.EXPECT().
Set(gomock.Any(), info.ProjectId, info.ConnectionId, gomock.Any()),
byProjectId.EXPECT().
Unset(gomock.Any(), info.ProjectId, info.ConnectionId),
)
gomock.InOrder(
byAgentId.EXPECT().
Set(gomock.Any(), info.AgentId, info.ConnectionId, gomock.Any()),
Expand All @@ -113,18 +102,11 @@ func TestUnregisterConnection_HappyPath(t *testing.T) {
func TestUnregisterConnection_AllCalledOnError(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
r, connectedAgents, byAgentId, byProjectId, _, info := setupTracker(t)
r, connectedAgents, byAgentId, _, info := setupTracker(t)

err1 := errors.New("err1")
err2 := errors.New("err2")

gomock.InOrder(
byProjectId.EXPECT().
Set(gomock.Any(), info.ProjectId, info.ConnectionId, gomock.Any()),
byProjectId.EXPECT().
Unset(gomock.Any(), info.ProjectId, info.ConnectionId).
Return(err1),
)
gomock.InOrder(
byAgentId.EXPECT().
Set(gomock.Any(), info.AgentId, info.ConnectionId, gomock.Any()),
Expand All @@ -150,43 +132,33 @@ func TestUnregisterConnection_AllCalledOnError(t *testing.T) {
}

func TestGC_HappyPath(t *testing.T) {
r, connectedAgents, byAgentId, byProjectId, _, _ := setupTracker(t)
r, connectedAgents, byAgentId, _, _ := setupTracker(t)

wasCalled1 := false
wasCalled2 := false
wasCalled3 := false

connectedAgents.EXPECT().
GC().
Return(func(_ context.Context) (int, error) {
wasCalled3 = true
wasCalled2 = true
return 3, nil
})

byAgentId.EXPECT().
GC().
Return(func(_ context.Context) (int, error) {
wasCalled2 = true
return 2, nil
})

byProjectId.EXPECT().
GC().
Return(func(_ context.Context) (int, error) {
wasCalled1 = true
return 1, nil
return 2, nil
})

assert.EqualValues(t, 6, r.runGC(context.Background()))
assert.EqualValues(t, 5, r.runGC(context.Background()))
assert.True(t, wasCalled1)
assert.True(t, wasCalled2)
assert.True(t, wasCalled3)
}

func TestGC_AllCalledOnError(t *testing.T) {
r, connectedAgents, byAgentId, byProjectId, rep, _ := setupTracker(t)
r, connectedAgents, byAgentId, rep, _ := setupTracker(t)

wasCalled1 := false
wasCalled2 := false
wasCalled3 := false

Expand All @@ -212,37 +184,23 @@ func TestGC_AllCalledOnError(t *testing.T) {
HandleProcessingError(gomock.Any(), gomock.Any(), "Failed to GC data in Redis", matcher.ErrorEq("err2")),
)

gomock.InOrder(
byProjectId.EXPECT().
GC().
Return(func(_ context.Context) (int, error) {
wasCalled1 = true
return 1, errors.New("err1")
}),
rep.EXPECT().
HandleProcessingError(gomock.Any(), gomock.Any(), "Failed to GC data in Redis", matcher.ErrorEq("err1")),
)

assert.EqualValues(t, 6, r.runGC(context.Background()))
assert.True(t, wasCalled1)
assert.EqualValues(t, 5, r.runGC(context.Background()))
assert.True(t, wasCalled2)
assert.True(t, wasCalled3)
}

func TestRefresh_HappyPath(t *testing.T) {
r, connectedAgents, byAgentId, byProjectId, _, _ := setupTracker(t)
r, connectedAgents, byAgentId, _, _ := setupTracker(t)

connectedAgents.EXPECT().
Refresh(gomock.Any(), gomock.Any())
byAgentId.EXPECT().
Refresh(gomock.Any(), gomock.Any())
byProjectId.EXPECT().
Refresh(gomock.Any(), gomock.Any())
r.refreshRegistrations(context.Background(), time.Now())
}

func TestRefresh_AllCalledOnError(t *testing.T) {
r, connectedAgents, byAgentId, byProjectId, rep, _ := setupTracker(t)
r, connectedAgents, byAgentId, rep, _ := setupTracker(t)

gomock.InOrder(
connectedAgents.EXPECT().
Expand All @@ -258,84 +216,11 @@ func TestRefresh_AllCalledOnError(t *testing.T) {
rep.EXPECT().
HandleProcessingError(gomock.Any(), gomock.Any(), "Failed to refresh hash data in Redis", matcher.ErrorEq("err1")),
)
gomock.InOrder(
byProjectId.EXPECT().
Refresh(gomock.Any(), gomock.Any()).
Return(errors.New("err2")),
rep.EXPECT().
HandleProcessingError(gomock.Any(), gomock.Any(), "Failed to refresh hash data in Redis", matcher.ErrorEq("err2")),
)
r.refreshRegistrations(context.Background(), time.Now())
}

func TestGetConnectionsByProjectId_HappyPath(t *testing.T) {
r, _, _, byProjectId, _, info := setupTracker(t)
infoBytes, err := proto.Marshal(info)
require.NoError(t, err)
byProjectId.EXPECT().
Scan(gomock.Any(), info.ProjectId, gomock.Any()).
Do(func(ctx context.Context, key int64, cb redistool.ScanCallback) (int, error) {
var done bool
done, err = cb("k2", infoBytes, nil)
if err != nil || done {
return 0, err
}
return 0, nil
})
var cbCalled int
err = r.GetConnectionsByProjectId(context.Background(), info.ProjectId, func(i *ConnectedAgentInfo) (done bool, err error) {
cbCalled++
assert.Empty(t, cmp.Diff(i, info, protocmp.Transform()))
return false, nil
})
require.NoError(t, err)
assert.EqualValues(t, 1, cbCalled)
}

func TestGetConnectionsByProjectId_ScanError(t *testing.T) {
r, _, _, byProjectId, rep, info := setupTracker(t)
gomock.InOrder(
byProjectId.EXPECT().
Scan(gomock.Any(), info.ProjectId, gomock.Any()).
Do(func(ctx context.Context, key int64, cb redistool.ScanCallback) (int, error) {
done, err := cb("", nil, errors.New("intended error"))
require.NoError(t, err)
assert.False(t, done)
return 0, nil
}),
rep.EXPECT().
HandleProcessingError(gomock.Any(), gomock.Any(), "Redis hash scan", matcher.ErrorEq("intended error")),
)
err := r.GetConnectionsByProjectId(context.Background(), info.ProjectId, func(i *ConnectedAgentInfo) (done bool, err error) {
require.FailNow(t, "unexpected call")
return false, nil
})
require.NoError(t, err)
}

func TestGetConnectionsByProjectId_UnmarshalError(t *testing.T) {
r, _, _, byProjectId, rep, info := setupTracker(t)
gomock.InOrder(
byProjectId.EXPECT().
Scan(gomock.Any(), info.ProjectId, gomock.Any()).
Do(func(ctx context.Context, key int64, cb redistool.ScanCallback) (int, error) {
done, err := cb("k2", []byte{1, 2, 3}, nil) // invalid bytes
require.NoError(t, err) // ignores error to keep going
assert.False(t, done)
return 0, nil
}),
rep.EXPECT().
HandleProcessingError(gomock.Any(), gomock.Any(), "Redis proto.Unmarshal(ConnectedAgentInfo)", matcher.ErrorIs(proto.Error)),
)
err := r.GetConnectionsByProjectId(context.Background(), info.ProjectId, func(i *ConnectedAgentInfo) (done bool, err error) {
require.FailNow(t, "unexpected call")
return false, nil
})
require.NoError(t, err)
}

func TestGetConnectionsByAgentId_HappyPath(t *testing.T) {
r, _, byAgentId, _, _, info := setupTracker(t)
r, _, byAgentId, _, info := setupTracker(t)
infoBytes, err := proto.Marshal(info)
require.NoError(t, err)
byAgentId.EXPECT().
Expand All @@ -359,7 +244,7 @@ func TestGetConnectionsByAgentId_HappyPath(t *testing.T) {
}

func TestGetConnectionsByAgentId_ScanError(t *testing.T) {
r, _, byAgentId, _, rep, info := setupTracker(t)
r, _, byAgentId, rep, info := setupTracker(t)
gomock.InOrder(
byAgentId.EXPECT().
Scan(gomock.Any(), info.AgentId, gomock.Any()).
Expand All @@ -380,7 +265,7 @@ func TestGetConnectionsByAgentId_ScanError(t *testing.T) {
}

func TestGetConnectionsByAgentId_UnmarshalError(t *testing.T) {
r, _, byAgentId, _, rep, info := setupTracker(t)
r, _, byAgentId, rep, info := setupTracker(t)
byAgentId.EXPECT().
Scan(gomock.Any(), info.AgentId, gomock.Any()).
Do(func(ctx context.Context, key int64, cb redistool.ScanCallback) (int, error) {
Expand All @@ -399,7 +284,7 @@ func TestGetConnectionsByAgentId_UnmarshalError(t *testing.T) {
}

func TestGetConnectedAgentsCount_HappyPath(t *testing.T) {
r, connectedAgents, _, _, _, _ := setupTracker(t)
r, connectedAgents, _, _, _ := setupTracker(t)
connectedAgents.EXPECT().
Len(gomock.Any(), connectedAgentsKey).
Return(int64(1), nil)
Expand All @@ -409,7 +294,7 @@ func TestGetConnectedAgentsCount_HappyPath(t *testing.T) {
}

func TestGetConnectedAgentsCount_LenError(t *testing.T) {
r, connectedAgents, _, _, _, _ := setupTracker(t)
r, connectedAgents, _, _, _ := setupTracker(t)
connectedAgents.EXPECT().
Len(gomock.Any(), connectedAgentsKey).
Return(int64(0), errors.New("intended error"))
Expand All @@ -418,22 +303,20 @@ func TestGetConnectedAgentsCount_LenError(t *testing.T) {
assert.Zero(t, size)
}

func setupTracker(t *testing.T) (*RedisTracker, *mock_redis.MockExpiringHash[int64, int64], *mock_redis.MockExpiringHash[int64, int64], *mock_redis.MockExpiringHash[int64, int64], *mock_tool.MockErrReporter, *ConnectedAgentInfo) {
func setupTracker(t *testing.T) (*RedisTracker, *mock_redis.MockExpiringHash[int64, int64], *mock_redis.MockExpiringHash[int64, int64], *mock_tool.MockErrReporter, *ConnectedAgentInfo) {
ctrl := gomock.NewController(t)
rep := mock_tool.NewMockErrReporter(ctrl)
connectedAgents := mock_redis.NewMockExpiringHash[int64, int64](ctrl)
byAgentId := mock_redis.NewMockExpiringHash[int64, int64](ctrl)
byProjectId := mock_redis.NewMockExpiringHash[int64, int64](ctrl)
tr := &RedisTracker{
log: zaptest.NewLogger(t),
errRep: rep,
refreshPeriod: time.Minute,
gcPeriod: time.Minute,
connectionsByAgentId: byAgentId,
connectionsByClusterId: byProjectId,
connectedAgents: connectedAgents,
log: zaptest.NewLogger(t),
errRep: rep,
refreshPeriod: time.Minute,
gcPeriod: time.Minute,
connectionsByAgentId: byAgentId,
connectedAgents: connectedAgents,
}
return tr, connectedAgents, byAgentId, byProjectId, rep, connInfo()
return tr, connectedAgents, byAgentId, rep, connInfo()
}

func connInfo() *ConnectedAgentInfo {
Expand All @@ -447,6 +330,6 @@ func connInfo() *ConnectedAgentInfo {
ConnectedAt: timestamppb.Now(),
ConnectionId: 123,
AgentId: 345,
ProjectId: 456,
ClusterId: "456",
}
}
6 changes: 0 additions & 6 deletions pkg/module/kubernetes_api/agent/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,6 @@ func TestClientImpersonation(t *testing.T) {
Username: "iuser1",
Groups: []string{"ig1", "ig2"},
Uid: "iuid",
Extra: []*rpc.ExtraKeyVal{
{
Key: "ix",
Val: []string{"ix1", "ix2"},
},
},
}
requestHeader := http.Header{}
requestHeader.Set(transport.ImpersonateUserHeader, "huser1")
Expand Down
15 changes: 5 additions & 10 deletions pkg/module/kubernetes_api/server/factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,19 @@ func Test_GetAuthorizedProxyUserCacheKeyFunc_AllFieldsUsed(t *testing.T) {

redisKeys := map[string]struct{}{}
redisKeys[keyFunc(proxyUserCacheKey{
agentId: 1,
accessType: "any",
accessKey: "any",
csrfToken: "any",
agentId: 1,
accessKey: "any",
clusterId: "any",
})] = struct{}{}
redisKeys[keyFunc(proxyUserCacheKey{
accessType: "any",
accessKey: "any",
accessKey: "any",
})] = struct{}{}
redisKeys[keyFunc(proxyUserCacheKey{
agentId: 1,
accessKey: "any",
csrfToken: "any",
})] = struct{}{}
redisKeys[keyFunc(proxyUserCacheKey{
agentId: 1,
accessType: "any",
csrfToken: "any",
agentId: 1,
})] = struct{}{}

assert.Equal(t, 4, len(redisKeys))
Expand Down
Loading

0 comments on commit 0801896

Please sign in to comment.