Skip to content

Commit

Permalink
cql3: functions: fix count(col) for non-scalar types
Browse files Browse the repository at this point in the history
count(col), unlike count(*), does not count rows for which col is NULL.
However, if col's data type is not a scalar (e.g. a collection, tuple,
or user-defined type) it behaves like count(*), counting NULLs too.

The cause is that get_dynamic_aggregate() converts count() to
the count(*) version. It works for scalars because get_dynamic_aggregate()
intentionally fails to match scalar arguments, and functions::get() then
matches the arguments against the pre-declared count functions.

As we can only pre-declare count(scalar) (there's an infinite number
of non-scalar types), we change the approach to be the same as min/max:
we make count() a generic function. In fact count(col) is much better
as a generic function, as it only examines its input to see if it is
NULL.

A unit test is added. It passes with Cassandra as well.

Fixes #14198.

Closes #14199
  • Loading branch information
avikivity authored and nyh committed Jun 13, 2023
1 parent e0855b1 commit 78f4ee3
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 50 deletions.
51 changes: 7 additions & 44 deletions cql3/functions/aggregate_fcts.cc
Expand Up @@ -290,25 +290,27 @@ struct aggregate_type_for<time_native_type> {
using type = time_native_type::primary_type;
};

} // anonymous namespace

/**
* Creates a COUNT function for the specified type.
*
* @param inputType the function input type
* @param input_type the function input type
* @return a COUNT function for the specified type.
*/
template <typename Type>
static shared_ptr<aggregate_function> make_count_function() {
shared_ptr<aggregate_function>
aggregate_fcts::make_count_function(data_type input_type) {
return make_shared<db::functions::aggregate_function>(
db::functions::stateless_aggregate_function{
.name = function_name::native_function("count"),
.state_type = long_type,
.result_type = long_type,
.argument_types = {data_type_for<Type>()},
.argument_types = {input_type},
.initial_state = data_value(int64_t(0)).serialize(),
.aggregation_function = ::make_shared<internal_scalar_function>(
"count_step",
long_type,
std::vector<data_type>({long_type, data_type_for<Type>()}),
std::vector<data_type>({long_type, input_type}),
[] (std::span<const bytes_opt> args) {
if (!args[1]) {
return args[0];
Expand All @@ -321,7 +323,6 @@ static shared_ptr<aggregate_function> make_count_function() {
.state_reduction_function = make_internal_scalar_function("count_reducer", return_any_nonnull, [] (int64_t c1, int64_t c2) { return c1 + c2; }),
});
}
}

// Drops the first arg type from the types declaration (which denotes the accumulator)
// in order to compute the actual type of given user-defined-aggregate (UDA)
Expand Down Expand Up @@ -466,44 +467,6 @@ aggregate_fcts::make_min_function(data_type io_type) {
void cql3::functions::add_agg_functions(declared_t& funcs) {
auto declare = [&funcs] (shared_ptr<function> f) { funcs.emplace(f->name(), f); };

declare(make_count_function<int8_t>());

declare(make_count_function<int16_t>());

declare(make_count_function<int32_t>());

declare(make_count_function<int64_t>());

declare(make_count_function<utils::multiprecision_int>());

declare(make_count_function<big_decimal>());

declare(make_count_function<float>());

declare(make_count_function<double>());

declare(make_count_function<sstring>());

declare(make_count_function<ascii_native_type>());

declare(make_count_function<simple_date_native_type>());

declare(make_count_function<db_clock::time_point>());

declare(make_count_function<timeuuid_native_type>());

declare(make_count_function<time_native_type>());

declare(make_count_function<utils::UUID>());

declare(make_count_function<bytes>());

declare(make_count_function<bool>());

declare(make_count_function<net::inet_address>());

// FIXME: more count/min/max

declare(make_sum_function<int8_t>());
declare(make_sum_function<int16_t>());
declare(make_sum_function<int32_t>());
Expand Down
4 changes: 4 additions & 0 deletions cql3/functions/aggregate_fcts.hh
Expand Up @@ -32,6 +32,10 @@ make_max_function(data_type io_type);
/// The same as `make_min_function()' but with type provided in runtime.
shared_ptr<aggregate_function>
make_min_function(data_type io_type);

/// count(col) function for the specified type
shared_ptr<aggregate_function> make_count_function(data_type input_type);

}
}
}
4 changes: 1 addition & 3 deletions cql3/functions/functions.cc
Expand Up @@ -295,9 +295,7 @@ static shared_ptr<function> get_dynamic_aggregate(const function_name &name, con
}

auto& arg = arg_types[0];
if (arg->is_collection() || arg->is_tuple() || arg->is_user_type()) {
return aggregate_fcts::make_count_rows_function();
}
return aggregate_fcts::make_count_function(arg);
} else if (name.has_keyspace()
? name == COUNT_ROWS_NAME
: name.name == COUNT_ROWS_NAME.name) {
Expand Down
16 changes: 13 additions & 3 deletions test/cql-pytest/test_aggregate.py
Expand Up @@ -9,7 +9,7 @@
import pytest
import math
from decimal import Decimal
from util import new_test_table, unique_key_int, project
from util import new_test_table, unique_key_int, project, new_type
from cassandra.util import Date

@pytest.fixture(scope="module")
Expand Down Expand Up @@ -72,9 +72,9 @@ def test_count_in_partition(cql, table1):
cql.execute(stmt, [p, 3, 3])
assert [(3,)] == list(cql.execute(f"select count(*) from {table1} where p = {p}"))

# Using count(v) instead of count(*) allows counting only rows with a set
# Using count(v) instead of count(*) allows counting only rows with a non-NULL
# value in v
def test_count_specific_column(cql, table1):
def test_count_specific_column(cql, test_keyspace, table1):
p = unique_key_int()
stmt = cql.prepare(f"insert into {table1} (p, c, v) values (?, ?, ?)")
cql.execute(stmt, [p, 1, 1])
Expand All @@ -83,6 +83,16 @@ def test_count_specific_column(cql, table1):
cql.execute(stmt, [p, 4, None])
assert [(4,)] == list(cql.execute(f"select count(*) from {table1} where p = {p}"))
assert [(3,)] == list(cql.execute(f"select count(v) from {table1} where p = {p}"))
# Check with non-scalar types too, reproduces #14198
with new_type(cql, test_keyspace, '(i int, t text)') as udt:
with new_test_table(cql, test_keyspace, f'p int, c int, fli frozen<list<int>>, tup tuple<int, bigint>, udt {udt}, PRIMARY KEY (p, c)') as table2:
cql.execute(f'INSERT INTO {table2}(p, c, fli, tup, udt) VALUES({p}, 5, [1, 2], (3, 4), {{i: 3, t: \'text\'}})')
cql.execute(f'INSERT INTO {table2}(p, c) VALUES({p}, 6)')

assert [(1,)] == list(cql.execute(f"select count(fli) from {table2} where p = {p}"))
assert [(1,)] == list(cql.execute(f"select count(tup) from {table2} where p = {p}"))
assert [(1,)] == list(cql.execute(f"select count(udt) from {table2} where p = {p}"))
assert [(2,)] == list(cql.execute(f"select count(*) from {table2} where p = {p}"))

# COUNT can be combined with GROUP BY to count separately for each partition
# or row.
Expand Down

0 comments on commit 78f4ee3

Please sign in to comment.