Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor:优化鉴权相关功能开关以及修复NPE问题 #1155

Merged
merged 12 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion apiserver/eurekaserver/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ func (h *EurekaServer) UpdateStatus(req *restful.Request, rsp *restful.Response)
}
code := h.updateStatus(context.Background(), namespace, appId, instId, status, false)
writePolarisStatusCode(req, code)
if code == api.ExecuteSuccess {
if code == api.ExecuteSuccess || code == api.NoNeedUpdate {
log.Infof("[EUREKA-SERVER]instance (namespace=%s, instId=%s, appId=%s) has been updated successfully",
namespace, instId, appId)
writeHeader(http.StatusOK, rsp)
Expand Down
1 change: 0 additions & 1 deletion apiserver/grpcserver/config/client_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,5 @@ func (g *ConfigGRPCServer) WatchConfigFiles(ctx context.Context,
if err != nil {
return nil, err
}

return callback(), nil
}
13 changes: 2 additions & 11 deletions apiserver/httpserver/config_client_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,7 @@ func (h *HTTPServer) ClientGetConfigFile(req *restful.Request, rsp *restful.Resp
Response: rsp,
}

version, err := strconv.ParseUint(handler.Request.QueryParameter("version"), 10, 64)
if err != nil {
handler.WriteHeaderAndProto(api.NewConfigClientResponseWithMessage(
apimodel.Code_BadRequest, "version must be number"))
}

version, _ := strconv.ParseUint(handler.Request.QueryParameter("version"), 10, 64)
configFile := &apiconfig.ClientConfigFileInfo{
Namespace: &wrapperspb.StringValue{Value: handler.Request.QueryParameter("namespace")},
Group: &wrapperspb.StringValue{Value: handler.Request.QueryParameter("group")},
Expand All @@ -49,7 +44,6 @@ func (h *HTTPServer) ClientGetConfigFile(req *restful.Request, rsp *restful.Resp
}

response := h.configServer.GetConfigFileForClient(handler.ParseHeaderContext(), configFile)

handler.WriteHeaderAndProto(response)
}

Expand All @@ -61,9 +55,7 @@ func (h *HTTPServer) ClientWatchConfigFile(req *restful.Request, rsp *restful.Re

// 1. 解析出客户端监听的配置文件列表
watchConfigFileRequest := &apiconfig.ClientWatchConfigFileRequest{}

_, err := handler.Parse(watchConfigFileRequest)
if err != nil {
if _, err := handler.Parse(watchConfigFileRequest); err != nil {
handler.WriteHeaderAndProto(api.NewResponseWithMsg(apimodel.Code_ParseException, err.Error()))
return
}
Expand All @@ -74,6 +66,5 @@ func (h *HTTPServer) ClientWatchConfigFile(req *restful.Request, rsp *restful.Re
handler.WriteHeaderAndProto(api.NewResponseWithMsg(apimodel.Code_ExecuteException, err.Error()))
return
}

handler.WriteHeaderAndProto(callback())
}
12 changes: 5 additions & 7 deletions apiserver/xdsserverv3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,16 +282,14 @@ func (x *XDSServer) getRegistryInfoWithCache(ctx context.Context, registryInfo m
}

// 获取routing配置
routeResp := x.namingServer.GetRoutingConfigWithCache(ctx, s)
if routeResp.GetCode().Value != api.ExecuteSuccess {
log.Errorf("error sync routing for %s, info : %s", svc.Name, routeResp.Info.GetValue())
routerRule, err := x.namingServer.Cache().RoutingConfig().GetRouterConfigV2("", svc.Name, svc.Namespace)
if err != nil {
log.Errorf("error sync routing for %s, info : %s", svc.Name, err.Error())
return fmt.Errorf("[XDSV3] error sync routing for %s", svc.Name)
}

if routeResp.Routing != nil {
svc.SvcRoutingRevision = routeResp.Routing.Revision.Value
svc.Routing = routeResp.Routing
}
svc.SvcRoutingRevision = routerRule.GetRevision().GetValue()
svc.Routing = routerRule

// 获取instance配置
resp := x.namingServer.ServiceInstancesCache(ctx, s)
Expand Down
10 changes: 8 additions & 2 deletions auth/defaultauth/auth_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ func (d *defaultAuthChecker) Initialize(options *auth.Config, s store.Store, cac
if err := cfg.Verify(); err != nil {
return err
}

// 兼容原本老的配置逻辑
if cfg.Strict {
cfg.ConsoleOpen = cfg.Strict
}
AuthOption = cfg
d.cacheMgn = cacheMgn
return nil
Expand Down Expand Up @@ -208,7 +211,10 @@ func canDowngradeAnonymous(authCtx *model.AcquireContext, err error) bool {
if authCtx.GetModule() == model.AuthModule {
return false
}
if AuthOption.Strict {
if authCtx.IsFromClient() && AuthOption.ClientStrict {
return false
}
if authCtx.IsFromConsole() && AuthOption.ConsoleStrict {
return false
}
if errors.Is(err, model.ErrorTokenInvalid) {
Expand Down
19 changes: 11 additions & 8 deletions auth/defaultauth/auth_checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1126,10 +1126,12 @@ func Test_defaultAuthChecker_Initialize(t *testing.T) {
err := authChecker.Initialize(cfg, storage, cacheMgn)
assert.NoError(t, err)
assert.Equal(t, &AuthConfig{
ConsoleOpen: true,
ClientOpen: true,
Salt: "polarismesh@2021",
Strict: false,
ConsoleOpen: true,
ClientOpen: true,
Salt: "polarismesh@2021",
Strict: false,
ConsoleStrict: true,
ClientStrict: false,
}, AuthOption)
})

Expand All @@ -1155,10 +1157,11 @@ func Test_defaultAuthChecker_Initialize(t *testing.T) {
err := authChecker.Initialize(cfg, storage, cacheMgn)
assert.NoError(t, err)
assert.Equal(t, &AuthConfig{
ConsoleOpen: true,
ClientOpen: true,
Salt: "polarismesh@2021",
Strict: false,
ConsoleOpen: true,
ClientOpen: true,
Salt: "polarismesh@2021",
Strict: false,
ConsoleStrict: true,
}, AuthOption)
})

Expand Down
2 changes: 2 additions & 0 deletions auth/defaultauth/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ func reset(strict bool) {
AuthOption.ClientOpen = true
AuthOption.ConsoleOpen = true
AuthOption.Strict = strict
AuthOption.ConsoleStrict = strict
AuthOption.ClientStrict = strict
}

func initCache(ctrl *gomock.Controller) (*cache.Config, *storemock.MockStore) {
Expand Down
14 changes: 11 additions & 3 deletions auth/defaultauth/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ type AuthConfig struct {
// Salt 相关密码、token加密的salt
Salt string `json:"salt" xml:"salt"`
// Strict 是否启用鉴权的严格模式,即对于没有任何鉴权策略的资源,也必须带上正确的token才能操作, 默认关闭
// Deprecated
Strict bool `json:"strict"`
// ConsoleStrict 是否启用鉴权的严格模式,即对于没有任何鉴权策略的资源,也必须带上正确的token才能操作, 默认关闭
ConsoleStrict bool `json:"consoleStrict"`
// ClientStrict 是否启用鉴权的严格模式,即对于没有任何鉴权策略的资源,也必须带上正确的token才能操作, 默认关闭
ClientStrict bool `json:"clientStrict"`
}

// Verify 检查配置是否合法
Expand All @@ -54,8 +59,11 @@ func DefaultAuthConfig() *AuthConfig {
ConsoleOpen: true,
// 针对客户端接口,默认不开启鉴权操作
ClientOpen: false,
Salt: "polarismesh@2021",
// 这里默认开启强 Token 检查模式
Strict: true,
// Salt token 加密 key
Salt: "polarismesh@2021",
// 这里默认开启 OpenAPI 的强 Token 检查模式
ConsoleStrict: true,
// 客户端接口默认不开启 token 强检查模式
ClientStrict: false,
}
}
19 changes: 10 additions & 9 deletions bootstrap/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,17 @@ func StartDiscoverComponents(ctx context.Context, cfg *boot_config.Config, s sto
if err != nil {
return err
}
cacheProvider, err := healthCheckServer.CacheProvider()
if err != nil {
return err
if cfg.HealthChecks.Open {
cacheProvider, err := healthCheckServer.CacheProvider()
if err != nil {
return err
}
healthCheckServer.SetServiceCache(cacheMgn.Service())
healthCheckServer.SetInstanceCache(cacheMgn.Instance())
// 为 instance 的 cache 添加 健康检查的 Listener
cacheMgn.AddListener(cache.CacheNameInstance, []cache.Listener{cacheProvider})
cacheMgn.AddListener(cache.CacheNameClient, []cache.Listener{cacheProvider})
}
healthCheckServer.SetServiceCache(cacheMgn.Service())
healthCheckServer.SetInstanceCache(cacheMgn.Instance())

// 为 instance 的 cache 添加 健康检查的 Listener
cacheMgn.AddListener(cache.CacheNameInstance, []cache.Listener{cacheProvider})
cacheMgn.AddListener(cache.CacheNameClient, []cache.Listener{cacheProvider})

namespaceSvr, err := namespace.GetServer()
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,11 @@ func (bc *baseCache) doCacheUpdate(name string, executor func() (map[string]time
log.Warnf("[Cache][%s] get store timestamp fail, skip update lastMtime, err : %v", name, err)
}
defer func() {
bc.lastFetchTime = curStoreTime
if err := recover(); err != nil {
log.Errorf("[Cache][%s] run cache update panic: %+v", name, err)
} else {
bc.lastFetchTime = curStoreTime
}
}()

start := time.Now()
Expand Down
16 changes: 14 additions & 2 deletions cache/routing_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ type (
Cache
// GetRouterConfig Obtain routing configuration based on serviceid
GetRouterConfig(id, service, namespace string) (*apitraffic.Routing, error)
// GetRouterConfig Obtain routing configuration based on serviceid
GetRouterConfigV2(id, service, namespace string) (*apitraffic.Routing, error)
// GetRoutingConfigCount Get the total number of routing configuration cache
GetRoutingConfigCount() int
// QueryRoutingConfigsV2 Query Route Configuration List
Expand Down Expand Up @@ -147,8 +149,8 @@ func (rc *routingConfigCache) name() string {
return RoutingConfigName
}

// GetRouterConfig Obtain routing configuration based on serviceid
func (rc *routingConfigCache) GetRouterConfig(id, service, namespace string) (*apitraffic.Routing, error) {
// GetRouterConfigV2 Obtain routing configuration based on serviceid
func (rc *routingConfigCache) GetRouterConfigV2(id, service, namespace string) (*apitraffic.Routing, error) {
if id == "" && service == "" && namespace == "" {
return nil, nil
}
Expand Down Expand Up @@ -186,6 +188,16 @@ func (rc *routingConfigCache) GetRouterConfig(id, service, namespace string) (*a
return formatRoutingResponseV1(resp), nil
}

// GetRouterConfig Obtain routing configuration based on serviceid
func (rc *routingConfigCache) GetRouterConfig(id, service, namespace string) (*apitraffic.Routing, error) {
ret, err := rc.GetRouterConfigV2(id, service, namespace)
if err != nil {
return nil, err
}
ret.Rules = nil
return ret, nil
}

// formatRoutingResponseV1 Give the client's cache, no need to expose EXTENDINFO information data
func formatRoutingResponseV1(ret *apitraffic.Routing) *apitraffic.Routing {
inBounds := ret.Inbounds
Expand Down
29 changes: 15 additions & 14 deletions config/client_config_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,15 @@ func (s *Server) GetConfigFileForClient(ctx context.Context,
apimodel.Code_ExecuteException, "load config file error")
}
}

log.Info("[Config][Client] client get config file success.", utils.ZapRequestIDByCtx(ctx),
zap.String("client", utils.ParseClientAddress(ctx)), zap.String("file", fileName),
zap.Uint64("version", entry.Version))

configFile, err := transferEntry2APIModel(client, entry)
if err != nil {
log.Error("[Config][Service] transfer entry to api model error.", utils.ZapRequestIDByCtx(ctx), zap.Error(err))
return api.NewConfigClientResponseWithMessage(
apimodel.Code_ExecuteException, "transfer entry to api model error")
}
log.Info("[Config][Client] client get config file success.", utils.ZapRequestIDByCtx(ctx),
zap.String("client", utils.ParseClientAddress(ctx)), zap.String("file", fileName),
zap.Uint64("version", entry.Version))
return api.NewConfigClientResponse(apimodel.Code_ExecuteSuccess, configFile)
}

Expand Down Expand Up @@ -116,17 +114,16 @@ func (s *Server) WatchConfigFiles(ctx context.Context,
clientAddr := utils.ParseClientAddress(ctx)
watchFiles := request.GetWatchFiles()
// 2. 检查客户端是否有版本落后
if resp := s.DoCheckClientConfigFile(ctx, watchFiles, compareByVersion); resp.Code.GetValue() != api.DataNoChange {
resp := s.DoCheckClientConfigFile(ctx, watchFiles, compareByVersion)
if resp.Code.GetValue() != api.DataNoChange {
return func() *apiconfig.ConfigClientResponse {
return resp
}, nil
}

// 3. 监听配置变更,hold 请求 30s,30s 内如果有配置发布,则响应请求
clientId := clientAddr + "@" + utils.NewUUID()[0:8]

finishChan := s.ConnManager().AddConn(clientId, watchFiles)

return func() *apiconfig.ConfigClientResponse {
return <-finishChan
}, nil
Expand Down Expand Up @@ -201,27 +198,31 @@ func transferEntry2APIModel(client *apiconfig.ClientConfigFileInfo,

dataKey := entry.GetDataKey()
encryptAlgo := entry.GetEncryptAlgo()
if dataKey != "" && publicKey != "" {
if !(dataKey != "") {
dataKeyBytes, err := base64.StdEncoding.DecodeString(dataKey)
if err != nil {
log.Error("[Config][Service] base64 decode data key error.", zap.String("dataKey", dataKey), zap.Error(err))
return nil, err
}
cipherDataKey, err := rsa.EncryptToBase64(dataKeyBytes, publicKey)
if err != nil {
log.Error("[Config][Service] rsa encrypt data key error.", zap.String("dataKey", dataKey), zap.Error(err))
if publicKey != "" {
cipherDataKey, err := rsa.EncryptToBase64(dataKeyBytes, publicKey)
if err != nil {
log.Error("[Config][Service] rsa encrypt data key error.",
zap.String("dataKey", dataKey), zap.Error(err))
} else {
dataKey = cipherDataKey
}
}
configFile.Tags = append(configFile.Tags,
&apiconfig.ConfigFileTag{
Key: utils.NewStringValue(utils.ConfigFileTagKeyDataKey),
Value: utils.NewStringValue(cipherDataKey),
Value: utils.NewStringValue(dataKey),
},
&apiconfig.ConfigFileTag{
Key: utils.NewStringValue(utils.ConfigFileTagKeyEncryptAlgo),
Value: utils.NewStringValue(encryptAlgo),
},
)
return configFile, nil
}
return configFile, nil
}
4 changes: 0 additions & 4 deletions config/config_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -845,10 +845,6 @@ func (s *Server) decryptConfigFile(ctx context.Context, configFile *apiconfig.Co
}
configFile.Tags = filterTags
}
// 非创建人请求不解密
if utils.ParseUserName(ctx) != configFile.CreateBy.GetValue() {
return nil
}
// 非加密文件不解密
if dataKey == "" {
return nil
Expand Down
4 changes: 2 additions & 2 deletions config/config_file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ func Test_decryptConfigFile(t *testing.T) {
wantErr: nil,
},
{
name: "non creator don't decrypt config file",
name: "non creator can decrypt config file",
args: args{
ctx: context.WithValue(context.Background(), utils.ContextUserNameKey, "test"),
configFile: &apiconfig.ConfigFile{
Expand All @@ -654,7 +654,7 @@ func Test_decryptConfigFile(t *testing.T) {
CreateBy: utils.NewStringValue("polaris"),
},
},
want: "YnLZ0SYuujFBHjYHAZVN5A==",
want: "polaris",
wantErr: nil,
},
}
Expand Down
6 changes: 2 additions & 4 deletions config/connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ func (c *connManager) removeConn(clientId string) {
return
}
connObj := conn.(*connection)

c.watchCenter.RemoveWatcher(clientId, connObj.watchConfigFiles)

cm.conns.Delete(clientId)
}

Expand All @@ -112,11 +110,11 @@ func (c *connManager) startHandleTimeoutRequestWorker(ctx context.Context) {
if cm.conns == nil {
continue
}
tNow := time.Now()
cm.conns.Range(func(client, conn interface{}) bool {
connCtx := conn.(*connection)
if time.Now().After(connCtx.finishTime) {
if tNow.After(connCtx.finishTime) {
connCtx.finishChan <- notModifiedResponse

c.removeConn(client.(string))
}
return true
Expand Down