diff --git a/cql3/functions/aggregate_fcts.cc b/cql3/functions/aggregate_fcts.cc index f9e08fb93e27..306c3f23cd11 100644 --- a/cql3/functions/aggregate_fcts.cc +++ b/cql3/functions/aggregate_fcts.cc @@ -290,25 +290,27 @@ struct aggregate_type_for { using type = time_native_type::primary_type; }; +} // anonymous namespace + /** * Creates a COUNT function for the specified type. * - * @param inputType the function input type + * @param input_type the function input type * @return a COUNT function for the specified type. */ -template -static shared_ptr make_count_function() { +shared_ptr +aggregate_fcts::make_count_function(data_type input_type) { return make_shared( db::functions::stateless_aggregate_function{ .name = function_name::native_function("count"), .state_type = long_type, .result_type = long_type, - .argument_types = {data_type_for()}, + .argument_types = {input_type}, .initial_state = data_value(int64_t(0)).serialize(), .aggregation_function = ::make_shared( "count_step", long_type, - std::vector({long_type, data_type_for()}), + std::vector({long_type, input_type}), [] (std::span args) { if (!args[1]) { return args[0]; @@ -321,7 +323,6 @@ static shared_ptr make_count_function() { .state_reduction_function = make_internal_scalar_function("count_reducer", return_any_nonnull, [] (int64_t c1, int64_t c2) { return c1 + c2; }), }); } -} // Drops the first arg type from the types declaration (which denotes the accumulator) // in order to compute the actual type of given user-defined-aggregate (UDA) @@ -466,44 +467,6 @@ aggregate_fcts::make_min_function(data_type io_type) { void cql3::functions::add_agg_functions(declared_t& funcs) { auto declare = [&funcs] (shared_ptr f) { funcs.emplace(f->name(), f); }; - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - declare(make_count_function()); - - // FIXME: more count/min/max - declare(make_sum_function()); declare(make_sum_function()); declare(make_sum_function()); diff --git a/cql3/functions/aggregate_fcts.hh b/cql3/functions/aggregate_fcts.hh index 17627f33f718..3c77f05a9287 100644 --- a/cql3/functions/aggregate_fcts.hh +++ b/cql3/functions/aggregate_fcts.hh @@ -32,6 +32,10 @@ make_max_function(data_type io_type); /// The same as `make_min_function()' but with type provided in runtime. shared_ptr make_min_function(data_type io_type); + +/// count(col) function for the specified type +shared_ptr make_count_function(data_type input_type); + } } } diff --git a/cql3/functions/functions.cc b/cql3/functions/functions.cc index 8d1eda634aff..eaeb24467002 100644 --- a/cql3/functions/functions.cc +++ b/cql3/functions/functions.cc @@ -295,9 +295,7 @@ static shared_ptr get_dynamic_aggregate(const function_name &name, con } auto& arg = arg_types[0]; - if (arg->is_collection() || arg->is_tuple() || arg->is_user_type()) { - return aggregate_fcts::make_count_rows_function(); - } + return aggregate_fcts::make_count_function(arg); } else if (name.has_keyspace() ? name == COUNT_ROWS_NAME : name.name == COUNT_ROWS_NAME.name) { diff --git a/test/cql-pytest/test_aggregate.py b/test/cql-pytest/test_aggregate.py index e6937d154717..8bd474eb5e39 100644 --- a/test/cql-pytest/test_aggregate.py +++ b/test/cql-pytest/test_aggregate.py @@ -9,7 +9,7 @@ import pytest import math from decimal import Decimal -from util import new_test_table, unique_key_int, project +from util import new_test_table, unique_key_int, project, new_type from cassandra.util import Date @pytest.fixture(scope="module") @@ -72,9 +72,9 @@ def test_count_in_partition(cql, table1): cql.execute(stmt, [p, 3, 3]) assert [(3,)] == list(cql.execute(f"select count(*) from {table1} where p = {p}")) -# Using count(v) instead of count(*) allows counting only rows with a set +# Using count(v) instead of count(*) allows counting only rows with a non-NULL # value in v -def test_count_specific_column(cql, table1): +def test_count_specific_column(cql, test_keyspace, table1): p = unique_key_int() stmt = cql.prepare(f"insert into {table1} (p, c, v) values (?, ?, ?)") cql.execute(stmt, [p, 1, 1]) @@ -83,6 +83,16 @@ def test_count_specific_column(cql, table1): cql.execute(stmt, [p, 4, None]) assert [(4,)] == list(cql.execute(f"select count(*) from {table1} where p = {p}")) assert [(3,)] == list(cql.execute(f"select count(v) from {table1} where p = {p}")) + # Check with non-scalar types too, reproduces #14198 + with new_type(cql, test_keyspace, '(i int, t text)') as udt: + with new_test_table(cql, test_keyspace, f'p int, c int, fli frozen>, tup tuple, udt {udt}, PRIMARY KEY (p, c)') as table2: + cql.execute(f'INSERT INTO {table2}(p, c, fli, tup, udt) VALUES({p}, 5, [1, 2], (3, 4), {{i: 3, t: \'text\'}})') + cql.execute(f'INSERT INTO {table2}(p, c) VALUES({p}, 6)') + + assert [(1,)] == list(cql.execute(f"select count(fli) from {table2} where p = {p}")) + assert [(1,)] == list(cql.execute(f"select count(tup) from {table2} where p = {p}")) + assert [(1,)] == list(cql.execute(f"select count(udt) from {table2} where p = {p}")) + assert [(2,)] == list(cql.execute(f"select count(*) from {table2} where p = {p}")) # COUNT can be combined with GROUP BY to count separately for each partition # or row.