Skip to content

Commit

Permalink
[Feature] rspamadm clickhouse neural_train subcommand
Browse files Browse the repository at this point in the history
  • Loading branch information
fatalbanana committed Jan 15, 2021
1 parent 5f7e2d6 commit 4cff597
Showing 1 changed file with 186 additions and 9 deletions.
195 changes: 186 additions & 9 deletions lualib/rspamadm/clickhouse.lua
Expand Up @@ -16,9 +16,13 @@ limitations under the License.

local argparse = require "argparse"
local lua_clickhouse = require "lua_clickhouse"
local lua_util = require "lua_util"
local rspamd_http = require "rspamd_http"
local rspamd_upstream_list = require "rspamd_upstream_list"
local ucl = require "ucl"

local E = {}

-- Define command line options
local parser = argparse()
:name 'rspamadm clickhouse'
Expand Down Expand Up @@ -80,6 +84,44 @@ neural_profile:option '--settings-id'
:argname('settings_id')
:default('')

local neural_train = parser:command 'neural_train'
:description 'Train neural using data from Clickhouse'
neural_train:option '--days'
:description 'Number of days to query data for'
:argname('days')
:default('7')
neural_train:option '--column-name-digest'
:description 'Name of neural profile digest column in Clickhouse'
:argname('column_name_digest')
:default('NeuralDigest')
neural_train:option '--column-name-vector'
:description 'Name of neural training vector column in Clickhouse'
:argname('column_name_vector')
:default('NeuralMpack')
neural_train:option '--limit -l'
:description 'Maximum rows to fetch per day'
:argname('limit')
neural_train:option '--profile -p'
:description 'Profile to use for training'
:argname('profile')
:default('default')
neural_train:option '--rule -r'
:description 'Rule to train'
:argname('rule')
:default('default')
neural_train:option '--spam -s'
:description 'WHERE clause to use for spam'
:argname('spam')
:default("Action == 'reject'")
neural_train:option '--ham -h'
:description 'WHERE clause to use for ham'
:argname('ham')
:default('Score < 0')
neural_train:option '--url -u'
:description 'URL to use for training'
:argname('url')
:default('http://127.0.0.1:11334/plugins/neural/learn')

local http_params = {
config = rspamd_config,
ev_base = rspamadm_ev_base,
Expand All @@ -97,6 +139,18 @@ local function load_config(config_file)
end
end

local function days_list(days)
-- Create list of days to query starting with yesterday
local query_days = {}
local previous_date = os.time() - 86400
local num_days = tonumber(days)
for _ = 1, num_days do
table.insert(query_days, os.date('%Y-%m-%d', previous_date))
previous_date = previous_date - 86400
end
return query_days
end

local function get_excluded_symbols(known_symbols, correlations, seen_total)
-- Walk results once to collect all symbols & count ocurrences

Expand Down Expand Up @@ -202,15 +256,7 @@ local function handle_neural_profile(args)
end
end

-- Create list of days to query starting with yesterday
local query_days = {}
local previous_date = os.time() - 86400
local num_days = tonumber(args.days)
for _ = 1, num_days do
table.insert(query_days, os.date('%Y-%m-%d', previous_date))
previous_date = previous_date - 86400
end

local query_days = days_list(args.days)
local conditions = {}
table.insert(conditions, string.format("SettingsId = '%s'", args.settings_id))
local limit = ''
Expand Down Expand Up @@ -263,8 +309,139 @@ local function handle_neural_profile(args)
io.stdout:write(ucl.to_format(json_output, 'json'))
end

local function post_neural_training(url, rule, spam_rows, ham_rows)
-- Prepare JSON payload
local payload = ucl.to_format(
{
ham_vec = ham_rows,
rule = rule,
spam_vec = spam_rows,
}, 'json')

-- POST the payload
local err, response = rspamd_http.request({
body = payload,
config = rspamd_config,
ev_base = rspamadm_ev_base,
log_obj = rspamd_config,
resolver = rspamadm_dns_resolver,
session = rspamadm_session,
url = url,
})

if err then
io.stderr:write(string.format('HTTP error: %s\n', err))
os.exit(1)
end
if response.code ~= 200 then
io.stderr:write(string.format('bad HTTP code: %d\n', response.code))
os.exit(1)
end
io.stdout:write(string.format('%s\n', response.content))
end

local function handle_neural_train(args)

local this_where -- which class of messages are we collecting data for
local ham_rows, spam_rows = {}, {}
local want_spam, want_ham = true, true -- keep collecting while true
local ucl_parser = ucl.parser()

-- Try find profile in config
local neural_opts = rspamd_config:get_all_opt('neural')
local symbols_profile = ((((neural_opts or E).rules or E)[args.rule] or E).profile or E)[args.profile]
if not symbols_profile then
io.stderr:write(string.format("Couldn't find profile %s in rule %s\n", args.profile, args.rule))
os.exit(1)
end
-- Try find max_trains
local max_trains = (neural_opts.rules[args.rule].train or E).max_trains or 1000

-- Callback used to process rows from Clickhouse
local function process_row(r)
local destination -- which table to collect this information in
if this_where == args.ham then
destination = ham_rows
if #destination >= max_trains then
want_ham = false
return
end
else
destination = spam_rows
if #destination >= max_trains then
want_spam = false
return
end
end
local ok, err = ucl_parser:parse_string(r[args.column_name_vector], 'msgpack')
if not ok then
io.stderr:write(string.format("Couldn't parse [%s]: %s", r[args.column_name_vector], err))
os.exit(1)
end
table.insert(destination, ucl_parser:get_object())
end

-- Generate symbols digest
local symbols_digest = lua_util.table_digest(symbols_profile)
-- Create list of days to query data for
local query_days = days_list(args.days)
-- Set value for limit
local limit = ''
local num_limit = tonumber(args.limit)
if num_limit then
limit = string.format(' LIMIT %d', num_limit) -- Contains leading space
end
-- Prepare query elements
local conditions = {string.format("%s = '%s'", args.column_name_digest, symbols_digest)}
local query_fmt = 'SELECT %s FROM rspamd WHERE %s%s'

-- Run queries
for _, the_where in ipairs({args.ham, args.spam}) do
-- Inform callback which group of vectors we're collecting
this_where = the_where
table.insert(conditions, the_where) -- should be 2nd from last condition
-- Loop over days and try collect data
for _, query_day in ipairs(query_days) do
-- Break the loop if we have enough data already
if this_where == args.ham then
if not want_ham then
break
end
else
if not want_spam then
break
end
end
-- Date should be the last condition
table.insert(conditions, string.format("Date = '%s'", query_day))
local query = string.format(query_fmt, args.column_name_vector, table.concat(conditions, ' AND '), limit)
local upstream = args.upstream:get_upstream_round_robin()
local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row)
if err ~= nil then
io.stderr:write(string.format('Error querying Clickhouse: %s\n', err))
os.exit(1)
end
conditions[#conditions] = nil -- remove Date condition
end
conditions[#conditions] = nil -- remove spam/ham condition
end

-- Make sure we collected enough data for training
if #ham_rows < max_trains then
io.stderr:write(string.format('Insufficient ham rows: %d/%d\n', #ham_rows, max_trains))
os.exit(1)
end
if #spam_rows < max_trains then
io.stderr:write(string.format('Insufficient spam rows: %d/%d\n', #spam_rows, max_trains))
os.exit(1)
end

return post_neural_training(args.url, args.rule, spam_rows, ham_rows)
end

local command_handlers = {
neural_profile = handle_neural_profile,
neural_train = handle_neural_train,
}

local function handler(args)
Expand Down

0 comments on commit 4cff597

Please sign in to comment.