Skip to content

Commit

Permalink
Update after review
Browse files Browse the repository at this point in the history
  • Loading branch information
jedelbo committed Jan 26, 2024
1 parent e3cfbdd commit d82eca7
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 127 deletions.
81 changes: 49 additions & 32 deletions src/realm/parser/driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ inline T string_to(const std::string& s)
return value;
}

template <>
inline Decimal128 string_to<Decimal128>(const std::string& s)
{
Decimal128 value(s);
if (value.is_nan()) {
throw InvalidQueryArgError(util::format("Cannot convert '%1' to a %2", s, get_type_name<Decimal128>()));
}
return value;
}

class MixedArguments : public query_parser::Arguments {
public:
using Arg = mpark::variant<Mixed, std::vector<Mixed>>;
Expand Down Expand Up @@ -261,12 +271,12 @@ class MixedArguments : public query_parser::Arguments {
static_assert(std::is_same_v<mpark::variant_alternative_t<1, Arg>, std::vector<Mixed>>);
return m_args[n].index() == 1;
}
DataType type_for_argument(size_t n)
DataType type_for_argument(size_t n) final
{
return mixed_for_argument(n).get_type();
}

Mixed mixed_for_argument(size_t n) override
Mixed mixed_for_argument(size_t n) final
{
Arguments::verify_ndx(n);
if (is_argument_list(n)) {
Expand Down Expand Up @@ -1196,14 +1206,14 @@ void ConstantNode::decode_b64()
{
const size_t encoded_size = text.size() - 5;
size_t buffer_size = util::base64_decoded_size(encoded_size);
decode_buffer.resize(buffer_size);
m_decode_buffer.resize(buffer_size);
StringData window(text.c_str() + 4, encoded_size);
util::Optional<size_t> decoded_size = util::base64_decode(window, decode_buffer.data(), buffer_size);
util::Optional<size_t> decoded_size = util::base64_decode(window, m_decode_buffer.data(), buffer_size);
if (!decoded_size) {
throw SyntaxError("Invalid base64 value");
}
REALM_ASSERT_DEBUG_EX(*decoded_size <= encoded_size, *decoded_size, encoded_size);
decode_buffer.resize(*decoded_size); // truncate
m_decode_buffer.resize(*decoded_size); // truncate
}

Mixed ConstantNode::get_value()
Expand All @@ -1227,7 +1237,7 @@ Mixed ConstantNode::get_value()
return StringData(text.data() + 1, text.size() - 2);
case Type::STRING_BASE64:
decode_b64();
return StringData(decode_buffer.data(), decode_buffer.size());
return StringData(m_decode_buffer.data(), m_decode_buffer.size());
case Type::TIMESTAMP: {
auto s = text;
int64_t seconds;
Expand Down Expand Up @@ -1293,7 +1303,7 @@ Mixed ConstantNode::get_value()
}
case BINARY_BASE64:
decode_b64();
return BinaryData(decode_buffer.data(), decode_buffer.size());
return BinaryData(m_decode_buffer.data(), m_decode_buffer.size());
}
return {};
}
Expand Down Expand Up @@ -1329,26 +1339,30 @@ std::unique_ptr<Subexpr> ConstantNode::visit(ParserDriver* drv, DataType hint)
print_pretty_objlink(value.get<ObjLink>(), drv->m_base_table->get_parent_group()));
}
else {
explain_value_message = util::format("argument %1 of type '%2'", explain_value_message,
get_data_type_name(value.get_type()));
explain_value_message = util::format("argument %1 with value '%2'", explain_value_message, value);
if (!(m_target_table || Mixed::data_types_are_comparable(value.get_type(), hint) ||
Mixed::is_numeric(hint) || (value.is_type(type_String) && hint == type_TypeOfValue))) {
throw InvalidQueryArgError(
util::format("Cannot compare %1 to a %2", explain_value_message, get_data_type_name(hint)));
}
}
}
}
else {
value = get_value();
}

if (target_table.length()) {
if (m_target_table) {
// There is a table name set. This must be an ObjLink
const Group* g = drv->m_base_table->get_parent_group();
auto table = g->get_table(target_table);
auto table = g->get_table(m_target_table);
if (!table) {
// Perhaps class prefix is missing
Group::TableNameBuffer buffer;
table = g->get_table(Group::class_name_to_table_name(target_table, buffer));
table = g->get_table(Group::class_name_to_table_name(m_target_table, buffer));
}
if (!table) {
throw InvalidQueryError(util::format("Unknown object type '%1'", target_table));
throw InvalidQueryError(util::format("Unknown object type '%1'", m_target_table));
}
auto obj_key = table->find_primary_key(value);
value = ObjLink(table->get_key(), ObjKey(obj_key));
Expand Down Expand Up @@ -1385,26 +1399,29 @@ std::unique_ptr<Subexpr> ConstantNode::visit(ParserDriver* drv, DataType hint)
break;
case type_Double: {
auto double_val = value.get_double();

if (hint == type_Int) {
int64_t int_val = int64_t(double_val);
// Only return an integer if it precisely represents val
if (double(int_val) == double_val)
ret = std::make_unique<Value<int64_t>>(int_val);
}
else if (hint == type_Float) {
ret = std::make_unique<Value<float>>(double_val);
}
else if (hint == type_Decimal) {
ret = std::make_unique<Value<Decimal128>>(Decimal128(text));
if (std::isinf(double_val) && (!Mixed::is_numeric(hint) || hint == type_Int)) {
throw InvalidQueryError(util::format("Infinity not supported for %1", get_data_type_name(hint)));
}
if (!ret) {
constexpr auto inf = std::numeric_limits<double>::infinity();
bool is_infinity = (double_val == inf || double_val == -inf);
if (is_infinity && hint != type_Double) {
throw InvalidQueryError(util::format("Infinity not supported for %1", get_data_type_name(hint)));

switch (hint) {
case type_Float:
ret = std::make_unique<Value<float>>(double_val);
break;
case type_Decimal:
ret = std::make_unique<Value<Decimal128>>(Decimal128(text));
break;
case type_Int: {
int64_t int_val = int64_t(double_val);
// Only return an integer if it precisely represents val
if (double(int_val) == double_val) {
ret = std::make_unique<Value<int64_t>>(int_val);
break;
}
[[fallthrough]];
}
ret = std::make_unique<Value<double>>(double_val);
default:
ret = std::make_unique<Value<double>>(double_val);
break;
}
break;
}
Expand All @@ -1421,7 +1438,7 @@ std::unique_ptr<Subexpr> ConstantNode::visit(ParserDriver* drv, DataType hint)
ret = std::make_unique<Value<double>>(string_to<double>(str));
break;
case type_Decimal:
ret = std::make_unique<Value<Decimal128>>(Decimal128(str));
ret = std::make_unique<Value<Decimal128>>(string_to<Decimal128>(str));
break;
default:
if (hint == type_TypeOfValue) {
Expand Down
9 changes: 4 additions & 5 deletions src/realm/parser/driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,18 +191,17 @@ class ConstantNode : public ValueNode {
}
void add_table(std::string table_name)
{
target_table = table_name.substr(1, table_name.size() - 2);
m_target_table = table_name.substr(1, table_name.size() - 2);
}

std::unique_ptr<ConstantMixedList> copy_list_of_args(std::vector<Mixed>&);
std::unique_ptr<Subexpr> copy_arg(ParserDriver*, DataType, size_t, DataType, std::string&);
std::unique_ptr<Subexpr> visit(ParserDriver*, DataType) override;
Mixed get_value();
util::Optional<ExpressionComparisonType> m_comp_type;
std::string target_table;

private:
std::string decode_buffer;
std::string m_decode_buffer;
std::optional<ExpressionComparisonType> m_comp_type;
std::optional<std::string> m_target_table;
void decode_b64();
};

Expand Down
Loading

0 comments on commit d82eca7

Please sign in to comment.