diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 36e2e8f5f276..fe23acefd681 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -1175,72 +1175,6 @@ class String : public ObjectRef { */ inline String& operator=(const char* other); - /*! - * \brief Compare is less than other std::string - * - * \param other The other string - * - * \return the comparison result - */ - bool operator<(const std::string& other) const { return this->compare(other) < 0; } - bool operator<(const String& other) const { return this->compare(other) < 0; } - bool operator<(const char* other) const { return this->compare(other) < 0; } - - /*! - * \brief Compare is greater than other std::string - * - * \param other The other string - * - * \return the comparison result - */ - bool operator>(const std::string& other) const { return this->compare(other) > 0; } - bool operator>(const String& other) const { return this->compare(other) > 0; } - bool operator>(const char* other) const { return this->compare(other) > 0; } - - /*! - * \brief Compare is less than or equal to other std::string - * - * \param other The other string - * - * \return the comparison result - */ - bool operator<=(const std::string& other) const { return this->compare(other) <= 0; } - bool operator<=(const String& other) const { return this->compare(other) <= 0; } - bool operator<=(const char* other) const { return this->compare(other) <= 0; } - - /*! - * \brief Compare is greater than or equal to other std::string - * - * \param other The other string - * - * \return the comparison result - */ - bool operator>=(const std::string& other) const { return this->compare(other) >= 0; } - bool operator>=(const String& other) const { return this->compare(other) >= 0; } - bool operator>=(const char* other) const { return this->compare(other) >= 0; } - - /*! - * \brief Compare is equal to other std::string - * - * \param other The other string - * - * \return the comparison result - */ - bool operator==(const std::string& other) const { return this->compare(other) == 0; } - bool operator==(const String& other) const { return this->compare(other) == 0; } - bool operator==(const char* other) const { return compare(other) == 0; } - - /*! - * \brief Compare is not equal to other std::string - * - * \param other The other string - * - * \return the comparison result - */ - bool operator!=(const std::string& other) const { return this->compare(other) != 0; } - bool operator!=(const String& other) const { return this->compare(other) != 0; } - bool operator!=(const char* other) const { return this->compare(other) != 0; } - /*! * \brief Compares this String object to other * @@ -1372,6 +1306,29 @@ class String : public ObjectRef { */ static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count); + /*! + * \brief Concatenate two char sequences + * + * \param lhs Pointers to the lhs char array + * \param lhs_size The size of the lhs char array + * \param rhs Pointers to the rhs char array + * \param rhs_size The size of the rhs char array + * + * \return The concatenated char sequence + */ + static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) { + std::string ret(lhs, lhs_size); + ret.append(rhs, rhs_size); + return String(ret); + } + + // Overload + operator + friend String operator+(const String& lhs, const String& rhs); + friend String operator+(const String& lhs, const std::string& rhs); + friend String operator+(const std::string& lhs, const String& rhs); + friend String operator+(const String& lhs, const char* rhs); + friend String operator+(const char* lhs, const String& rhs); + friend struct tvm::ObjectEqual; }; @@ -1410,10 +1367,102 @@ inline String& String::operator=(std::string other) { inline String& String::operator=(const char* other) { return operator=(std::string(other)); } -inline String operator+(const std::string lhs, const String& rhs) { - return lhs + rhs.operator std::string(); +inline String operator+(const String& lhs, const String& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); } +inline String operator+(const String& lhs, const std::string& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const std::string& lhs, const String& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const char* lhs, const String& rhs) { + size_t lhs_size = std::strlen(lhs); + size_t rhs_size = rhs.size(); + return String::Concat(lhs, lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const String& lhs, const char* rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = std::strlen(rhs); + return String::Concat(lhs.data(), lhs_size, rhs, rhs_size); +} + +// Overload < operator +inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; } + +inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; } + +// Overload > operator +inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; } + +inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; } + +// Overload <= operator +inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } + +inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } + +// Overload >= operator +inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } + +inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(rhs) <= 0; } + +// Overload == operator +inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; } + +inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; } + +// Overload != operator +inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; } + +inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; } + inline std::ostream& operator<<(std::ostream& out, const String& input) { out.write(input.data(), input.size()); return out; diff --git a/src/ir/module.cc b/src/ir/module.cc index c7393749dc37..0d6eeb130ab0 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -365,7 +365,7 @@ void IRModuleNode::ImportFromStd(const String& path) { auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path"); CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; std::string std_path = (*f)(); - this->Import(std_path + "/" + path.operator std::string()); + this->Import(std_path + "/" + path); } std::unordered_set IRModuleNode::Imports() const { return this->import_set_; } diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index a09e24b12429..bf40f4bdb672 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -456,9 +456,7 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) { return PrintFunc(Doc::Text("fn "), GetRef(op)); } -Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { - return Doc::Text('@' + op->name_hint.operator std::string()); -} +Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text("@" + op->name_hint); } Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 94efff5180fb..9390feada456 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -83,7 +83,7 @@ Var Var::copy_with_suffix(const String& suffix) const { } else { new_ptr = make_object(*node); } - new_ptr->name_hint = new_ptr->name_hint.operator std::string() + suffix.operator std::string(); + new_ptr->name_hint = new_ptr->name_hint + suffix; return Var(new_ptr); } diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index efd6ac7e406f..eca65ee7af21 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -458,23 +458,54 @@ TEST(String, compare) { // compare with string CHECK_EQ(str_source.compare(source), 0); + CHECK(str_source == source); + CHECK(source == str_source); + CHECK(str_source <= source); + CHECK(source <= str_source); + CHECK(str_source >= source); + CHECK(source >= str_source); CHECK_LT(str_source.compare(mismatch1), 0); + CHECK(str_source < mismatch1); + CHECK(mismatch1 != str_source); CHECK_GT(str_source.compare(mismatch2), 0); + CHECK(str_source > mismatch2); + CHECK(mismatch2 < str_source); CHECK_GT(str_source.compare(mismatch3), 0); + CHECK(str_source > mismatch3); CHECK_LT(str_source.compare(mismatch4), 0); + CHECK(str_source < mismatch4); + CHECK(mismatch4 > str_source); // compare with char* CHECK_EQ(str_source.compare(source.data()), 0); + CHECK(str_source == source.data()); + CHECK(source.data() == str_source); + CHECK(str_source <= source.data()); + CHECK(source <= str_source.data()); + CHECK(str_source >= source.data()); + CHECK(source >= str_source.data()); CHECK_LT(str_source.compare(mismatch1.data()), 0); + CHECK(str_source < mismatch1.data()); + CHECK(str_source != mismatch1.data()); + CHECK(mismatch1.data() != str_source); CHECK_GT(str_source.compare(mismatch2.data()), 0); + CHECK(str_source > mismatch2.data()); + CHECK(mismatch2.data() < str_source); CHECK_GT(str_source.compare(mismatch3.data()), 0); + CHECK(str_source > mismatch3.data()); CHECK_LT(str_source.compare(mismatch4.data()), 0); + CHECK(str_source < mismatch4.data()); + CHECK(mismatch4.data() > str_source); // compare with String CHECK_LT(str_source.compare(str_mismatch1), 0); + CHECK(str_source < str_mismatch1); CHECK_GT(str_source.compare(str_mismatch2), 0); + CHECK(str_source > str_mismatch2); CHECK_GT(str_source.compare(str_mismatch3), 0); + CHECK(str_source > str_mismatch3); CHECK_LT(str_source.compare(str_mismatch4), 0); + CHECK(str_source < str_mismatch4); } TEST(String, c_str) { @@ -513,6 +544,23 @@ TEST(String, Cast) { String s2 = Downcast(r); } +TEST(String, Concat) { + String s1("hello"); + String s2("world"); + std::string s3("world"); + String res1 = s1 + s2; + String res2 = s1 + s3; + String res3 = s3 + s1; + String res4 = s1 + "world"; + String res5 = "world" + s1; + + CHECK_EQ(res1.compare("helloworld"), 0); + CHECK_EQ(res2.compare("helloworld"), 0); + CHECK_EQ(res3.compare("worldhello"), 0); + CHECK_EQ(res4.compare("helloworld"), 0); + CHECK_EQ(res5.compare("worldhello"), 0); +} + TEST(Optional, Composition) { Optional opt0(nullptr); Optional opt1 = String("xyz");