Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WIP] add SVM regressor and kernel coreml exporters
- Loading branch information
Showing
15 changed files
with
789 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,46 @@ | ||
MergeCFLAGS() | ||
include(ExternalProject) | ||
IF (NOT TARGET GoogleMock) | ||
MergeCFLAGS() | ||
include(ExternalProject) | ||
|
||
IF (MSVC) | ||
SET (CUSTOM_CMAKE_ARGS -DCMAKE_ARCHIVE_OUTPUT_DIRECTORY:PATH=${THIRD_PARTY_DIR}/libs/gmock | ||
-DCMAKE_LIBRARY_OUTPUT_DIRECTORY:PATH=${THIRD_PARTY_DIR}/libs/gmock | ||
-DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS}${CMAKE_DEFINITIONS} | ||
-DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} | ||
-DCMAKE_CXX_FLAGS_DISTRIBUTION:STRING=${CMAKE_CXX_FLAGS_DISTRIBUTION} | ||
-DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} | ||
-DCMAKE_C_COMPILER:STRING=${CMAKE_C_COMPILER} | ||
-DCMAKE_CXX_COMPILER:STRING=${CMAKE_CXX_COMPILER} | ||
) | ||
ELSE () | ||
SET(MERGED_CXX_FLAGS "${MERGED_CXX_FLAGS} -fPIC") | ||
SET (CUSTOM_CMAKE_ARGS -DCMAKE_ARCHIVE_OUTPUT_DIRECTORY:PATH=${THIRD_PARTY_DIR}/libs/gmock | ||
-DCMAKE_LIBRARY_OUTPUT_DIRECTORY:PATH=${THIRD_PARTY_DIR}/libs/gmock | ||
-DCMAKE_CXX_FLAGS:STRING=${MERGED_CXX_FLAGS}${CMAKE_DEFINITIONS} | ||
-DCMAKE_C_COMPILER:STRING=${CMAKE_C_COMPILER} | ||
-DCMAKE_CXX_COMPILER:STRING=${CMAKE_CXX_COMPILER} | ||
) | ||
ENDIF() | ||
IF (MSVC) | ||
SET (CUSTOM_CMAKE_ARGS -DCMAKE_ARCHIVE_OUTPUT_DIRECTORY:PATH=${THIRD_PARTY_DIR}/libs/gmock | ||
-DCMAKE_LIBRARY_OUTPUT_DIRECTORY:PATH=${THIRD_PARTY_DIR}/libs/gmock | ||
-DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS}${CMAKE_DEFINITIONS} | ||
-DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} | ||
-DCMAKE_CXX_FLAGS_DISTRIBUTION:STRING=${CMAKE_CXX_FLAGS_DISTRIBUTION} | ||
-DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} | ||
-DCMAKE_C_COMPILER:STRING=${CMAKE_C_COMPILER} | ||
-DCMAKE_CXX_COMPILER:STRING=${CMAKE_CXX_COMPILER} | ||
) | ||
ELSE () | ||
SET(MERGED_CXX_FLAGS "${MERGED_CXX_FLAGS} -fPIC") | ||
SET (CUSTOM_CMAKE_ARGS -DCMAKE_ARCHIVE_OUTPUT_DIRECTORY:PATH=${THIRD_PARTY_DIR}/libs/gmock | ||
-DCMAKE_LIBRARY_OUTPUT_DIRECTORY:PATH=${THIRD_PARTY_DIR}/libs/gmock | ||
-DCMAKE_CXX_FLAGS:STRING=${MERGED_CXX_FLAGS}${CMAKE_DEFINITIONS} | ||
-DCMAKE_C_COMPILER:STRING=${CMAKE_C_COMPILER} | ||
-DCMAKE_CXX_COMPILER:STRING=${CMAKE_CXX_COMPILER} | ||
) | ||
ENDIF() | ||
|
||
IF(EXISTS /usr/src/googletest) | ||
ExternalProject_Add( | ||
GoogleMock | ||
DOWNLOAD_COMMAND "" | ||
SOURCE_DIR /usr/src/googletest | ||
PREFIX ${CMAKE_BINARY_DIR}/GoogleMock | ||
INSTALL_COMMAND "" | ||
CMAKE_ARGS ${CUSTOM_CMAKE_ARGS} | ||
) | ||
ELSE() | ||
ExternalProject_Add( | ||
GoogleMock | ||
URL https://github.com/google/googletest/archive/release-1.8.1.tar.gz | ||
URL_MD5 2e6fbeb6a91310a16efe181886c59596 | ||
TIMEOUT 10 | ||
PREFIX ${CMAKE_BINARY_DIR}/GoogleMock | ||
DOWNLOAD_DIR ${THIRD_PARTY_DIR}/GoogleMock | ||
INSTALL_COMMAND "" | ||
CMAKE_ARGS ${CUSTOM_CMAKE_ARGS} | ||
) | ||
IF(EXISTS /usr/src/googletest) | ||
ExternalProject_Add( | ||
GoogleMock | ||
DOWNLOAD_COMMAND "" | ||
SOURCE_DIR /usr/src/googletest | ||
PREFIX ${CMAKE_BINARY_DIR}/GoogleMock | ||
INSTALL_COMMAND "" | ||
CMAKE_ARGS ${CUSTOM_CMAKE_ARGS} | ||
) | ||
ELSE() | ||
ExternalProject_Add( | ||
GoogleMock | ||
URL https://github.com/google/googletest/archive/release-1.8.1.tar.gz | ||
URL_MD5 2e6fbeb6a91310a16efe181886c59596 | ||
TIMEOUT 10 | ||
PREFIX ${CMAKE_BINARY_DIR}/GoogleMock | ||
DOWNLOAD_DIR ${THIRD_PARTY_DIR}/GoogleMock | ||
INSTALL_COMMAND "" | ||
CMAKE_ARGS ${CUSTOM_CMAKE_ARGS} | ||
) | ||
ENDIF() | ||
ENDIF() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#include "CoreMLConverter.h" | ||
#include "SVMConverter.h" | ||
|
||
#include <shogun/lib/exception/NotFittedException.h> | ||
|
||
|
||
using namespace shogun; | ||
using namespace shogun::coreml; | ||
|
||
std::shared_ptr<CoreMLModel> shogun::coreml::convert(const CMachine* m) | ||
{ | ||
/* | ||
auto converter_registry = ConverterFactory::instance(); | ||
std::string machine_name(m->get_name()); | ||
if (!m->is_trained()) | ||
throw NotFittedException("The supplied machine is not trained!"); | ||
auto model = std::make_shared<CoreMLModel>(); | ||
auto spec = model->get_specification(); | ||
//(*converter_registry)(machine_name)(m, spec); | ||
return model; | ||
*/ | ||
return nullptr; | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
#ifndef __COREML_CONVERTER_H__ | ||
#define __COREML_CONVERTER_H__ | ||
|
||
#include <functional> | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
|
||
#include <shogun/machine/Machine.h> | ||
|
||
#include "CoreMLModel.h" | ||
|
||
namespace shogun | ||
{ | ||
namespace coreml | ||
{ | ||
static constexpr int32_t SPECIFICATION_VERSION = 1; | ||
|
||
struct visitor; | ||
|
||
struct ICoreMLConverter { | ||
virtual void convert(visitor& v, const CSGObject* o) const = 0; | ||
}; | ||
|
||
template <class, class> | ||
struct CoreMLConverter; | ||
|
||
struct visitor { | ||
template<typename T, typename I> | ||
void visit(T*, I*); | ||
}; | ||
|
||
template <class I, class O> | ||
struct CoreMLConverter: ICoreMLConverter { | ||
typedef I input_type; | ||
typedef O output_type; | ||
|
||
std::shared_ptr<CoreML::Specification::Model> m_spec; | ||
CoreMLConverter(std::shared_ptr<CoreML::Specification::Model> spec): m_spec(spec) {} | ||
|
||
void convert(visitor& c, const CSGObject* o) const override | ||
{ | ||
auto x = o->as<const I>(); | ||
c.visit(this, x); | ||
} | ||
|
||
static O* convert(const I* o); | ||
static const std::unordered_set<std::string> supported_types; | ||
}; | ||
/* | ||
class ConverterFactory | ||
{ | ||
typedef std::function<std::shared_ptr<ICoreMLConverter>(std::shared_ptr<CoreML::Specification::Model>)> ConverterFactoryFunction; | ||
public: | ||
bool register_converter(const std::string& machine_name, ConverterFunction f) | ||
{ | ||
return m_registry.emplace(std::make_pair(machine_name, f)).second; | ||
} | ||
ConverterFunction& operator()(const std::string& m) | ||
{ | ||
auto f = m_registry.find(m); | ||
if (f == m_registry.end()) | ||
throw std::runtime_error("The provided machine cannot be converted to CoreML format!"); | ||
return f->second; | ||
} | ||
static ConverterFactory* instance() | ||
{ | ||
static ConverterFactory* f = new ConverterFactory(); | ||
return f; | ||
} | ||
private: | ||
std::unordered_map<std::string, ConverterFactoryFunction> m_registry; | ||
}; | ||
*/ | ||
|
||
std::shared_ptr<CoreMLModel> convert(const CMachine* m); | ||
|
||
#define REGISTER_COREML_CONVERTER(factory, classname, machines, function) \ | ||
static int register_converter##classname = []() { \ | ||
for (auto m = machines.cbegin(); m != machines.cend(); ++m) \ | ||
factory->register_converter(*m, function); \ | ||
return factory->size(); \ | ||
}(); | ||
|
||
#define REGISTER_CONVERTER(classname, machines, function) \ | ||
REGISTER_COREML_CONVERTER(ConverterFactory::instance(), classname, machines, function) | ||
} | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#include "CoreMLModel.h" | ||
|
||
#include <fstream> | ||
|
||
#include <google/protobuf/io/zero_copy_stream_impl.h> | ||
|
||
#include "Model.pb.h" | ||
|
||
using namespace shogun::coreml; | ||
|
||
CoreMLModel::CoreMLModel(): | ||
m_spec(std::make_shared<CoreML::Specification::Model>()) | ||
{ | ||
} | ||
|
||
CoreMLModel::~CoreMLModel() | ||
{ | ||
m_spec.reset(); | ||
} | ||
|
||
void CoreMLModel::save(const std::string& filename) const | ||
{ | ||
std::fstream out(filename, std::ios::binary | std::ios::out); | ||
this->save(out); | ||
out.close(); | ||
} | ||
|
||
void CoreMLModel::save(std::ostream& out) const | ||
{ | ||
::google::protobuf::io::OstreamOutputStream pb_out(&out); | ||
if (!m_spec->SerializeToZeroCopyStream(&pb_out)) | ||
throw std::runtime_error("could not save"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#ifndef __COREML_MODEL_H__ | ||
#define __COREML_MODEL_H__ | ||
|
||
#include <memory> | ||
#include <ostream> | ||
#include <string> | ||
|
||
namespace CoreML | ||
{ | ||
namespace Specification | ||
{ | ||
class Model; | ||
} | ||
} | ||
|
||
namespace shogun | ||
{ | ||
namespace coreml | ||
{ | ||
class CoreMLModel | ||
{ | ||
public: | ||
CoreMLModel(); | ||
~CoreMLModel(); | ||
|
||
void save(const std::string& filename) const; | ||
void save(std::ostream& out) const; | ||
|
||
std::shared_ptr<CoreML::Specification::Model> get_specification() const | ||
{ | ||
return m_spec; | ||
} | ||
|
||
private: | ||
std::shared_ptr<CoreML::Specification::Model> m_spec; | ||
}; | ||
} | ||
} | ||
#endif |
Oops, something went wrong.