diff --git a/CMakeLists.txt b/CMakeLists.txt index df77482c870..6c47eb03d45 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,14 +3,22 @@ project(torchvision) set(CMAKE_CXX_STANDARD 11) find_package(Torch REQUIRED) +find_package(OpenCV REQUIRED COMPONENTS core imgproc imgcodecs) + +file(GLOB HEADERS torchvision/csrc/vision.h torchvision/csrc/general.h) -file(GLOB HEADERS torchvision/csrc/vision.h) file(GLOB MODELS_HEADERS torchvision/csrc/models/*.h) file(GLOB MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp) -add_library (${PROJECT_NAME} SHARED ${MODELS_SOURCES}) +file(GLOB DATASETS_HEADERS torchvision/csrc/datasets/*.h) +file(GLOB DATASETS_SOURCES torchvision/csrc/datasets/*.h torchvision/csrc/datasets/*.cpp) + +add_library (${PROJECT_NAME} SHARED ${MODELS_SOURCES} ${DATASETS_SOURCES}) +target_link_libraries(${PROJECT_NAME} PRIVATE "${OpenCV_LIBS}") target_link_libraries(${PROJECT_NAME} PUBLIC "${TORCH_LIBRARIES}") install(TARGETS ${PROJECT_NAME} DESTINATION ${CMAKE_INSTALL_PREFIX}/lib) install(FILES ${HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME}) install(FILES ${MODELS_HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME}/models) +#install(FILES ${DATASETS_HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME}/datasets) + diff --git a/torchvision/csrc/datasets/datasetsimpl.cpp b/torchvision/csrc/datasets/datasetsimpl.cpp new file mode 100644 index 00000000000..795b5e2dd25 --- /dev/null +++ b/torchvision/csrc/datasets/datasetsimpl.cpp @@ -0,0 +1,90 @@ +#include "datasetsimpl.h" + +#include +#include +#include +#include +#include + +namespace vision { +namespace datasets { +namespace datasetsimpl { +std::vector lsdir(const std::string& path) { + std::vector list; + auto dp = opendir(path.c_str()); + + if (dp != nullptr) { + auto ep = readdir(dp); + + while (ep != nullptr) { + std::string name = ep->d_name; + if (name != "." && name != "..") + list.emplace_back(std::move(name)); + } + + closedir(dp); + } + + return list; +} + +std::string tolower(std::string str) { + std::transform(str.begin(), str.end(), str.begin(), ::tolower); + return str; +} + +inline bool comp(const std::string& A, const std::string& B) { + return tolower(A) < tolower(B); +}; + +void sort_names(std::vector& data) { + std::sort(data.begin(), data.end(), comp); +} + +bool isdir(const std::string& path) { + struct stat st; + if (stat(path.c_str(), &st) == 0) + return st.st_mode & S_IFDIR; + return false; +} + +bool isfile(const std::string& path) { + struct stat st; + if (stat(path.c_str(), &st) == 0) + return st.st_mode & S_IFREG; + return false; +} + +bool exists(const std::string& path) { + struct stat st; + return stat(path.c_str(), &st) == 0; +} + +torch::Tensor read_image(const std::string& path) { + auto mat = cv::imread(path); + TORCH_CHECK(!mat.empty(), "Failed to read image \"", path, "\"."); + + cv::cvtColor(mat, mat, cv::COLOR_BGR2RGB); + std::vector tensors; + std::vector channels(size_t(mat.channels())); + cv::split(mat, channels); + + for (auto& channel : channels) + tensors.push_back( + torch::from_blob(channel.ptr(), {mat.rows, mat.cols}, torch::kUInt8)); + + auto output = torch::cat(tensors) + .view({mat.channels(), mat.rows, mat.cols}) + .to(torch::kFloat); + return output / 255; +} + +std::string absolute_path(const std::string& path) { + char rpath[PATH_MAX]; + realpath(path.c_str(), rpath); + return std::string(rpath); +} + +} // namespace datasetsimpl +} // namespace datasets +} // namespace vision diff --git a/torchvision/csrc/datasets/datasetsimpl.h b/torchvision/csrc/datasets/datasetsimpl.h new file mode 100644 index 00000000000..9b446a81574 --- /dev/null +++ b/torchvision/csrc/datasets/datasetsimpl.h @@ -0,0 +1,42 @@ +#ifndef DATASETSIMPL_H +#define DATASETSIMPL_H + +#include + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +namespace vision { +namespace datasets { +namespace datasetsimpl { + +std::vector lsdir(const std::string& path); + +std::string tolower(std::string str); + +void sort_names(std::vector& data); + +bool isdir(const std::string& path); + +bool isfile(const std::string& path); + +bool exists(const std::string& path); + +std::string absolute_path(const std::string& path); + +inline std::string join(const std::string& str) { + return str; +} +template +inline std::string join(const std::string& head, Tail&&... tail) { + return head + "/" + join(std::forward(tail)...); +} + +torch::Tensor read_image(const std::string& path); + +} // namespace datasetsimpl +} // namespace datasets +} // namespace vision + +#endif // DATASETSIMPL_H diff --git a/torchvision/csrc/models/general.h b/torchvision/csrc/general.h similarity index 100% rename from torchvision/csrc/models/general.h rename to torchvision/csrc/general.h diff --git a/torchvision/csrc/models/alexnet.h b/torchvision/csrc/models/alexnet.h index 813531fb383..4c164da64d0 100644 --- a/torchvision/csrc/models/alexnet.h +++ b/torchvision/csrc/models/alexnet.h @@ -2,7 +2,7 @@ #define ALEXNET_H #include -#include "general.h" +#include "../general.h" namespace vision { namespace models { diff --git a/torchvision/csrc/models/densenet.h b/torchvision/csrc/models/densenet.h index 3ed6eba0837..f758b70ab1e 100644 --- a/torchvision/csrc/models/densenet.h +++ b/torchvision/csrc/models/densenet.h @@ -2,7 +2,7 @@ #define DENSENET_H #include -#include "general.h" +#include "../general.h" namespace vision { namespace models { diff --git a/torchvision/csrc/models/googlenet.h b/torchvision/csrc/models/googlenet.h index 94390fd5070..4bf334939ac 100644 --- a/torchvision/csrc/models/googlenet.h +++ b/torchvision/csrc/models/googlenet.h @@ -2,7 +2,7 @@ #define GOOGLENET_H #include -#include "general.h" +#include "../general.h" namespace vision { namespace models { diff --git a/torchvision/csrc/models/inception.h b/torchvision/csrc/models/inception.h index d4edcbadd47..c01bfb5c0b8 100644 --- a/torchvision/csrc/models/inception.h +++ b/torchvision/csrc/models/inception.h @@ -2,7 +2,7 @@ #define INCEPTION_H #include -#include "general.h" +#include "../general.h" namespace vision { namespace models { diff --git a/torchvision/csrc/models/mnasnet.h b/torchvision/csrc/models/mnasnet.h index e499a7df987..86ae177f4a7 100644 --- a/torchvision/csrc/models/mnasnet.h +++ b/torchvision/csrc/models/mnasnet.h @@ -2,7 +2,7 @@ #define MNASNET_H #include -#include "general.h" +#include "../general.h" namespace vision { namespace models { diff --git a/torchvision/csrc/models/mobilenet.h b/torchvision/csrc/models/mobilenet.h index e69f840a4c9..7bbf63925a0 100644 --- a/torchvision/csrc/models/mobilenet.h +++ b/torchvision/csrc/models/mobilenet.h @@ -2,7 +2,7 @@ #define MOBILENET_H #include -#include "general.h" +#include "../general.h" namespace vision { namespace models { diff --git a/torchvision/csrc/models/resnet.h b/torchvision/csrc/models/resnet.h index ae9f4613ebe..774fe52a206 100644 --- a/torchvision/csrc/models/resnet.h +++ b/torchvision/csrc/models/resnet.h @@ -2,7 +2,7 @@ #define RESNET_H #include -#include "general.h" +#include "../general.h" namespace vision { namespace models { diff --git a/torchvision/csrc/models/shufflenetv2.h b/torchvision/csrc/models/shufflenetv2.h index 0e357f2ce7d..000e4a2b37a 100644 --- a/torchvision/csrc/models/shufflenetv2.h +++ b/torchvision/csrc/models/shufflenetv2.h @@ -2,7 +2,7 @@ #define SHUFFLENETV2_H #include -#include "general.h" +#include "../general.h" namespace vision { namespace models { diff --git a/torchvision/csrc/models/squeezenet.h b/torchvision/csrc/models/squeezenet.h index 298f1f04095..8b74e056f43 100644 --- a/torchvision/csrc/models/squeezenet.h +++ b/torchvision/csrc/models/squeezenet.h @@ -2,7 +2,7 @@ #define SQUEEZENET_H #include -#include "general.h" +#include "../general.h" namespace vision { namespace models { diff --git a/torchvision/csrc/models/vgg.h b/torchvision/csrc/models/vgg.h index cc9f98aea77..e48c1d078b3 100644 --- a/torchvision/csrc/models/vgg.h +++ b/torchvision/csrc/models/vgg.h @@ -2,7 +2,7 @@ #define VGG_H #include -#include "general.h" +#include "../general.h" namespace vision { namespace models {