Skip to content

Commit

Permalink
add operation level interceptor support
Browse files Browse the repository at this point in the history
  • Loading branch information
AmaliMatharaarachchi committed Jan 4, 2022
1 parent 3ee9e48 commit 3bd3f79
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 64 deletions.
69 changes: 48 additions & 21 deletions adapter/internal/interceptor/interceptor.go
Expand Up @@ -19,22 +19,29 @@ package interceptor

import (
"bytes"
logger "github.com/wso2/product-microgateway/adapter/internal/loggers"
"text/template"

logger "github.com/wso2/product-microgateway/adapter/internal/loggers"
)

//Interceptor hold values used for interceptor
type Interceptor struct {
Context *InvocationContext
RequestExternalCall *HTTPCallConfig
ResponseExternalCall *HTTPCallConfig
ReqFlowInclude *RequestInclusions
RespFlowInclude *RequestInclusions
Context *InvocationContext
RequestFlowEnable bool
ResponseFlowEnable bool
RequestFlow map[string]Config // key:operation method -> value:config
ResponseFlow map[string]Config // key:operation method -> value:config
}

//HTTPCallConfig hold values used for external interceptor engine
type Config struct {
Enable bool
ExternalCall *HTTPCallConfig
Include *RequestInclusions
}

//HTTPCallConfig hold values used for external interceptor engine
type HTTPCallConfig struct {
Enable bool
ClusterName string
Timeout string // in milli seconds
}
Expand Down Expand Up @@ -69,11 +76,14 @@ var (
// Note: this template only applies if request or response interceptor is enabled
commonTemplate = `
local interceptor = require 'home.wso2.interceptor.lib.interceptor'
{{if .ResponseExternalCall.Enable}} {{/* resp_flow details are required in req flow if request info needed in resp flow */}}
local resp_flow = {invocationContext={{.RespFlowInclude.InvocationContext}}, requestHeaders={{.RespFlowInclude.RequestHeaders}}, requestBody={{.RespFlowInclude.RequestBody}}, requestTrailer={{.RespFlowInclude.RequestTrailer}},
responseHeaders={{.RespFlowInclude.ResponseHeaders}}, responseBody={{.RespFlowInclude.ResponseBody}}, responseTrailers={{.RespFlowInclude.ResponseTrailers}}}
{{else}}local resp_flow = {}{{end}} {{/* if resp_flow disabled no need req info in resp path */}}
{{if or .ReqFlowInclude.InvocationContext .RespFlowInclude.InvocationContext}}
{{if .ResponseFlowEnable}} {{/* resp_flow details are required in req flow if request info needed in resp flow */}}
local resp_flow_list = {
{{ range $key, $value := .ResponseFlow }}
{{ $key }} = {invocationContext={{$value.Include.InvocationContext}}, requestHeaders={{$value.Include.RequestHeaders}}, requestBody={{$value.Include.RequestBody}}, requestTrailer={{$value.Include.RequestTrailer}},
responseHeaders={{$value.Include.ResponseHeaders}}, responseBody={{$value.Include.ResponseBody}}, responseTrailers={{$value.Include.ResponseTrailers}}}
{{ end }}
}
{{else}}local resp_flow_list = {}{{end}} {{/* if resp_flow disabled no need req info in resp path */}}
local inv_context = {
organizationId = "{{.Context.OrganizationID}}",
basePath = "{{.Context.BasePath}}",
Expand All @@ -85,32 +95,49 @@ local inv_context = {
prodClusterName = "{{.Context.ProdClusterName}}",
sandClusterName = "{{.Context.SandClusterName}}"
}
{{else}}local inv_context = nil{{end}}
`
requestInterceptorTemplate = `
local req_flow = {invocationContext={{.ReqFlowInclude.InvocationContext}}, requestHeaders={{.ReqFlowInclude.RequestHeaders}}, requestBody={{.ReqFlowInclude.RequestBody}}, requestTrailer={{.ReqFlowInclude.RequestTrailer}}}
local req_flow_list = {
{{ range $key, $value := .RequestFlow }}
{{ $key }}= {invocationContext={{$value.Include.InvocationContext}}, requestHeaders={{$value.Include.RequestHeaders}}, requestBody={{$value.Include.RequestBody}}, requestTrailer={{$value.Include.RequestTrailer}}}
{{ end }}
}
local req_call_config = {
{{ range $key, $value := .RequestFlow }}
{{ $key }}={ClusterName={{$value.ExternalCall.ClusterName}}, Timeout={{$value.ExternalCall.Timeout}}
{{ end }}
}
function envoy_on_request(request_handle)
method=request_handle:headers():get(":method")
interceptor.handle_request_interceptor(
request_handle,
{cluster_name="{{.RequestExternalCall.ClusterName}}", timeout={{.RequestExternalCall.Timeout}}},
req_flow, resp_flow, inv_context
{cluster_name=req_call_config[method].ClusterName, timeout=req_call_config[method].Timeout},
req_flow_list[method], resp_flow_list[method], inv_context
)
end
`
//get method in response flow
responseInterceptorTemplate = `
local res_call_config = {
{{ range $key, $value := .ResponseFlow }}
{{ $key }}= {ClusterName={{$value.ExternalCall.ClusterName}}, Timeout={{$value.ExternalCall.Timeout}}
{{ end }}
}
function envoy_on_response(response_handle)
interceptor.handle_response_interceptor(
method=request_handle:headers():get(":method")
response_handle,
{cluster_name="{{.ResponseExternalCall.ClusterName}}", timeout={{.ResponseExternalCall.Timeout}}},
resp_flow
{cluster_name=req_call_config[method].ClusterName, timeout=req_call_config[method].Timeout},
resp_flow_list[method]
)
end
`
// defaultRequestInterceptorTemplate is the template that is applied when request flow is disabled
// just updated req flow info with resp flow without calling interceptor service
defaultRequestInterceptorTemplate = `
function envoy_on_request(request_handle)
interceptor.handle_request_interceptor(request_handle, {}, {}, resp_flow, inv_context, true)
method=request_handle:headers():get(":method")
interceptor.handle_request_interceptor(request_handle, {}, {}, resp_flow_list[method], inv_context, true)
end
`
// defaultResponseInterceptorTemplate is the template that is applied when response flow is disabled
Expand All @@ -123,8 +150,8 @@ end
//GetInterceptor inject values and get request interceptor
// Note: This method is called only if one of request or response interceptor is enabled
func GetInterceptor(values *Interceptor) string {
templ := template.Must(template.New("lua-filter").Parse(getTemplate(values.RequestExternalCall.Enable,
values.ResponseExternalCall.Enable)))
templ := template.Must(template.New("lua-filter").Parse(getTemplate(values.RequestFlowEnable,
values.ResponseFlowEnable)))
var out bytes.Buffer
err := templ.Execute(&out, values)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions adapter/internal/oasparser/envoyconf/internal_dtos.go
Expand Up @@ -37,7 +37,7 @@ type routeCreateParams struct {
prodRouteConfig *model.EndpointConfig
sandRouteConfig *model.EndpointConfig
AuthHeader string
requestInterceptor model.InterceptEndpoint
responseInterceptor model.InterceptEndpoint
requestInterceptor map[string]model.InterceptEndpoint
responseInterceptor map[string]model.InterceptEndpoint
corsPolicy *model.CorsConfig
}
115 changes: 82 additions & 33 deletions adapter/internal/oasparser/envoyconf/routes_with_clusters.go
Expand Up @@ -151,7 +151,7 @@ func CreateRoutesWithClusters(mgwSwagger model.MgwSwagger, upstreamCerts map[str
}

var interceptorErr error
apiRequestInterceptor, interceptorErr = mgwSwagger.GetInterceptor(mgwSwagger.GetVendorExtensions(), xWso2requestInterceptor)
apiRequestInterceptor, interceptorErr = mgwSwagger.GetInterceptor(mgwSwagger.GetVendorExtensions(), xWso2requestInterceptor, "api")
// if lua filter exists on api level, add cluster
if interceptorErr == nil && apiRequestInterceptor.Enable {
logger.LoggerOasparser.Debugf("API level request interceptors found for %v : %v", apiTitle, apiVersion)
Expand All @@ -167,7 +167,7 @@ func CreateRoutesWithClusters(mgwSwagger model.MgwSwagger, upstreamCerts map[str
endpoints = append(endpoints, addresses...)
}
}
apiResponseInterceptor, interceptorErr = mgwSwagger.GetInterceptor(mgwSwagger.GetVendorExtensions(), xWso2responseInterceptor)
apiResponseInterceptor, interceptorErr = mgwSwagger.GetInterceptor(mgwSwagger.GetVendorExtensions(), xWso2responseInterceptor, "api")
// if lua filter exists on api level, add cluster
if interceptorErr == nil && apiResponseInterceptor.Enable {
logger.LoggerOasparser.Debugln("API level response interceptors found for " + mgwSwagger.GetID())
Expand Down Expand Up @@ -269,7 +269,8 @@ func CreateRoutesWithClusters(mgwSwagger model.MgwSwagger, upstreamCerts map[str
clusterNameSand = ""
}

reqInterceptorVal, err := mgwSwagger.GetInterceptor(resource.GetVendorExtensions(), xWso2requestInterceptor)
// create resource level request interceptor cluster
reqInterceptorVal, err := mgwSwagger.GetInterceptor(resource.GetVendorExtensions(), xWso2requestInterceptor, "resource")
if err == nil && reqInterceptorVal.Enable {
logger.LoggerOasparser.Debugf("Resource level request interceptors found for %v:%v-%v", apiTitle, apiVersion, resource.GetPath())
reqInterceptorVal.ClusterName = getClusterName(requestInterceptClustersNamePrefix, organizationID, vHost,
Expand All @@ -284,7 +285,30 @@ func CreateRoutesWithClusters(mgwSwagger model.MgwSwagger, upstreamCerts map[str
endpoints = append(endpoints, addresses...)
}
}
respInterceptorVal, err := mgwSwagger.GetInterceptor(resource.GetVendorExtensions(), xWso2responseInterceptor)

// create operational level response interceptor clusters
operationalReqInterceptors := mgwSwagger.GetOperationInterceptors(apiRequestInterceptor, resourceRequestInterceptor, resource.GetMethod(),
xWso2requestInterceptor)
for method, opI := range operationalReqInterceptors {
if opI.Enable && opI.Level == "operation" {
logger.LoggerOasparser.Debugf("Operation level request interceptors found for %v:%v-%v-%v", apiTitle, apiVersion, resource.GetPath(),
opI.ClusterName)
opI.ClusterName = getClusterName(requestInterceptClustersNamePrefix, organizationID, vHost, apiTitle, apiVersion, opI.ClusterName)
cluster, addresses, err := CreateLuaCluster(interceptorCerts, opI)
if err != nil {
logger.LoggerOasparser.Errorf("Error while adding operational level request intercept external cluster for %s. %v",
apiTitle, err.Error())
// setting resource level interceptor to failed operation level interceptor.
operationalReqInterceptors[method] = resourceRequestInterceptor
} else {
clusters = append(clusters, cluster)
endpoints = append(endpoints, addresses...)
}
}
}

// create resource level response interceptor cluster
respInterceptorVal, err := mgwSwagger.GetInterceptor(resource.GetVendorExtensions(), xWso2responseInterceptor, "resource")
if err == nil && respInterceptorVal.Enable {
logger.LoggerOasparser.Debugf("Resource level response interceptors found for %v:%v-%v"+apiTitle, apiVersion, resource.GetPath())
respInterceptorVal.ClusterName = getClusterName(responseInterceptClustersNamePrefix, organizationID,
Expand All @@ -300,13 +324,34 @@ func CreateRoutesWithClusters(mgwSwagger model.MgwSwagger, upstreamCerts map[str
}
}

// create operation level response interceptor clusters
operationalRespInterceptorVal := mgwSwagger.GetOperationInterceptors(apiResponseInterceptor, resourceResponseInterceptor, resource.GetMethod(),
xWso2responseInterceptor)
for method, opI := range operationalRespInterceptorVal {
if opI.Enable && opI.Level == "operation" {
logger.LoggerOasparser.Debugf("Operational level response interceptors found for %v:%v-%v-%v", apiTitle, apiVersion, resource.GetPath(),
opI.ClusterName)
opI.ClusterName = getClusterName(responseInterceptClustersNamePrefix, organizationID, vHost, apiTitle, apiVersion, opI.ClusterName)
cluster, addresses, err := CreateLuaCluster(interceptorCerts, opI)
if err != nil {
logger.LoggerOasparser.Errorf("Error while adding operational level response intercept external cluster for %s. %v",
apiTitle, err.Error())
// setting resource level interceptor to failed operation level interceptor.
operationalRespInterceptorVal[method] = resourceResponseInterceptor
} else {
clusters = append(clusters, cluster)
endpoints = append(endpoints, addresses...)
}
}
}

routeP := createRoute(genRouteCreateParams(&mgwSwagger, resource, vHost, resourceBasePath, clusterNameProd,
clusterNameSand, resourceRequestInterceptor, resourceResponseInterceptor, organizationID))
clusterNameSand, operationalReqInterceptors, operationalRespInterceptorVal, organizationID))
routes = append(routes, routeP)
}
if mgwSwagger.GetAPIType() == model.WS {
routesP := createRoute(genRouteCreateParams(&mgwSwagger, nil, vHost, apiLevelbasePath, apiLevelClusterNameProd,
apiLevelClusterNameSand, apiRequestInterceptor, apiResponseInterceptor, organizationID))
apiLevelClusterNameSand, nil, nil, organizationID))
routes = append(routes, routesP)
}
return routes, clusters, endpoints
Expand Down Expand Up @@ -723,7 +768,7 @@ func createRoute(params *routeCreateParams) *routev3.Route {
}

var luaPerFilterConfig lua.LuaPerRoute
if !requestInterceptor.Enable && !responseInterceptor.Enable {
if len(requestInterceptor) < 1 && len(responseInterceptor) < 1 {
luaPerFilterConfig = lua.LuaPerRoute{
Override: &lua.LuaPerRoute_Disabled{Disabled: true},
}
Expand Down Expand Up @@ -841,35 +886,39 @@ func createRoute(params *routeCreateParams) *routev3.Route {
return &router
}

func getInlineLuaScript(requestInterceptor model.InterceptEndpoint, responseInterceptor model.InterceptEndpoint,
func getInlineLuaScript(requestInterceptor map[string]model.InterceptEndpoint, responseInterceptor map[string]model.InterceptEndpoint,
requestContext *interceptor.InvocationContext) string {

i := &interceptor.Interceptor{
Context: requestContext,
RequestExternalCall: &interceptor.HTTPCallConfig{}, // assign default values ("false" if req flow disabled)
ResponseExternalCall: &interceptor.HTTPCallConfig{},
ReqFlowInclude: &interceptor.RequestInclusions{}, // assign default values ("false" if headers not included req details)
RespFlowInclude: &interceptor.RequestInclusions{},
}
if requestInterceptor.Enable {
i.RequestExternalCall = &interceptor.HTTPCallConfig{
Enable: true,
ClusterName: requestInterceptor.ClusterName,
// multiplying in seconds here because in configs we are directly getting config to time.Duration
// which is in nano seconds, so multiplying it in seconds here
Timeout: strconv.FormatInt((requestInterceptor.RequestTimeout * time.Second).Milliseconds(), 10),
Context: requestContext,
}
if len(requestInterceptor) > 0 {
i.RequestFlowEnable = true
for method, op := range requestInterceptor {
i.RequestFlow[method] = interceptor.Config{
ExternalCall: &interceptor.HTTPCallConfig{
ClusterName: op.ClusterName,
// multiplying in seconds here because in configs we are directly getting config to time.Duration
// which is in nano seconds, so multiplying it in seconds here
Timeout: strconv.FormatInt((op.RequestTimeout * time.Second).Milliseconds(), 10),
},
Include: op.Includes,
}
}
i.ReqFlowInclude = requestInterceptor.Includes
}
if responseInterceptor.Enable {
i.ResponseExternalCall = &interceptor.HTTPCallConfig{
Enable: true,
ClusterName: responseInterceptor.ClusterName,
// multiplying in seconds here because in configs we are directly getting config to time.Duration
// which is in nano seconds, so multiplying it in seconds here
Timeout: strconv.FormatInt((requestInterceptor.RequestTimeout * time.Second).Milliseconds(), 10),
}
if len(responseInterceptor) > 0 {
i.ResponseFlowEnable = true
for method, op := range requestInterceptor {
i.RequestFlow[method] = interceptor.Config{
ExternalCall: &interceptor.HTTPCallConfig{
ClusterName: op.ClusterName,
// multiplying in seconds here because in configs we are directly getting config to time.Duration
// which is in nano seconds, so multiplying it in seconds here
Timeout: strconv.FormatInt((op.RequestTimeout * time.Second).Milliseconds(), 10),
},
Include: op.Includes,
}
}
i.RespFlowInclude = responseInterceptor.Includes
}
return interceptor.GetInterceptor(i)
}
Expand Down Expand Up @@ -1170,8 +1219,8 @@ func getCorsPolicy(corsConfig *model.CorsConfig) *routev3.CorsPolicy {
}

func genRouteCreateParams(swagger *model.MgwSwagger, resource *model.Resource, vHost, endpointBasePath string,
prodClusterName string, sandClusterName string, requestInterceptor model.InterceptEndpoint,
responseInterceptor model.InterceptEndpoint, organizationID string) *routeCreateParams {
prodClusterName string, sandClusterName string, requestInterceptor map[string]model.InterceptEndpoint,
responseInterceptor map[string]model.InterceptEndpoint, organizationID string) *routeCreateParams {
params := &routeCreateParams{
organizationID: organizationID,
title: swagger.GetTitle(),
Expand Down
27 changes: 22 additions & 5 deletions adapter/internal/oasparser/model/api_operation.go
Expand Up @@ -19,12 +19,16 @@
// and create a common model which can represent both types.
package model

import "github.com/google/uuid"

// Operation type object holds data about each http method in the REST API.
type Operation struct {
method string
security []map[string][]string
tier string
disableSecurity bool
iD string
method string
security []map[string][]string
tier string
disableSecurity bool
vendorExtensions map[string]interface{}
}

// GetMethod returns the http method name of the give API operation
Expand Down Expand Up @@ -52,9 +56,22 @@ func (operation *Operation) GetTier() string {
return operation.tier
}

// GetVendorExtensions returns vendor extensions which are explicitly defined under
// a given resource.
func (operation *Operation) GetVendorExtensions() map[string]interface{} {
return operation.vendorExtensions
}

// GetID returns the id of a given resource.
// This is a randomly generated UUID
func (operation *Operation) GetID() string {
return operation.iD
}

// NewOperation Creates and returns operation type object
func NewOperation(method string, security []map[string][]string, extensions map[string]interface{}) *Operation {
tier := ResolveThrottlingTier(extensions)
disableSecurity := ResolveDisableSecurity(extensions)
return &Operation{method, security, tier, disableSecurity}
id := uuid.New().String()
return &Operation{id, method, security, tier, disableSecurity, extensions}
}

0 comments on commit 3bd3f79

Please sign in to comment.