diff --git a/lualib/rspamadm/clickhouse.lua b/lualib/rspamadm/clickhouse.lua index 2ca4eab184..711437c944 100644 --- a/lualib/rspamadm/clickhouse.lua +++ b/lualib/rspamadm/clickhouse.lua @@ -16,9 +16,13 @@ limitations under the License. local argparse = require "argparse" local lua_clickhouse = require "lua_clickhouse" +local lua_util = require "lua_util" +local rspamd_http = require "rspamd_http" local rspamd_upstream_list = require "rspamd_upstream_list" local ucl = require "ucl" +local E = {} + -- Define command line options local parser = argparse() :name 'rspamadm clickhouse' @@ -80,6 +84,44 @@ neural_profile:option '--settings-id' :argname('settings_id') :default('') +local neural_train = parser:command 'neural_train' + :description 'Train neural using data from Clickhouse' +neural_train:option '--days' + :description 'Number of days to query data for' + :argname('days') + :default('7') +neural_train:option '--column-name-digest' + :description 'Name of neural profile digest column in Clickhouse' + :argname('column_name_digest') + :default('NeuralDigest') +neural_train:option '--column-name-vector' + :description 'Name of neural training vector column in Clickhouse' + :argname('column_name_vector') + :default('NeuralMpack') +neural_train:option '--limit -l' + :description 'Maximum rows to fetch per day' + :argname('limit') +neural_train:option '--profile -p' + :description 'Profile to use for training' + :argname('profile') + :default('default') +neural_train:option '--rule -r' + :description 'Rule to train' + :argname('rule') + :default('default') +neural_train:option '--spam -s' + :description 'WHERE clause to use for spam' + :argname('spam') + :default("Action == 'reject'") +neural_train:option '--ham -h' + :description 'WHERE clause to use for ham' + :argname('ham') + :default('Score < 0') +neural_train:option '--url -u' + :description 'URL to use for training' + :argname('url') + :default('http://127.0.0.1:11334/plugins/neural/learn') + local http_params = { config = rspamd_config, ev_base = rspamadm_ev_base, @@ -97,6 +139,18 @@ local function load_config(config_file) end end +local function days_list(days) + -- Create list of days to query starting with yesterday + local query_days = {} + local previous_date = os.time() - 86400 + local num_days = tonumber(days) + for _ = 1, num_days do + table.insert(query_days, os.date('%Y-%m-%d', previous_date)) + previous_date = previous_date - 86400 + end + return query_days +end + local function get_excluded_symbols(known_symbols, correlations, seen_total) -- Walk results once to collect all symbols & count ocurrences @@ -202,15 +256,7 @@ local function handle_neural_profile(args) end end - -- Create list of days to query starting with yesterday - local query_days = {} - local previous_date = os.time() - 86400 - local num_days = tonumber(args.days) - for _ = 1, num_days do - table.insert(query_days, os.date('%Y-%m-%d', previous_date)) - previous_date = previous_date - 86400 - end - + local query_days = days_list(args.days) local conditions = {} table.insert(conditions, string.format("SettingsId = '%s'", args.settings_id)) local limit = '' @@ -263,8 +309,139 @@ local function handle_neural_profile(args) io.stdout:write(ucl.to_format(json_output, 'json')) end +local function post_neural_training(url, rule, spam_rows, ham_rows) + -- Prepare JSON payload + local payload = ucl.to_format( + { + ham_vec = ham_rows, + rule = rule, + spam_vec = spam_rows, + }, 'json') + + -- POST the payload + local err, response = rspamd_http.request({ + body = payload, + config = rspamd_config, + ev_base = rspamadm_ev_base, + log_obj = rspamd_config, + resolver = rspamadm_dns_resolver, + session = rspamadm_session, + url = url, + }) + + if err then + io.stderr:write(string.format('HTTP error: %s\n', err)) + os.exit(1) + end + if response.code ~= 200 then + io.stderr:write(string.format('bad HTTP code: %d\n', response.code)) + os.exit(1) + end + io.stdout:write(string.format('%s\n', response.content)) +end + +local function handle_neural_train(args) + + local this_where -- which class of messages are we collecting data for + local ham_rows, spam_rows = {}, {} + local want_spam, want_ham = true, true -- keep collecting while true + local ucl_parser = ucl.parser() + + -- Try find profile in config + local neural_opts = rspamd_config:get_all_opt('neural') + local symbols_profile = ((((neural_opts or E).rules or E)[args.rule] or E).profile or E)[args.profile] + if not symbols_profile then + io.stderr:write(string.format("Couldn't find profile %s in rule %s\n", args.profile, args.rule)) + os.exit(1) + end + -- Try find max_trains + local max_trains = (neural_opts.rules[args.rule].train or E).max_trains or 1000 + + -- Callback used to process rows from Clickhouse + local function process_row(r) + local destination -- which table to collect this information in + if this_where == args.ham then + destination = ham_rows + if #destination >= max_trains then + want_ham = false + return + end + else + destination = spam_rows + if #destination >= max_trains then + want_spam = false + return + end + end + local ok, err = ucl_parser:parse_string(r[args.column_name_vector], 'msgpack') + if not ok then + io.stderr:write(string.format("Couldn't parse [%s]: %s", r[args.column_name_vector], err)) + os.exit(1) + end + table.insert(destination, ucl_parser:get_object()) + end + + -- Generate symbols digest + local symbols_digest = lua_util.table_digest(symbols_profile) + -- Create list of days to query data for + local query_days = days_list(args.days) + -- Set value for limit + local limit = '' + local num_limit = tonumber(args.limit) + if num_limit then + limit = string.format(' LIMIT %d', num_limit) -- Contains leading space + end + -- Prepare query elements + local conditions = {string.format("%s = '%s'", args.column_name_digest, symbols_digest)} + local query_fmt = 'SELECT %s FROM rspamd WHERE %s%s' + + -- Run queries + for _, the_where in ipairs({args.ham, args.spam}) do + -- Inform callback which group of vectors we're collecting + this_where = the_where + table.insert(conditions, the_where) -- should be 2nd from last condition + -- Loop over days and try collect data + for _, query_day in ipairs(query_days) do + -- Break the loop if we have enough data already + if this_where == args.ham then + if not want_ham then + break + end + else + if not want_spam then + break + end + end + -- Date should be the last condition + table.insert(conditions, string.format("Date = '%s'", query_day)) + local query = string.format(query_fmt, args.column_name_vector, table.concat(conditions, ' AND '), limit) + local upstream = args.upstream:get_upstream_round_robin() + local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row) + if err ~= nil then + io.stderr:write(string.format('Error querying Clickhouse: %s\n', err)) + os.exit(1) + end + conditions[#conditions] = nil -- remove Date condition + end + conditions[#conditions] = nil -- remove spam/ham condition + end + + -- Make sure we collected enough data for training + if #ham_rows < max_trains then + io.stderr:write(string.format('Insufficient ham rows: %d/%d\n', #ham_rows, max_trains)) + os.exit(1) + end + if #spam_rows < max_trains then + io.stderr:write(string.format('Insufficient spam rows: %d/%d\n', #spam_rows, max_trains)) + os.exit(1) + end + + return post_neural_training(args.url, args.rule, spam_rows, ham_rows) +end + local command_handlers = { neural_profile = handle_neural_profile, + neural_train = handle_neural_train, } local function handler(args)