Skip to content

Commit

Permalink
[Rework] Lua_util: Another rework for extract_specific_urls
Browse files Browse the repository at this point in the history
  • Loading branch information
vstakhov committed Aug 20, 2019
1 parent 9cf1605 commit 2bccf65
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 65 deletions.
134 changes: 72 additions & 62 deletions lualib/lua_util.lua
Expand Up @@ -682,6 +682,7 @@ exports.filter_specific_urls = function (urls, params)
end
local function process_single_url(u)
local priority = 1 -- Normal priority
local esld = u:get_tld()
if params.ignore_redirected and u:is_redirected() then
Expand All @@ -697,36 +698,40 @@ exports.filter_specific_urls = function (urls, params)
local str_hash = tostring(u)
if esld then
-- Special cases
if (u:get_protocol() ~= 'mailto') and (not u:is_html_displayed()) then
if u:is_obscured() then
priority = 2
else
if u:get_user() then
priority = 2
elseif u:is_subject() or u:is_phished() then
priority = 2
end
end
elseif u:is_html_displayed() then
priority = 0
end
if not eslds[esld] then
eslds[esld] = {{str_hash, u}}
eslds[esld] = {{str_hash, u, priority}}
neslds = neslds + 1
else
if #eslds[esld] < params.esld_limit then
table.insert(eslds[esld], {str_hash, u})
table.insert(eslds[esld], {str_hash, u, priority})
end
end
-- eSLD - 1 part => tld
local parts = rspamd_str_split(esld, '.')
local tld = table.concat(fun.totable(fun.tail(parts)), '.')
if not tlds[tld] then
tlds[tld] = {{str_hash, u}}
tlds[tld] = {{str_hash, u, priority}}
ntlds = ntlds + 1
else
table.insert(tlds[tld], {str_hash, u})
end
-- Special cases
if not u:get_protocol() == 'mailto' and not u:is_html_displayed() then
if u:is_obscured() then
insert_url(str_hash, u)
else
if u:get_user() then
insert_url(str_hash, u)
elseif u:is_subject() or u:is_phished() then
insert_url(str_hash, u)
end
end
table.insert(tlds[tld], {str_hash, u, priority})
end
end
end
Expand All @@ -737,41 +742,59 @@ exports.filter_specific_urls = function (urls, params)
local limit = params.limit
limit = limit - nres
if limit <= 0 then limit = 1 end
if neslds <= limit then
-- We can get urls based on their eslds
repeat
local item_found = false
for _,lurls in pairs(eslds) do
if #lurls > 0 then
local last = table.remove(lurls)
insert_url(last[1], last[2])
limit = limit - 1
item_found = true
end
end
until limit <= 0 or not item_found
if limit < 0 then limit = 0 end
if limit == 0 then
res = exports.values(res)
if params.task and not params.no_cache then
params.task:cache_set(cache_key, res)
end
return res
end
if ntlds <= limit then
while limit > 0 do
for _,lurls in pairs(tlds) do
-- Sort eSLDs and tlds
local function sort_stuff(tbl)
-- Sort according to max priority
table.sort(tbl, function(e1, e2)
-- Sort by priority so max priority is at the end
table.sort(e1, function(tr1, tr2)
return tr1[3] < tr2[3]
end)
table.sort(e2, function(tr1, tr2)
return tr1[3] < tr2[3]
end)
if e1[#e1][3] ~= e2[#e2][3] then
-- Sort by priority so max priority is at the beginning
return e1[#e1][3] > e2[#e2][3]
else
-- Prefer less urls to more urls per esld
return #e1 < #e2
end
end)
return tbl
end
eslds = sort_stuff(exports.values(eslds))
neslds = #eslds
if neslds <= limit then
-- Number of eslds < limit
repeat
local item_found = false
for _,lurls in ipairs(eslds) do
if #lurls > 0 then
local last = table.remove(lurls)
insert_url(last[1], last[2])
limit = limit - 1
item_found = true
end
end
end
until limit <= 0 or not item_found
res = exports.values(res)
if params.task and not params.no_cache then
Expand All @@ -780,38 +803,25 @@ exports.filter_specific_urls = function (urls, params)
return res
end
-- We need to sort tlds table first
local tlds_keys = {}
for k,_ in pairs(tlds) do table.insert(tlds_keys, k) end
table.sort(tlds_keys, function (t1, t2)
return #tlds[t1] < #tlds[t2]
end)
ntlds = #tlds_keys
for i=1,ntlds / 2 do
local tld1 = tlds[tlds_keys[i]]
local tld2 = tlds[tlds_keys[ntlds - i]]
if #tld1 > 0 then
local last = table.remove(tld1)
insert_url(last[1], last[2])
limit = limit - 1
end
if #tld2 > 0 then
local last = table.remove(tld2)
insert_url(last[1], last[2])
limit = limit - 1
end
tlds = sort_stuff(exports.values(tlds))
ntlds = #tlds
if limit <= 0 then
break
-- Number of tlds < limit
while limit > 0 do
for _,lurls in ipairs(tlds) do
if #lurls > 0 then
local last = table.remove(lurls)
insert_url(last[1], last[2])
limit = limit - 1
end
if limit == 0 then break end
end
end
res = exports.values(res)
if params.task and not params.no_cache then
params.task:cache_set(cache_key, res)
end
return res
end
Expand Down
6 changes: 3 additions & 3 deletions test/lua/unit/lua_util.extract_specific_urls.lua
Expand Up @@ -192,8 +192,8 @@ context("Lua util - extract_specific_urls", function()
local actual = util.extract_specific_urls({
task = task,
limit = 2,
esld_limit = 2,
limit = 1,
esld_limit = 1,
})
local actual_result = prepare_actual_result(actual)
Expand All @@ -202,7 +202,7 @@ context("Lua util - extract_specific_urls", function()
local s = logger.slog("case[%1] %2 =?= %3", i, expect, actual_result)
print(s) --]]
assert_equal("domain.com", actual_result[1], "checking that first url is the one with highest suspiciousness level")
assert_rspamd_table_eq({actual = actual_result, expect = {"domain.com"}})
end)
end)
Expand Down

0 comments on commit 2bccf65

Please sign in to comment.