Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added tensor constructors and get_raw_data to work with pointers of p… #202

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/cppflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace cppflow {
FROZEN_GRAPH,
};

model() = default;
explicit model(const std::string& filename, const TYPE type=TYPE::SAVED_MODEL);

std::vector<std::string> get_operations() const;
Expand Down
62 changes: 57 additions & 5 deletions include/cppflow/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ namespace cppflow {
template<typename T>
tensor(const std::vector<T>& values, const std::vector<int64_t>& shape);

/**
* Creates a flat tensor with the given values, and specified length and shape
* @tparam T A type that can be convertible into a tensor
* @param values The values to be converted
* @param len The length of the converted tensor
* @param shape The shape of the converted tensor
*/
template<typename T>
tensor(T *values, size_t len, const std::vector<int64_t>& shape);

/**
* Creates a flat tensor with the given values
* @tparam T A type that can be convertible into a tensor
Expand Down Expand Up @@ -69,6 +79,23 @@ namespace cppflow {
*/
datatype dtype() const;

/**
* Converts the tensor into a pointer of primitive type T
* @tparam T The c++ type (must be equivalent to the tensor type)
* @return A pointer of type T representing the flat tensor
*/
template<typename T>
T *get_raw_data() const;

/**
* Converts the tensor into a pointer of primitive type T
* @tparam T The c++ type (must be equivalent to the tensor type)
* @return A pointer of type T representing the flat tensor
* @return The size of the array
*/
template<typename T>
T *get_raw_data(size_t &size) const;

/**
* Converts the tensor into a C++ vector
* @tparam T The c++ type (must be equivalent to the tensor type)
Expand All @@ -77,7 +104,6 @@ namespace cppflow {
template<typename T>
std::vector<T> get_data() const;


~tensor() = default;
tensor(const tensor &tensor) = default;
tensor(tensor &&tensor) = default;
Expand Down Expand Up @@ -140,6 +166,11 @@ namespace cppflow {
tensor::tensor(const std::vector<T>& values, const std::vector<int64_t>& shape) :
tensor(deduce_tf_type<T>(), values.data(), values.size() * sizeof(T), shape) {}


template<typename T>
tensor::tensor(T *values, size_t len, const std::vector<int64_t>& shape) :
tensor(deduce_tf_type<T>(), values, len * sizeof(T), shape) {}

template<typename T>
tensor::tensor(const std::initializer_list<T>& values) :
tensor(std::vector<T>(values), {(int64_t) values.size()}) {}
Expand Down Expand Up @@ -213,28 +244,49 @@ namespace cppflow {
return res;
}


template<typename T>
std::vector<T> tensor::get_data() const {
T *tensor::get_raw_data(size_t &size) const {

// Check if asked datatype and tensor datatype match
if (this->dtype() != deduce_tf_type<T>()) {
auto type1 = cppflow::to_string(deduce_tf_type<T>());
auto type2 = cppflow::to_string(this->dtype());
auto error = "Datatype in function get_data (" + type1 + ") does not match tensor datatype (" + type2 + ")";
auto error = "Datatype in function get_raw_data (" + type1 + ") does not match tensor datatype (" + type2 + ")";
throw std::runtime_error(error);
}


auto res_tensor = get_tensor();

// Check tensor data is not empty
auto raw_data = TF_TensorData(res_tensor.get());
//this->error_check(raw_data != nullptr, "Tensor data is empty");

size_t size = TF_TensorByteSize(res_tensor.get()) / TF_DataTypeSize(TF_TensorType(res_tensor.get()));
// Get size of array
size = TF_TensorByteSize(res_tensor.get()) / TF_DataTypeSize(TF_TensorType(res_tensor.get()));

// Convert to correct type
const auto T_data = static_cast<T*>(raw_data);

return T_data;
}

template<typename T>
T *tensor::get_raw_data() const {

// Get the raw data and return
size_t size = 0;
const auto T_data = this->get_raw_data<T>(size);

return T_data;
}

template<typename T>
std::vector<T> tensor::get_data() const {

// Get the raw data and size of array
size_t size = 0;
const auto T_data = this->get_raw_data<T>(size);
std::vector<T> r(T_data, T_data + size);

return r;
Expand Down