Skip to content

Commit

Permalink
Merge pull request Kong#381 from Mashape/feature/wildcard-subdomains
Browse files Browse the repository at this point in the history
[feature] implement wildcard subdomains for APIs' public_dns

Former-commit-id: a66441044105d0a70ede72c708dc381efa370da1
  • Loading branch information
thibaultcha committed Jul 9, 2015
2 parents d0b59b9 + 2259018 commit d6731f2
Show file tree
Hide file tree
Showing 10 changed files with 309 additions and 216 deletions.
30 changes: 17 additions & 13 deletions kong/dao/cassandra/base_dao.lua
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,11 @@ function BaseDao:insert(t)
assert(t ~= nil, "Cannot insert a nil element")
assert(type(t) == "table", "Entity to insert must be a table")

local ok, db_err, errors
local ok, db_err, errors, self_err

-- Populate the entity with any default/overriden values and validate it
errors = validations.validate(t, self, {
ok, errors, self_err = validations.validate_entity(t, self._schema, {
dao = self._factory,
dao_insert = function(field)
if field.type == "id" then
return uuid()
Expand All @@ -363,13 +364,10 @@ function BaseDao:insert(t)
end
end
})
if errors then
return nil, errors
end

ok, errors = validations.on_insert(t, self._schema, self._factory)
if not ok then
return nil, errors
if self_err then
return nil, self_err
elseif not ok then
return nil, DaoError(errors, error_types.SCHEMA)
end

ok, errors, db_err = self:check_unique_fields(t)
Expand Down Expand Up @@ -440,7 +438,7 @@ function BaseDao:update(t, full)
assert(t ~= nil, "Cannot update a nil element")
assert(type(t) == "table", "Entity to update must be a table")

local ok, db_err, errors
local ok, db_err, errors, self_err

-- Check if exists to prevent upsert
local res, err = self:find_by_primary_key(t)
Expand All @@ -455,9 +453,15 @@ function BaseDao:update(t, full)
end

-- Validate schema
errors = validations.validate(t, self, {partial_update = not full, full_update = full})
if errors then
return nil, errors
ok, errors, self_err = validations.validate_entity(t, self._schema, {
partial_update = not full,
full_update = full,
dao = self._factory
})
if self_err then
return nil, self_err
elseif not ok then
return nil, DaoError(errors, error_types.SCHEMA)
end

ok, errors, db_err = self:check_unique_fields(t, true)
Expand Down
22 changes: 20 additions & 2 deletions kong/dao/schemas/apis.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,30 @@ local function check_public_dns_and_path(value, api_t)
return false, "At least a 'public_dns' or a 'path' must be specified"
end

return true
-- Validate wildcard public_dns
if public_dns then
local _, count = public_dns:gsub("%*", "")
if count > 1 then
return false, "Only one wildcard is allowed: "..public_dns
elseif count > 0 then
local pos = public_dns:find("%*")
local valid
if pos == 1 then
valid = public_dns:match("^%*%.") ~= nil
elseif pos == string.len(public_dns) then
valid = public_dns:match(".%.%*$") ~= nil
end

if not valid then
return false, "Invalid wildcard placement: "..public_dns
end
end
end
end

local function check_path(path, api_t)
local valid, err = check_public_dns_and_path(path, api_t)
if not valid then
if valid == false then
return false, err
end

Expand Down
30 changes: 15 additions & 15 deletions kong/dao/schemas/plugins_configurations.lua
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,32 @@ return {
value = { type = "table", schema = load_value_schema },
enabled = { type = "boolean", default = true }
},
on_insert = function(plugin_t, dao, schema)
self_check = function(self, plugin_t, dao, is_update)
-- Load the value schema
local value_schema, err = schema.fields.value.schema(plugin_t)
local value_schema, err = self.fields.value.schema(plugin_t)
if err then
return false, err
return false, DaoError(err, constants.DATABASE_ERROR_TYPES.SCHEMA)
end

-- Check if the schema has a `no_consumer` field
if value_schema.no_consumer and plugin_t.consumer_id ~= nil and plugin_t.consumer_id ~= constants.DATABASE_NULL_ID then
return false, DaoError("No consumer can be configured for that plugin", constants.DATABASE_ERROR_TYPES.SCHEMA)
end

local res, err = dao.plugins_configurations:find_by_keys({
name = plugin_t.name,
api_id = plugin_t.api_id,
consumer_id = plugin_t.consumer_id
})
if not is_update then
local res, err = dao.plugins_configurations:find_by_keys({
name = plugin_t.name,
api_id = plugin_t.api_id,
consumer_id = plugin_t.consumer_id
})

if err then
return nil, DaoError(err, constants.DATABASE_ERROR_TYPES.DATABASE)
end
if err then
return nil, DaoError(err, constants.DATABASE_ERROR_TYPES.DATABASE)
end

if res and #res > 0 then
return false, DaoError("Plugin configuration already exists", constants.DATABASE_ERROR_TYPES.UNIQUE)
else
return true
if res and #res > 0 then
return false, DaoError("Plugin configuration already exists", constants.DATABASE_ERROR_TYPES.UNIQUE)
end
end
end
}
34 changes: 8 additions & 26 deletions kong/dao/schemas_validation.lua
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
local utils = require "kong.tools.utils"
local stringy = require "stringy"
local DaoError = require "kong.dao.error"
local constants = require "kong.constants"
local error_types = constants.DATABASE_ERROR_TYPES

local POSSIBLE_TYPES = {
id = true,
Expand Down Expand Up @@ -44,7 +41,7 @@ local _M = {}
-- `is_update` For an entity update, check immutable fields. Set to true.
-- @return `valid` Success of validation. True or false.
-- @return `errors` A list of encountered errors during the validation.
function _M.validate_fields(t, schema, options)
function _M.validate_entity(t, schema, options)
if not options then options = {} end
local errors

Expand Down Expand Up @@ -151,7 +148,7 @@ function _M.validate_fields(t, schema, options)

if t[column] and type(t[column]) == "table" then
-- Actually validating the sub-schema
local s_ok, s_errors = _M.validate_fields(t[column], sub_schema, options)
local s_ok, s_errors = _M.validate_entity(t[column], sub_schema, options)
if not s_ok then
for s_k, s_v in pairs(s_errors) do
errors = utils.add_error(errors, column.."."..s_k, s_v)
Expand All @@ -172,7 +169,7 @@ function _M.validate_fields(t, schema, options)
-- [FUNC] Check field against a custom function
-- only if there is no error on that field already.
local ok, err, new_fields = v.func(t[column], t, column)
if not ok and err then
if ok == false and err then
errors = utils.add_error(errors, column, err)
elseif new_fields then
for k, v in pairs(new_fields) do
Expand All @@ -190,29 +187,14 @@ function _M.validate_fields(t, schema, options)
end
end

return errors == nil, errors
end

function _M.on_insert(t, schema, dao)
if schema.on_insert and type(schema.on_insert) == "function" then
local valid, err = schema.on_insert(t, dao, schema)
if not valid or err then
return false, err
else
return true
if errors == nil and type(schema.self_check) == "function" then
local ok, err = schema.self_check(schema, t, options.dao, (options.partial_update or options.full_update))
if ok == false then
return false, nil, err
end
else
return true
end
end

function _M.validate(t, dao, options)
local ok, errors

ok, errors = _M.validate_fields(t, dao._schema, options)
if not ok then
return DaoError(errors, error_types.SCHEMA)
end
return errors == nil, errors
end

local digit = "[0-9a-f]"
Expand Down
86 changes: 61 additions & 25 deletions kong/resolver/access.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,49 @@ local responses = require "kong.tools.responses"

local _M = {}

-- Take a public_dns and make it a pattern for wildcard matching.
-- Only do so if the public_dns actually has a wildcard.
local function create_wildcard_pattern(public_dns)
if string.find(public_dns, "*", 1, true) then
local pattern = string.gsub(public_dns, "%.", "%%.")
pattern = string.gsub(pattern, "*", ".+")
pattern = string.format("^%s$", pattern)
return pattern
end
end

-- Load all APIs in memory.
-- Sort the data for faster lookup: dictionary per public_dns, host,
-- and an array of wildcard public_dns.
local function load_apis_in_memory()
local apis, err = dao.apis:find_all()
if err then
return nil, err
end

-- build dictionnaries of public_dns:api and path:apis for efficient O(1) lookup.
-- we only do O(n) lookup for wildcard public_dns that are in an array.
local dns_dic, dns_wildcard, path_dic = {}, {}, {}
for _, api in ipairs(apis) do
if api.public_dns then
local pattern = create_wildcard_pattern(api.public_dns)
if pattern then
-- If the public_dns is a wildcard, we have a pattern and we can
-- store it in an array for later lookup.
table.insert(dns_wildcard, {pattern = pattern, api = api})
else
-- Keep non-wildcard public_dns in a dictionary for faster lookup.
dns_dic[api.public_dns] = api
end
end
if api.path then
path_dic[api.path] = api
end
end

return {by_dns = dns_dic, wildcard_dns = dns_wildcard, by_path = path_dic}
end

local function get_backend_url(api)
local result = api.target_url

Expand Down Expand Up @@ -37,7 +80,8 @@ end
-- matching the API's `public_dns`, either from the `request_uri` matching the API's `path`.
--
-- To perform this, we need to query _ALL_ APIs in memory. It is the only way to compare the `request_uri`
-- as a regex to the values set in DB. We keep APIs in the database cache for a longer time than usual.
-- as a regex to the values set in DB, as well as matching wildcard dns.
-- We keep APIs in the database cache for a longer time than usual.
-- @see https://github.com/Mashape/kong/issues/15 for an improvement on this.
--
-- @param `request_uri` The URI for this request.
Expand All @@ -49,31 +93,14 @@ end
local function find_api(request_uri)
local retrieved_api

-- retrieve all APIs
local apis_dics, err = cache.get_or_set("ALL_APIS_BY_DIC", function()
local apis, err = dao.apis:find_all()
if err then
return nil, err
end

-- build dictionnaries of public_dns:api and path:apis for efficient lookup.
local dns_dic, path_dic = {}, {}
for _, api in ipairs(apis) do
if api.public_dns then
dns_dic[api.public_dns] = api
end
if api.path then
path_dic[api.path] = api
end
end
return {dns = dns_dic, path = path_dic}
end, 60) -- 60 seconds cache
-- Retrieve all APIs
local apis_dics, err = cache.get_or_set("ALL_APIS_BY_DIC", load_apis_in_memory, 60) -- 60 seconds cache, longer than usual

if err then
return err
end

-- find by Host header
-- Find by Host header
local all_hosts = {}
for _, header_name in ipairs({"Host", constants.HEADERS.HOST_OVERRIDE}) do
local hosts = ngx.req.get_headers()[header_name]
Expand All @@ -85,9 +112,18 @@ local function find_api(request_uri)
for _, host in ipairs(hosts) do
host = unpack(stringy.split(host, ":"))
table.insert(all_hosts, host)
if apis_dics.dns[host] then
retrieved_api = apis_dics.dns[host]
break
if apis_dics.by_dns[host] then
retrieved_api = apis_dics.by_dns[host]
--break
else
-- If the API was not found in the dictionary, maybe it is a wildcard public_dns.
-- In that case, we need to loop over all of them.
for _, wildcard_dns in ipairs(apis_dics.wildcard_dns) do
if string.match(host, wildcard_dns.pattern) then
retrieved_api = wildcard_dns.api
break
end
end
end
end
end
Expand All @@ -99,7 +135,7 @@ local function find_api(request_uri)
end

-- Otherwise, we look for it by path. We have to loop over all APIs and compare the requested URI.
for path, api in pairs(apis_dics.path) do
for path, api in pairs(apis_dics.by_path) do
local m, err = ngx.re.match(request_uri, "^"..path)
if err then
ngx.log(ngx.ERR, "[resolver] error matching requested path: "..err)
Expand Down
Loading

0 comments on commit d6731f2

Please sign in to comment.