Skip to content

Commit

Permalink
[Fix] Neural: Another bunch of fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vstakhov committed Jul 15, 2019
1 parent 437c520 commit 6d11758
Showing 1 changed file with 39 additions and 6 deletions.
45 changes: 39 additions & 6 deletions src/plugins/lua/neural.lua
Expand Up @@ -541,28 +541,47 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
else
local inputs, outputs = {}, {}

-- Used to show sparsed vectors in a convenient format (for debugging only)
--[[
local function debug_vec(t)
local ret = {}
for i,v in ipairs(t) do
if v ~= 0 then
ret[#ret + 1] = string.format('%d=%.2f', i, v)
end
end
return ret
end
]]--

-- Make training set by joining vectors
-- KANN automatically shuffles those samples
-- 1.0 is used for spam and -1.0 is used for ham
-- It implies that output layer can express that (e.g. tanh output)
for _,e in ipairs(spam_vec) do
inputs[#inputs + 1] = e
outputs[#outputs + 1] = {1.0}
--rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
end
for _,e in ipairs(ham_vec) do
inputs[#inputs + 1] = e
outputs[#outputs + 1] = {-1.0}
--rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
end

-- Called in child process
local function train()
local log_thresh = rule.train.max_iterations / 10
train_ann:train1(inputs, outputs, {
lr = rule.train.learning_rate,
max_epoch = rule.train.max_iterations,
cb = function(iter, train_cost, _)
if math.floor(iter / rule.train.max_iterations * 10) % 10 == 0 then
rspamd_logger.infox(rspamd_config, "ANN %s:%s: learned %s iterations, error: %s",
if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
rspamd_logger.infox(rspamd_config,
"ANN %s:%s: learned from %s redis key in %s iterations, error: %s",
rule.prefix, set.name,
ann_key,
iter, train_cost)
end
end
Expand All @@ -589,7 +608,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
)
else
rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
rule.prefix, set.name, ann_key)
rule.prefix, set.name, set.ann.redis_key)
end
end

Expand All @@ -608,8 +627,6 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
{ann_key, 'lock'}
)
else
rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes; redis key: %s',
rule.prefix, set.name, #data, ann_key)
local ann_data = rspamd_util.zstd_compress(data)
if not set.ann then
set.ann = {
Expand Down Expand Up @@ -637,6 +654,10 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
local ucl = require "ucl"
local profile_serialized = ucl.to_format(profile, 'json-compact', true)

rspamd_logger.infox(rspamd_config,
'trained ANN %s:%s, %s bytes; redis key: %s (old key %s)',
rule.prefix, set.name, #data, set.ann.redis_key, ann_key)

lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true},
redis_save_cb,
Expand Down Expand Up @@ -1131,8 +1152,20 @@ local function process_rules_settings()
rule.prefix, selt.name)
end

local function filter_symbols_predicate(sname)
local fl = rspamd_config:get_symbol_flags(sname)
if fl then
fl = lua_util.list_to_hash(fl)

return not (fl.nostat or fl.idempotent or fl.skip)
end

return false
end

-- Generic stuff
table.sort(selt.symbols)
table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols)))

selt.digest = lua_util.table_digest(selt.symbols)
selt.prefix = redis_ann_prefix(rule, selt.name)

Expand Down

0 comments on commit 6d11758

Please sign in to comment.