diff --git a/cmake/external/date.cmake b/cmake/external/date.cmake new file mode 100644 index 00000000000..fdefbfd920d --- /dev/null +++ b/cmake/external/date.cmake @@ -0,0 +1,16 @@ +include(ExternalProject) +set(DATE_PREFIX ${CMAKE_BINARY_DIR}/date) +set(DATE_SOURCE_DIR "${THIRD_PARTY_DIR}/date") +set(DATE_INCLUDE_DIR "${DATE_SOURCE_DIR}/include") +ExternalProject_Add( + date + PREFIX ${DATE_PREFIX} + SOURCE_DIR ${DATE_SOURCE_DIR} + GIT_REPOSITORY https://github.com/HowardHinnant/date.git + GIT_TAG e7e1482087f58913b80a20b04d5c58d9d6d90155 + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory ${DATE_INCLUDE_DIR}/ ${THIRD_PARTY_INCLUDE_DIR}/ +) + +add_dependencies(libshogun date) \ No newline at end of file diff --git a/src/interfaces/swig/IO.i b/src/interfaces/swig/IO.i index ffc058f5e8a..fc6d4b6146b 100644 --- a/src/interfaces/swig/IO.i +++ b/src/interfaces/swig/IO.i @@ -79,6 +79,7 @@ %shared_ptr(shogun::StreamingFileFromDenseFeatures) #endif +%include %include %include %include diff --git a/src/interfaces/swig/IO_includes.i b/src/interfaces/swig/IO_includes.i index d1d1fcbc8c8..afd6170a471 100644 --- a/src/interfaces/swig/IO_includes.i +++ b/src/interfaces/swig/IO_includes.i @@ -1,4 +1,5 @@ %{ +#include #include #include #include diff --git a/src/shogun/CMakeLists.txt b/src/shogun/CMakeLists.txt index 78a40e236a9..59a1c0b32df 100644 --- a/src/shogun/CMakeLists.txt +++ b/src/shogun/CMakeLists.txt @@ -522,6 +522,12 @@ SHOGUN_DEPENDENCIES( SCOPE PRIVATE CONFIG_FLAG HAVE_ARPREC) +# Date external lib +include(external/date) +SHOGUN_INCLUDE_DIRS(SCOPE PUBLIC + $ + $) + ############################ HMM OPTION(USE_HMMDEBUG "HMM debug mode" OFF) OPTION(USE_HMMCACHE "HMM cache" ON) diff --git a/src/shogun/io/ARFFFile.cpp b/src/shogun/io/ARFFFile.cpp new file mode 100644 index 00000000000..542ffef436e --- /dev/null +++ b/src/shogun/io/ARFFFile.cpp @@ -0,0 +1,729 @@ +/* + * This software is distributed under BSD 3-clause license (see LICENSE file). + * + * Authors: Gil Hoben + */ + +#include +#include +#include + +#include + +using namespace shogun; +using namespace shogun::arff_detail; + +/** + * Visitor pattern to reserve memory for a std::vector + * wrapped in a variant class. + */ +struct VectorResizeVisitor +{ + VectorResizeVisitor(size_t size) : m_size(size){}; + template + void operator()(std::vector& v) const noexcept + { + v.reserve(m_size); + } + size_t m_size; +}; + +/** + * Visitor pattern to determine size of a std::vector + * wrapped in a variant class. + */ +struct VectorSizeVisitor +{ + template + size_t operator()(const std::vector& v) const noexcept + { + return v.size(); + } +}; + +template +T buffer_to_type(const std::string& buffer) +{ + error( + "No conversion from {} to {}!\n", buffer.c_str(), + demangled_type()); +} +template <> +int8_t buffer_to_type(const std::string& buffer) +{ + return static_cast(std::stoi(buffer)); +} +template <> +int16_t buffer_to_type(const std::string& buffer) +{ + return static_cast(std::stoi(buffer)); +} +template <> +int32_t buffer_to_type(const std::string& buffer) +{ + return std::stoi(buffer); +} +template <> +int64_t buffer_to_type(const std::string& buffer) +{ + return std::stoll(buffer); +} +template <> +float32_t buffer_to_type(const std::string& buffer) +{ + return std::stof(buffer); +} +template <> +float64_t buffer_to_type(const std::string& buffer) +{ + return std::stod(buffer); +} +template <> +floatmax_t buffer_to_type(const std::string& buffer) +{ + return std::stold(buffer); +} + +template +void ARFFDeserializer::read_helper() +{ + std::vector, std::vector>>> + data_vectors; + m_line_number = 0; + m_row_count = 0; + m_file_done = false; + auto read_comment = [this]() { + if (to_lower(m_current_line.substr(0, 1)) == m_comment_string) + m_comments.push_back(m_current_line.substr(1, std::string::npos)); + else if ( + to_lower(m_current_line.substr(0, m_relation_string.size())) == + m_relation_string) + m_state = true; + }; + auto check_comment = []() { return true; }; + process_chunk(read_comment, check_comment, false); + + auto read_relation = [this]() { + if (to_lower(m_current_line.substr(0, m_relation_string.size())) == + m_relation_string) + m_relation = remove_whitespace( + m_current_line.substr(m_relation_string.size())); + else if ( + to_lower(m_current_line.substr(0, m_attribute_string.size())) == + m_attribute_string) + m_state = true; + }; + // a relation has to be defined + auto check_relation = [this]() { return !m_relation.empty(); }; + process_chunk(read_relation, check_relation, true); + + // parse the @attributes section + auto read_attributes = [this, &data_vectors]() { + if (to_lower(m_current_line.substr(0, m_attribute_string.size())) == + m_attribute_string) + { + std::string name, type; + auto inner_string = + m_current_line.substr(m_attribute_string.size()); + left_trim(inner_string, [](const auto& val) { + return !std::isspace(val); + }); + auto it = inner_string.begin(); + if (is_part_of(*it, "\"\'")) + { + auto quote_type = *it; + ++it; + auto begin = it; + while (*it != quote_type && it != inner_string.end()) + ++it; + if (it == inner_string.end()) + error( + "Encountered unbalanced parenthesis in attribute " + "declaration on line {}: \"{}\"\n", + m_line_number, m_current_line); + name = {begin, it}; + type = trim({std::next(it), inner_string.end()}); + } + else + { + auto begin = it; + while (!std::isspace(*it)) + ++it; + if (it == inner_string.end() && it != inner_string.end()) + error( + "Expected at least two elements in attribute " + "declaration on line {}: \"{}\"", + m_line_number, m_current_line); + name = {begin, it}; + type = trim({std::next(it), inner_string.end()}); + } + + SG_DEBUG("name: {}\n", name); + SG_DEBUG("type: {}\n", type); + + if (name.empty() || type.empty()) + error( + "Could not find the name and type on line {}: \"{}\".\n", + m_line_number, m_current_line); + if (it == inner_string.end()) + error( + "Could not split attibute name and type on line {}: " + "\"{}\".\n", + m_line_number, m_current_line); + + // check if it is nominal + if (type[0] == '{') + { + // @ATTRIBUTE class {Iris-setosa,Iris-versicolor,Iris-virginica} + std::vector attributes; + // split nominal values: "{A, B, C}" to vector{A, B, C} + split( + type.substr(1, type.size() - 2), ", ", + std::back_inserter(attributes), "\'\""); + auto processed_name = trim(name, [](const auto& val) { + return !std::isspace(val) && val != '\'' && val != '\"'; + }); + m_attribute_names.emplace_back(processed_name); + m_nominal_attributes.emplace_back( + std::make_pair(name, attributes)); + m_attributes.push_back(Attribute::NOMINAL); + data_vectors.emplace_back(std::vector{}); + return; + } + + auto is_date = type.find("date") != std::string::npos; + if (is_date) + { + std::vector date_elements; + // split "date [[date-format]]" or "name date [[date-format]]" + split(type, " ", std::back_inserter(date_elements), "\""); + if (date_elements[0] == "date" && date_elements.size() < 3) + { + // @attribute date [[date-format]] + if (type.size() == 1) + m_date_formats.emplace_back(m_default_date_format); + else + m_date_formats.push_back( + javatime_to_cpptime(date_elements[1])); + name = ""; + } + else if (date_elements[1] == "date" && date_elements.size() < 4) + { + // @attribute name date [[date-format]] + if (date_elements.size() == 2) + m_date_formats.emplace_back(m_default_date_format); + else + m_date_formats.push_back( + javatime_to_cpptime(date_elements[2])); + } + else + { + error( + "Error parsing date on line {}: {}\n", m_line_number, + m_current_line); + } + m_attributes.push_back(Attribute::DATE); + data_vectors.emplace_back(std::vector{}); + } + else if (is_primitive_type(type)) + { + type = to_lower(type); + // numeric attributes + if (type == "numeric") + { + m_attributes.push_back(Attribute::NUMERIC); + data_vectors.emplace_back(std::vector{}); + } + else if (type == "integer") + { + m_attributes.push_back(Attribute::INTEGER); + data_vectors.emplace_back(std::vector{}); + } + else if (type == "real") + { + m_attributes.push_back(Attribute::REAL); + data_vectors.emplace_back(std::vector{}); + } + else if (type == "string") + { + // @ATTRIBUTE LCC string + m_attributes.push_back(Attribute::STRING); + data_vectors.emplace_back( + std::vector>{}); + } + else + error( + "Unexpected attribute type identifier \"{}\" " + "on line {}: {}\n", + type, m_line_number, m_current_line); + } + else + error( + "Unexpected format in @ATTRIBUTE on line {}: {}\n", + m_line_number, m_current_line); + auto processed_name = trim(name, [](const auto& val) { + return !std::isspace(val) && val != '\'' && val != '\"'; + }); + m_attribute_names.emplace_back(processed_name); + } + // comments in this section are ignored + else if (m_current_line.substr(0, 1) == m_comment_string) + { + } + else if ( + to_lower(m_current_line.substr(0, m_data_string.size())) == + m_data_string) + m_state = true; + }; + + auto check_attributes = [this]() { + // attributes cannot be empty + return !m_attributes.empty(); + }; + process_chunk(read_attributes, check_attributes, true); + + // estimate the size of the @data section + auto pos = m_stream->tellg(); + auto approx_data_line_count = std::count( + std::istreambuf_iterator(*m_stream), + std::istreambuf_iterator(), '\n'); + reserve_vector_memory(approx_data_line_count, data_vectors); + m_stream->seekg(pos); + + std::vector> elems; + elems.reserve(m_attributes.size()); + + // read the @data section + auto read_data = [this, &data_vectors, &elems]() { + // it's a comment and can be skipped + if (m_current_line.substr(0, 1) == m_comment_string) + return; + // it's the data string (i.e. @data"), does not provide information + if (to_lower(m_current_line.substr(0, m_data_string.size())) == + m_data_string) + return; + + // assumes that until EOF we should expect comma delimited values + elems.clear(); + split(m_current_line, ",", std::back_inserter(elems), "\'\""); + if (elems.size() != m_attributes.size()) + error( + "Unexpected number of values on line {}, expected {} " + "values, but found {}.\n", + m_line_number, m_attributes.size(), elems.size()); + // only parse rows that do not contain missing values + if (std::find(elems.begin(), elems.end(), m_missing_value_string) == + elems.end()) + { + auto nominal_pos = m_nominal_attributes.begin(); + auto date_pos = m_date_formats.begin(); + for (int i = 0; i < elems.size(); ++i) + { + Attribute type = m_attributes[i]; + switch (type) + { + case (Attribute::NUMERIC): + case (Attribute::INTEGER): + case (Attribute::REAL): + { + try + { + std::get>(data_vectors[i]) + .push_back(buffer_to_type(elems[i])); + } + catch (const std::invalid_argument&) + { + error( + "Failed to covert \"{}\" to numeric on line %d.\n", + elems[i], m_line_number); + } + } + break; + case (Attribute::NOMINAL): + { + if (nominal_pos == m_nominal_attributes.end()) + error( + "Unexpected nominal value \"{}\" on line {}\n", + elems[i].c_str(), m_line_number); + auto encoding = (*nominal_pos).second; + auto trimmed_el = trim(elems[i]); + remove_char_inplace(trimmed_el, '\''); + auto pos = + std::find(encoding.begin(), encoding.end(), trimmed_el); + if (pos == encoding.end()) + error( + "Unexpected value \"{}\" on line %d\n", + trimmed_el, m_line_number); + ScalarType idx = std::distance(encoding.begin(), pos); + std::get>(data_vectors[i]) + .push_back(idx); + ++nominal_pos; + } + break; + case (Attribute::DATE): + { + date::sys_seconds t; + std::istringstream ss(elems[i]); + if (date_pos == m_date_formats.end()) + error( + "Unexpected date value \"{}\" on line {}.\n", + elems[i], m_line_number); + ss >> date::parse(*date_pos, t); + if (bool(ss)) + { + auto value_timestamp = t.time_since_epoch().count(); + std::get>(data_vectors[i]) + .push_back(value_timestamp); + } + else + error( + "Error parsing date \"{}\" with date format \"{}\" " + "on line {}.\n", + elems[i], *date_pos, + m_line_number); + ++date_pos; + } + break; + case (Attribute::STRING): + std::get>>( + data_vectors[i]) + .emplace_back(elems[i]); + } + } + ++m_row_count; + } + }; + auto check_data = [&data_vectors]() { + if (!data_vectors.empty()) + { + auto feature_count = data_vectors.size(); + size_t row_count = + std::visit(VectorSizeVisitor{}, data_vectors[0]); + for (int i = 1; i < feature_count; ++i) + { + require( + std::visit(VectorSizeVisitor{}, data_vectors[i]) == + row_count, + "All columns must have the same number of features!\n"); + } + } + else + return false; + return true; + }; + process_chunk(read_data, check_data, true); + + // transform data into a feature object + index_t row_count = std::visit(VectorSizeVisitor{}, data_vectors[0]); + for (int i = 0; i < data_vectors.size(); ++i) + { + Attribute att = m_attributes[i]; + auto vec = data_vectors[i]; + switch (att) + { + case Attribute::NUMERIC: + case Attribute::INTEGER: + case Attribute::REAL: + case Attribute::DATE: + case Attribute::NOMINAL: + { + auto casted_vec = std::get>(vec); + SGMatrix mat(1, row_count); + memcpy( + mat.matrix, casted_vec.data(), + casted_vec.size() * sizeof(ScalarType)); + m_features.push_back(std::make_shared>(mat)); + } + break; + case Attribute::STRING: + { + auto casted_vec = + std::get>>(vec); + index_t max_string_length = 0; + for (const auto& el : casted_vec) + { + if (max_string_length < el.size()) + max_string_length = el.size(); + } + std::vector> strings(row_count, max_string_length); + for (int j = 0; j < row_count; ++j) + { + SGVector current(max_string_length); + memcpy( + current.vector, casted_vec[j].data(), + (casted_vec.size() + 1) * sizeof(CharType)); + strings[j] = current; + } + m_features.push_back( + std::make_shared>(strings, EAlphabet::RAWBYTE)); + } + } + } +} + +template +void ARFFDeserializer::read_string_dispatcher() +{ + switch (m_string_primitive_type) + { + case EPrimitiveType::PT_UINT8: + { + read_helper(); + } + break; + case EPrimitiveType::PT_UINT16: + { + error("16-bit wide string conversion not available."); + } + break; + default: + error("The provided type for string parsing is not valid!\n"); + } +} + +void ARFFDeserializer::read() +{ + switch (m_primitive_type) + { + case EPrimitiveType::PT_INT8: + { + read_string_dispatcher(); + } + break; + case EPrimitiveType::PT_INT16: + { + read_string_dispatcher(); + } + break; + case EPrimitiveType::PT_INT32: + { + read_string_dispatcher(); + } + break; + case EPrimitiveType::PT_INT64: + { + read_string_dispatcher(); + } + break; + case EPrimitiveType::PT_FLOAT32: + { + read_string_dispatcher(); + } + break; + case EPrimitiveType::PT_FLOAT64: + { + read_string_dispatcher(); + } + break; + case EPrimitiveType::PT_FLOATMAX: + { + read_string_dispatcher(); + } + break; + default: + error("The provided type for scalar parsing is not valid!\n"); + } +} + +template +void ARFFDeserializer::reserve_vector_memory( + size_t line_count, + std::vector, std::vector>>>& v) +{ + VectorResizeVisitor visitor{line_count}; + for (auto& vec : v) + std::visit(visitor, vec); +} + +/** + * Very type unsafe, but no UB! + * @param obj + * @return + */ +std::vector features_to_string(const std::shared_ptr& obj, Attribute att) +{ + std::vector result_string; + switch (att) + { + case Attribute::NUMERIC: + case Attribute::REAL: + { + auto mat_to_string = [&result_string](const auto& mat) { + result_string.reserve(mat.size()); + for (int i = 0; i < mat.size(); ++i) + { + result_string.push_back(std::to_string(mat[i])); + } + }; + + for (const auto& param : obj->get_params()) + if (param.first == "feature_matrix") + { + sg_any_dispatch( + param.second->get_value(), sg_matrix_typemap, + shogun::None{}, shogun::None{}, mat_to_string); + return result_string; + } + } + break; + case Attribute::INTEGER: + { + auto mat_to_string = [&result_string](const auto& mat) { + result_string.reserve(mat.size()); + for (int i = 0; i < mat.size(); ++i) + { + result_string.push_back( + std::to_string(static_cast(mat[i]))); + } + }; + + for (const auto& param : obj->get_params()) + if (param.first == "feature_matrix") + { + sg_any_dispatch( + param.second->get_value(), sg_matrix_typemap, + shogun::None{}, shogun::None{}, mat_to_string); + return result_string; + } + } + break; + default: + error("Unsupported type: {}\n", static_cast(att)); + } + error("The provided feature object does not have a feature matrix!\n"); + return std::vector{}; +} + +std::vector features_to_string( + const std::shared_ptr& obj, const std::vector& nominal_values) +{ + std::vector result_string; + auto mat_to_string = [&result_string, &nominal_values](const auto& mat) { + result_string.reserve(mat.size()); + for (int i = 0; i < mat.size(); ++i) + { + result_string.emplace_back( + "\"" + nominal_values[static_cast(mat[i])] + "\""); + } + }; + + for (const auto& param : obj->get_params()) + if (param.first == "feature_matrix") + { + sg_any_dispatch( + param.second->get_value(), sg_matrix_typemap, shogun::None{}, + shogun::None{}, mat_to_string); + return result_string; + } + error("The provided feature object does not have a feature matrix!\n"); + return std::vector{}; +} + +std::unique_ptr ARFFSerializer::write() +{ + auto ss = std::make_unique(); + + // @relation + *ss << ARFFDeserializer::m_relation_string << " " << m_name << "\n\n"; + + // @attribute + for (const auto& att : m_attributes) + { + switch (att.second) + { + case Attribute::NUMERIC: + *ss << ARFFDeserializer::m_attribute_string << " " << att.first + << " numeric\n"; + break; + case Attribute::INTEGER: + *ss << ARFFDeserializer::m_attribute_string << " " << att.first + << " integer\n"; + break; + case Attribute::REAL: + *ss << ARFFDeserializer::m_attribute_string << " " << att.first + << " real\n"; + break; + case Attribute::STRING: + *ss << ARFFDeserializer::m_attribute_string << " " << att.first + << " string\n"; + break; + case Attribute::DATE: + error("C++ to Java date format conversion is not implement!"); + break; + case Attribute::NOMINAL: + { + *ss << ARFFDeserializer::m_attribute_string << " " << att.first + << " "; + auto nominal_values_vector = m_nominal_mapping.at(att.first); + std::string nominal_values_string = std::accumulate( + nominal_values_vector.begin(), nominal_values_vector.end(), + "{\"" + nominal_values_vector[0] + "\"", + [](std::string& lhs, const std::string& rhs) { + return lhs += ",\"" + rhs + "\""; + }); + nominal_values_string.append("}\n"); + *ss << nominal_values_string; + } + } + } + + // @data + *ss << "\n" << ARFFDeserializer::m_data_string << "\n\n"; + + auto num_vectors = m_feature_list.back()->as()->get_num_vectors(); + std::vector> result; + auto att_iter = m_attributes.begin(); + + for (const auto& feature: m_feature_list) + { + auto n_i = feature->as()->get_num_vectors(); + require( + n_i == num_vectors, + "Expected all features to have the same number of examples!\n"); + + switch (att_iter->second) + { + case Attribute::NUMERIC: + case Attribute::REAL: + case Attribute::INTEGER: + result.push_back(features_to_string(feature, att_iter->second)); + break; + case Attribute::NOMINAL: + result.push_back( + features_to_string(feature, m_nominal_mapping.at(att_iter->first))); + break; + case Attribute::DATE: + case Attribute::STRING: + error("Writing out strings and dates has not been implemented!"); + } + ++att_iter; + } + + std::vector result_rows(num_vectors); + + for (size_t col = 0; col != result.size(); ++col) + { + if (col != result.size() - 1) + for (auto row = 0; row != num_vectors; ++row) + result_rows[row].append(result[col][row] + ","); + else + for (auto row = 0; row != num_vectors; ++row) + result_rows[row].append(result[col][row] + "\n"); + } + + for (const auto& row : result_rows) + *ss << row; + + return ss; +} + +void ARFFSerializer::write(const std::string& filename) +{ + auto result = write(); + std::ofstream myfile; + myfile.open(filename); + myfile << result->str(); + myfile.close(); +} diff --git a/src/shogun/io/ARFFFile.h b/src/shogun/io/ARFFFile.h new file mode 100644 index 00000000000..622ea3fca2c --- /dev/null +++ b/src/shogun/io/ARFFFile.h @@ -0,0 +1,722 @@ +/* + * This software is distributed under BSD 3-clause license (see LICENSE file). + * + * Authors: Gil Hoben + */ + +#ifndef SHOGUN_ARFFFILE_H +#define SHOGUN_ARFFFILE_H + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace shogun +{ +#ifndef SWIG + /** Contains miscellaneous string manipulation functions using the STL and + * Java to C++ date format utilities */ + namespace arff_detail + { + /** + * Checks if string is blank + * @param line to check + * @return bool whether line is empty + */ + SG_FORCED_INLINE bool is_blank(const std::string& line) + { + return line.find_first_not_of(" \t\r\f\v") == std::string::npos; + } + + /** + * Checks if the character in the lhs is in the rhs + * @param lhs the single character to find + * @param rhs a string with the characters to compare against the lhs + * @return whether the character of the lhs is in the rhs + */ + SG_FORCED_INLINE bool is_part_of(char lhs, const std::string& rhs) + { + auto result = rhs.find(lhs); + return result != std::string::npos; + } + + /** + * Trims to the left of a string in place according to trim_func + * @param s the string to trim + * @param trim_func a unary function that determines what values to + * erase + */ + template + SG_FORCED_INLINE void left_trim(std::string& s, FunctorT trim_func) + { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), trim_func)); + } + + /** + * Trims to the right of a string in place according to trim_func + * @param s the string to trim + * @param trim_func a unary function that determines what values to + * erase + */ + template + SG_FORCED_INLINE void right_trim(std::string& s, FunctorT trim_func) + { + s.erase( + std::find_if(s.rbegin(), s.rend(), trim_func).base(), s.end()); + } + + const auto lambda_is_space = [](const auto& val) { + return !std::isspace(val); + }; + /** + * Returns the string trimmed to the left and right according to + * trim_func. By default this function trims whitespaces + * @param s the string to trim + * @param trim_func a unary function that determines what values to + * erase + */ + template + SG_FORCED_INLINE std::string + trim(std::string line, FunctorT trim_func = lambda_is_space) + { + left_trim(line, trim_func); + right_trim(line, trim_func); + return line; + } + + /** + * Splits a line given a set of delimiter characters + * + * @param s string to split + * @param delimiters a set of delimiter character + * @param result dynamic container inserter where tokens are stored + * @param quotes a string with the characters that are considered + * quotes, i.e. any text between quotes is kept together. + */ + template + SG_FORCED_INLINE void split( + const std::basic_string& s, const std::string& delimiters, + std::back_insert_iterator>> result, + const std::string& quotes) + { + auto it = s.begin(); + auto begin = s.begin(); + while (is_part_of(*it, delimiters)) + { + ++it; + begin = it; + } + while (it != s.end()) + { + if (is_part_of(*it, delimiters)) + { + } + else if (is_part_of(*it, quotes)) + { + auto quote_type = *it; + ++it; + begin = it; + while (*it != quote_type) + { + ++it; + } + if (it == s.end()) + error( + "Encountered unbalanced parenthesis in \"{}\"\n", + std::string(std::prev(begin), it)); + *(result++) = {begin, it}; + } + else + { + begin = it; + while (!is_part_of(*it, delimiters) && it != s.end()) + { + ++it; + } + std::basic_string token{begin, it}; + if (!is_blank(token)) + *(result++) = token; + } + if (it != s.end()) + { + ++it; + begin = it; + } + } + } + + /** + * Returns a string in lowercase. + * + * @param line string to process + * @return lowercase string + */ + SG_FORCED_INLINE std::string to_lower(const std::string& line) + { + std::string result; + std::transform( + line.begin(), line.end(), std::back_inserter(result), + [](const auto& val) { return std::tolower(val); }); + return result; + } + + /** + * Returns string without whitespace + * @param line string to process + * @return string without whitespace + */ + SG_FORCED_INLINE std::string remove_whitespace(const std::string& line) + { + std::string result = line; + result.erase( + std::remove_if(result.begin(), result.end(), ::isspace), + result.end()); + return result; + } + + /** + * Removes all occurences of a character in place + * @param line string to process + * @param character char to remove + */ + SG_FORCED_INLINE void + remove_char_inplace(std::string& line, char character) + { + line.erase( + std::remove_if( + line.begin(), line.end(), + [&character](auto const& val) { return val == character; }), + line.end()); + } + + /** + * Java to C++ time format token converter. + * + * Java tokens taken from: + * http://tutorials.jenkov.com/java-date-time/parsing-formatting-dates.html + * C++ tokens taken from: + * https://www.ibm.com/support/knowledgecenter/en/ssw_ibm_i_71/rtref/strpti.htm + * @param java_token + * @return C++ equivalent + */ + SG_FORCED_INLINE const char* + process_javatoken(const std::string& java_token) + { + if (java_token == "yy") + return "%y"; + if (java_token == "yyyy") + return "%Y"; + if (java_token == "MM") + return "%m"; + if (java_token == "dd") + return "%d"; + if (java_token == "hh") + return "%I"; + if (java_token == "HH") + return "%H"; + if (java_token == "mm") + return "%M"; + if (java_token == "ss") + return "%S"; + if (java_token == "Z") + return "%z"; + if (java_token == "z") + error( + "Timezone abbreviations are currently not supported.\n"); + if (java_token.empty()) + return ""; + if (java_token == "SSS") + return nullptr; + return nullptr; + } + SG_FORCED_INLINE const char* process_javatoken(char java_token) + { + if (java_token == ':') + return ":"; + if (java_token == '\'') + return ""; + if (java_token == '-') + return "-"; + if (java_token == ' ') + return " "; + return nullptr; + } + + /** + * Checks if a Java token is valid and returns string representing the + * C++ token. + * @param java_time_token token to check and translate + * @return translated C++ token + */ + SG_FORCED_INLINE const char* + check_and_append_j2cpp(const std::string& java_time_token) + { + if (auto cpp_token = process_javatoken(java_time_token)) + return cpp_token; + else + error( + "Could not convert Java time token \"{}\" to C++ time " + "token.\n", + java_time_token); + return nullptr; + } + + SG_FORCED_INLINE const char* + check_and_append_j2cpp(char java_time_token) + { + if (auto cpp_token = process_javatoken(java_time_token)) + return cpp_token; + else + error( + "Could not convert Java time token \"{}\" to C++ time " + "token.\n", + java_time_token); + return nullptr; + } + + /** + * Converts a Java SimpleDateFormat to a C++ date format + * @param java_time the string to translate + * @return the C++ format equivalent of java_time + */ + SG_FORCED_INLINE std::string + javatime_to_cpptime(const std::string& java_time) + { + std::string cpp_time; + std::string token; + auto begin = java_time.begin(); + auto it = java_time.begin(); + while (it != java_time.end()) + { + if (*it == '-' || *it == ' ' || *it == ':') + { + token = {begin, it}; + cpp_time.append(check_and_append_j2cpp(token)); + cpp_time.append(check_and_append_j2cpp(*it)); + begin = std::next(it); + } + else if (*it == '\'') + { + token = {begin, it}; + cpp_time.append(check_and_append_j2cpp(token)); + cpp_time.append(check_and_append_j2cpp(*it)); + begin = std::next(it); + begin = it; + ++it; + while (*it != '\'') + { + ++it; + } + token = {std::next(begin), it}; + cpp_time.append(token); + begin = std::next(it); + } + else if (std::next(it) == java_time.end()) + { + token = {begin, std::next(it)}; + if (auto cpp_token = process_javatoken(token)) + { + cpp_time.append(cpp_token); + } + else + error( + "Could not convert Java time token {} to C++ time " + "token.\n", + token); + } + ++it; + } + return cpp_time; + } + } // namespace arff_detail +#endif // SWIG + + /** + * The attributes supported in the ARFF format + */ + enum class Attribute + { + NUMERIC = 0, + INTEGER = 1, + REAL = 2, + STRING = 3, + DATE = 4, + NOMINAL = 5 + }; + + class ARFFSerializer; + + /** + * ARFFDeserializer parses files in the ARFF format. + * For information about this format see + * https://waikato.github.io/weka-wiki/arff_stable/ + */ + class ARFFDeserializer + { + public: + friend class ARFFSerializer; + /** + * ARFFDeserializer constructor with a filename. + * Performs a check to see if a file can be streamed. + * Fails if file does not exist, or it cannot be opened, + * i.e. not the correct permission. + * + * @param filename the name of the file to parse + * @param primitive_type the type to parse the scalars in, i.e. numeric + * attributes + */ + explicit ARFFDeserializer( + const std::string& filename, + EPrimitiveType primitive_type = PT_FLOAT64, + EPrimitiveType string_primitive_type = PT_UINT8) + : m_primitive_type(primitive_type), + m_string_primitive_type(string_primitive_type) + { + auto* file_stream = new std::ifstream(filename); + if (file_stream->fail()) + { + error( + "Cannot open {}. Please check if file exists and if you " + "have the right permissions to open it.\n", + filename); + } + m_stream = std::unique_ptr(file_stream); + } +#ifndef SWIG + /** + * ARFFDeserializer constructor with an input stream. + * This constructors copies the stream and takes care + * of proper deletion. + * + * @param stream the input stream + * @param primitive_type the type to parse the scalars in, i.e. numeric + * attributes + */ + explicit ARFFDeserializer( + std::shared_ptr& stream, + EPrimitiveType primitive_type = PT_FLOAT64, + EPrimitiveType string_primitive_type = PT_UINT8) + : m_stream(stream), m_primitive_type(primitive_type), + m_string_primitive_type(string_primitive_type) + { + } +#endif // SWIG + /** + * Parse the file. + */ + void read(); + + /** + * Returns string parsed in @relation line + * @return the relation string + */ + std::string get_relation() const noexcept + { + return m_relation; + } + + /** + * Returns the name of the features parsed in "@attribute" + * @return the relation string + */ + std::vector get_feature_names() const noexcept + { + return m_attribute_names; + } + + /** + * Get list of features from parsed data. The label name indicates the + * column to be excluded, i.e. it's a label and not a feature. + * @return a list of features + */ + std::vector> get_features(const std::string& label_name) const + { + auto find_label = std::find( + m_attribute_names.begin(), m_attribute_names.end(), label_name); + if (find_label == m_attribute_names.end()) + error( + "The provided label \"{}\" was not found!\n", + label_name); + + std::vector> result; + + int idx = 0; + int label_idx = + std::distance(m_attribute_names.begin(), find_label); + for (const auto& feat : m_features) + { + if (idx != label_idx) + result.push_back(feat); + ++idx; + } + + return result; + } + + /** + * Get list of features from parsed data. + * @return a list of features + */ + std::vector> get_features() const + { + return m_features; + } + + /** + * Get feature by name. + * @return the requested feature if it exists. + */ + std::shared_ptr get_feature(const std::string& feature_name) const + { + auto find_feature = std::find( + m_attribute_names.begin(), m_attribute_names.end(), + feature_name); + if (find_feature == m_attribute_names.end()) + error( + "The provided label \"{}\" was not found!\n", + feature_name); + int feature_idx = + std::distance(m_attribute_names.begin(), find_feature); + auto result = m_features[feature_idx]; + return result; + } + + /** + * Get ARFF attribute types. + */ + std::vector get_attribute_types() const noexcept + { + return m_attributes; + } + + /** + * Returns the nominal values in the order of encoding. + * For example for an ARFF file with "@ATTRIBUTE class + * {Iris-setosa,Iris-versicolor,Iris-virginica}" + * get_nominal_values("class") returns the vector + * {"Iris-setosa","Iris-versicolor","Iris-virginica"} + * @return nominal values + */ + std::vector + get_nominal_values(const std::string& feature_name) const + { + + for (const auto& nom_att : m_nominal_attributes) + { + if (nom_att.first == feature_name) + return nom_att.second; + } + error("The provided feature name is not a nominal feature!\n"); + return std::vector{}; + } + + protected: + /** character used in file to comment out a line */ + static constexpr std::string_view m_comment_string = "%"; + /** characters to declare relations, i.e. @relation */ + static constexpr std::string_view m_relation_string = "@relation"; + /** characters to declare attributes, i.e. @attribute */ + static constexpr std::string_view m_attribute_string = "@attribute"; + /** characters to declare data fields, i.e. @data */ + static constexpr std::string_view m_data_string = "@data"; + /** the default C++ date format specified by the ARFF standard */ + static constexpr std::string_view m_default_date_format = "%Y-%M-%DT%H:%M:%S"; + /** missing data */ + static constexpr std::string_view m_missing_value_string = "?"; + + private: + /** + * Templated parser helper for string container primitive type. + * + */ + template + void read_string_dispatcher(); + + /** + * Templated parser helper. + * + */ + template + void read_helper(); + + /** + * Processes a chunk. A chunk is defined as a set of lines that + * are processed in the same way. A chunk ends when the func + * sets the internal m_state to false. + * Parsing can also end when the stream reaches EOF. + * + * @tparam LambdaT type of processing function + * @tparam CheckT type of check function + * @param func processing function that reads each line + * @param check_func function that checks the result from the processing + * function + * @param skip_first whether to stream the first line + */ + template + void process_chunk(LambdaT&& func, CheckT&& check_func, bool skip_first) + { + m_state = false; + + if (skip_first && !m_stream->eof()) + func(); + + while (!m_state && !m_file_done) + { + consume_line(func); + } + if (!check_func()) + { + error( + "Parsing error on line {}: {}\n", m_line_number, + m_current_line); + } + } + + /** + * Function called by process_chunk to process a "chunk" line by line. + * This function also checks if EOF has been reached. + * + * @tparam T type of processing function + * @param func line processing function + */ + template + void consume_line(T&& func) + { + if (m_stream->eof()) + { + m_file_done = true; + return; + } + std::getline(*m_stream, m_current_line); + m_line_number++; + if (!arff_detail::is_blank(m_current_line)) + func(); + } +#ifndef SWIG + /** + * Checks if a token represented by a string + * denotes a primitive type in the ARFF format + * @param token the token to be checked + * @return whether the token denotes a primitive type + */ + SG_FORCED_INLINE bool is_primitive_type(const std::string& token) const + noexcept + { + return token.find_first_of("numeric") != std::string::npos || + token.find_first_of("integer") != std::string::npos || + token.find_first_of("real") != std::string::npos || + token.find_first_of("string") != std::string::npos; + } +#endif // SWIG + template + void reserve_vector_memory( + size_t line_count, + std::vector, + std::vector>>>& v); + + /** the name of the attributes */ + std::vector m_attribute_names; + + /** the input stream */ + std::shared_ptr m_stream; + /** the scalar type used for parsing */ + EPrimitiveType m_primitive_type; + /** the string underlying type used for parsing */ + EPrimitiveType m_string_primitive_type; + + /** internal line number counter for exceptions */ + size_t m_line_number; + /** internal flag set true when string stream is EOF */ + bool m_file_done; + /** internal state when set to true switches parsing rules */ + bool m_state; + /** current row count of data */ + size_t m_row_count; + /** the string after m_relation_string*/ + std::string m_relation; + /** the string where comments are stored */ + std::vector m_comments; + /** the string representing the current line being parsed */ + std::string m_current_line; + /** the attribute types in the order they are parsed */ + std::vector m_attributes; + /** stores the date formats */ + std::vector m_date_formats; + /** the mapping of nominal attributes to their value */ + std::vector>> + m_nominal_attributes; + + /** the parsed features */ + std::vector> m_features; + }; + + /** + * ARFFSerializer writes out files in the ARFF format. + * For information about this format see + * https://waikato.github.io/weka-wiki/arff_stable/ + */ + class ARFFSerializer + { + public: + /** + * The ARFFSerializer constructor. + * + * @param name the name of the dataset + * @param feature_list a list with individual features + * @param attributes a map of the feature names to the ARFF type the + * features translate to + * @param nominal_mapping a mapping of nominal features to a vector of + * strings whose index will be used to infer the nominal value + */ + ARFFSerializer( + const std::string& name, std::vector> feature_list, + const std::vector>& attributes, + const std::unordered_map>& + nominal_mapping) + : m_name(name), m_attributes(attributes), + m_nominal_mapping(nominal_mapping) + { + m_feature_list = std::move(feature_list); + } + +#ifndef SWIG + /** + * Writes out features to an output stream. + * @return the output stream + */ + std::unique_ptr write(); +#endif // SWIG + + /** + * Writes out features with the provided information + * used in the constructor. + * + * @param filename the file to write to + */ + void write(const std::string& filename); + + private: + /** the name of the dataset */ + std::string m_name; + /** the list of features to write out */ + std::vector> m_feature_list; + /** the attributes */ + std::vector> m_attributes; + /** the nominal attributes, if any */ + std::unordered_map> + m_nominal_mapping; + }; + +} // namespace shogun + +#endif // SHOGUN_ARFFFILE_H diff --git a/tests/unit/io/ARFFFile_unittest.cc b/tests/unit/io/ARFFFile_unittest.cc new file mode 100644 index 00000000000..0ec17d59791 --- /dev/null +++ b/tests/unit/io/ARFFFile_unittest.cc @@ -0,0 +1,253 @@ +/* + * This software is distributed under BSD 3-clause license (see LICENSE file). + * + * Authors: Gil Hoben + */ + +#include +#include + +#include + +using namespace shogun; + +// Tolerance values for tests +template +constexpr T get_epsilon() +{ + return std::numeric_limits::epsilon(); +} + +// convert type to the supported enums +template +constexpr EPrimitiveType convert_type_to_enum() +{ + return EPrimitiveType::PT_UNDEFINED; +} + +template <> +constexpr EPrimitiveType convert_type_to_enum() +{ + return EPrimitiveType::PT_INT8; +} + +template <> +constexpr EPrimitiveType convert_type_to_enum() +{ + return EPrimitiveType::PT_INT16; +} + +template <> +constexpr EPrimitiveType convert_type_to_enum() +{ + return EPrimitiveType::PT_INT32; +} + +template <> +constexpr EPrimitiveType convert_type_to_enum() +{ + return EPrimitiveType::PT_INT64; +} + +template <> +constexpr EPrimitiveType convert_type_to_enum() +{ + return EPrimitiveType::PT_FLOAT32; +} + +template <> +constexpr EPrimitiveType convert_type_to_enum() +{ + return EPrimitiveType::PT_FLOAT64; +} + +template <> +constexpr EPrimitiveType convert_type_to_enum() +{ + return EPrimitiveType::PT_FLOATMAX; +} + +template +class ARFF_typed_tests : public ::testing::Test +{ +}; + +SG_TYPED_TEST_CASE( + ARFF_typed_tests, + Types); + +TYPED_TEST(ARFF_typed_tests, Parse_numeric) +{ + auto type = convert_type_to_enum(); + + std::string test = "@relation test_numeric \n" + "%\n" + "% \n" + "@attribute VAR1 numeric \n" + "@attribute VAR2 real \n" + "% \n" + "% \n" + "@data \n" + "50, 5.1 \n" + "45, 4.13 "; + auto ss = std::make_shared(test); + auto s = std::shared_ptr(ss); + + auto parser = std::make_unique(s, type); + parser->read(); + auto result = parser->get_features(); + ASSERT_EQ(result.size(), 2); + + auto col1 = + result[0]->template as>(); + auto col2 = + result[1]->template as>(); + + SGVector solution1{50, 45}; + SGVector solution2{static_cast(5.1), + static_cast(4.13)}; + + ASSERT_EQ(col1->get_feature_matrix()[0], solution1[0]); + ASSERT_EQ(col1->get_feature_matrix()[1], solution1[1]); + + EXPECT_NEAR( + col2->get_feature_matrix()[0], solution2[0], get_epsilon()); + EXPECT_NEAR( + col2->get_feature_matrix()[1], solution2[1], get_epsilon()); + ASSERT_EQ(parser->get_relation(), "test_numeric"); +} + +TEST(ARFFFileTest, Parse_datetime) +{ + std::string test = "@relation test_date \n" + "% \n" + "% \n" + "@attribute PERIOD_DATE date \"yyyy-MM-dd Z\" \n" + "@attribute VAR1 numeric \n" + "% \n" + "% \n" + "@data \n" + "\"2019-01-10 +0000\", 50 \n" + "\"2019-02-10 -0100\", 26 \n" + "\"2019-03-10 +0000\", 34 \n" + "\"2019-04-10 +0000\", 41 \n" + "\"2019-05-10 +0000\", 44 \n" + "\"2019-06-10 +0000\", 45 "; + + auto ss = std::make_shared(test); + auto s = std::shared_ptr(ss); + + SGVector solution1{1547078400, 1549760400, 1552176000, + 1554854400, 1557446400, 1560124800}; + + SGVector solution2{50, 26, 34, 41, 44, 45}; + + auto parser = std::make_unique(s); + parser->read(); + auto result = parser->get_features(); + ASSERT_EQ(result.size(), 2); + + auto col1 = result[0]->as>(); + auto mat1 = col1->get_feature_matrix(); + auto col2 = result[1]->as>(); + auto mat2 = col2->get_feature_matrix(); + ASSERT_EQ(mat1.size(), 6); + for (int i = 0; i < 6; ++i) + { + ASSERT_EQ(mat1[i], solution1[i]); + ASSERT_EQ(mat2[i], solution2[i]); + } + ASSERT_EQ(parser->get_relation(), "test_date"); +} + +TEST(ARFFFileTest, Parse_string) +{ + std::string test = "@relation test_string \n" + "@attribute VAR1 string \n" + "@attribute VAR2 numeric \n" + "@data \n" + "\"test1\", 50 \n" + "\"test2\", 26 \n" + "\"test3\", 34 \n" + "test1, 41 \n" + "test2, 44 \n" + "test3, 45 "; + + auto ss = std::make_shared(test); + auto s = std::shared_ptr(ss); + + std::vector solution1{"test1", "test2", "test3", + "test1", "test2", "test3"}; + SGVector solution2{50, 26, 34, 41, 44, 45}; + + auto parser = std::make_unique(s); + parser->read(); + auto result = parser->get_features(); + ASSERT_EQ(result.size(), 2); + auto col1 = result[0]->as>(); + auto col2 = result[1]->as>(); + auto mat2 = col2->get_feature_matrix(); + ASSERT_EQ(col1->get_num_vectors(), 6); + for (int i = 0; i < col1->get_num_vectors(); ++i) + { + auto row = col1->get_feature_vector(i); + for (auto j = 0; j < col1->get_max_vector_length(); ++j) + ASSERT_EQ(row[j], solution1[i][j]); + ASSERT_EQ(mat2[i], solution2[i]); + } + ASSERT_EQ(parser->get_relation(), "test_string"); +} + +TEST(ARFFFileTest, Parse_nominal) +{ + std::string test = + "@relation test_nominal \n" + "% \n" + "% \n" + "@attribute \"Twist n\' Shout\" {\"a\", b, \"c 1\", \'¯\\_(ツ)_/¯\'} \n" + "@attribute VAR2 numeric \n" + "% \n" + "% \n" + "@data \n" + " \'a\', 50 \n" + "b, 26 \n" + "\"b\" , 34 \n" + " \'c 1\' , 41 \n" + "% the row below can be replaced if it causes issues...\n" + "\"¯\\_(ツ)_/¯\", 44 \n" + "a, 45 "; + + auto ss = std::make_shared(test); + auto s = std::shared_ptr(ss); + + SGVector solution1{0, 1, 1, 2, 3, 0}; + SGVector solution2{50, 26, 34, 41, 44, 45}; + std::vector nom_values_result{"a", "b", "c 1", "¯\\_(ツ)_/¯"}; + + auto parser = std::make_unique(s); + parser->read(); + auto result = parser->get_features(); + ASSERT_EQ(result.size(), 2); + + auto col1 = result[0]->as>(); + auto mat1 = col1->get_feature_matrix(); + auto col2 = result[1]->as>(); + auto mat2 = col2->get_feature_matrix(); + ASSERT_EQ(mat1.size(), 6); + for (int i = 0; i < 6; ++i) + { + ASSERT_EQ(mat1[i], solution1[i]); + ASSERT_EQ(mat2[i], solution2[i]); + } + auto nom_values = parser->get_nominal_values("Twist n\' Shout"); + ASSERT_EQ(nom_values.size(), nom_values_result.size()); + for (int i = 0; i < nom_values.size(); ++i) + { + ASSERT_EQ(nom_values[i], nom_values_result[i]); + } + + ASSERT_EQ(parser->get_relation(), "test_nominal"); + ASSERT_EQ(parser->get_feature_names().size(), 2); + ASSERT_EQ(parser->get_feature_names()[0], "Twist n\' Shout"); + ASSERT_EQ(parser->get_feature_names()[1], "VAR2"); +} \ No newline at end of file