Skip to content

Commit

Permalink
Add PyTorchPredictorContainer (#15899)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #15899

Add PyTorchPredictorContainer to support multiple jit script modules

Reviewed By: pritamdamania87

Differential Revision: D13596139

fbshipit-source-id: 3ce0bdf2f4dbba7aa1d20e824d03e5ac98f5d887
  • Loading branch information
houseroad authored and facebook-github-bot committed Jan 15, 2019
1 parent 1065e7c commit b329e03
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 25 deletions.
4 changes: 2 additions & 2 deletions caffe2/serialize/file_adapter.h
Expand Up @@ -3,14 +3,14 @@
#include <fstream>
#include <memory>

#include <c10/macros/Macros.h>
#include "c10/macros/Macros.h"
#include "caffe2/serialize/istream_adapter.h"
#include "caffe2/serialize/read_adapter_interface.h"

namespace caffe2 {
namespace serialize {

class FileAdapter final : public ReadAdapterInterface {
class CAFFE2_API FileAdapter final : public ReadAdapterInterface {
public:
C10_DISABLE_COPY_AND_ASSIGN(FileAdapter);
explicit FileAdapter(const std::string& file_name);
Expand Down
5 changes: 2 additions & 3 deletions caffe2/serialize/istream_adapter.h
Expand Up @@ -2,15 +2,14 @@

#include <istream>

#include <c10/macros/Macros.h>

#include "c10/macros/Macros.h"
#include "caffe2/serialize/read_adapter_interface.h"

namespace caffe2 {
namespace serialize {

// this is a reader implemented by std::istream
class IStreamAdapter final : public ReadAdapterInterface {
class CAFFE2_API IStreamAdapter final : public ReadAdapterInterface {
public:
C10_DISABLE_COPY_AND_ASSIGN(IStreamAdapter);
explicit IStreamAdapter(std::istream* istream);
Expand Down
4 changes: 3 additions & 1 deletion caffe2/serialize/read_adapter_interface.h
Expand Up @@ -3,13 +3,15 @@
#include <cstddef>
#include <cstdint>

#include "c10/macros/Macros.h"

namespace caffe2 {
namespace serialize {

// this is the interface for the (file/stream/memory) reader in
// PyTorchStreamReader. with this interface, we can extend the support
// besides standard istream
class ReadAdapterInterface {
class CAFFE2_API ReadAdapterInterface {
public:
virtual size_t size() const = 0;
virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
Expand Down
60 changes: 41 additions & 19 deletions torch/csrc/jit/import.cpp
Expand Up @@ -8,10 +8,13 @@
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/utils/functional.h>

#include <caffe2/core/types.h>
#include <caffe2/proto/caffe2_pb.h>
#include <caffe2/proto/torch_pb.h>
#include <caffe2/serialize/inline_container.h>
#include "caffe2/core/common.h"
#include "caffe2/core/types.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/proto/torch_pb.h"
#include "caffe2/serialize/file_adapter.h"
#include "caffe2/serialize/inline_container.h"
#include "caffe2/serialize/istream_adapter.h"

#include <ATen/ATen.h>

Expand All @@ -23,6 +26,10 @@
namespace torch {
namespace jit {

using caffe2::serialize::ReadAdapterInterface;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::FileAdapter;

namespace {

// this is a deserializer class which loads script modules from pt files. the
Expand All @@ -34,9 +41,8 @@ namespace {
class ScriptModuleDeserializer final {
public:
ScriptModuleDeserializer(const std::string& filename);

ScriptModuleDeserializer(std::istream* is);

explicit ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai);
void deserialize(
ModuleLookup module_lookup,
c10::optional<at::Device> device);
Expand Down Expand Up @@ -68,6 +74,9 @@ ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename)
ScriptModuleDeserializer::ScriptModuleDeserializer(std::istream* is)
: reader_(is) {}

ScriptModuleDeserializer::ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai)
: reader_(std::move(rai)) {}

void ScriptModuleDeserializer::deserialize(
ModuleLookup module_lookup,
c10::optional<at::Device> device) {
Expand Down Expand Up @@ -229,9 +238,34 @@ void import_ir_module(
deserializer.deserialize(module_lookup, device);
}

void import_ir_module(
ModuleLookup module_lookup,
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<at::Device> device) {
ScriptModuleDeserializer deserializer(std::move(rai));
deserializer.deserialize(module_lookup, device);
}

std::shared_ptr<script::Module> load(
std::istream& in,
c10::optional<at::Device> device) {
std::unique_ptr<IStreamAdapter> rai =
caffe2::make_unique<IStreamAdapter>(&in);
auto module = load(std::move(rai), device);
return module;
}

std::shared_ptr<script::Module> load(
const std::string& filename,
c10::optional<at::Device> device) {
std::unique_ptr<FileAdapter> rai = caffe2::make_unique<FileAdapter>(filename);
auto module = load(std::move(rai), device);
return module;
}

std::shared_ptr<script::Module> load(
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device) {
auto module = std::make_shared<script::Module>();

auto module_lookup = [&](const std::vector<std::string>& qualified_name) {
Expand All @@ -245,23 +279,11 @@ std::shared_ptr<script::Module> load(
return curr;
};

ScriptModuleDeserializer deserializer(&in);
ScriptModuleDeserializer deserializer(std::move(rai));
deserializer.deserialize(module_lookup, device);

return module;
}

std::shared_ptr<script::Module> load(
const std::string& filename,
c10::optional<at::Device> device) {
std::ifstream in(filename, std::ios_base::binary);

AT_CHECK(!in.fail(), "load: could not open file ", filename);

auto module = load(in, device);

return module;
}

} // namespace jit
} // namespace torch
21 changes: 21 additions & 0 deletions torch/csrc/jit/import.h
Expand Up @@ -5,6 +5,12 @@

#include <istream>

namespace caffe2 {
namespace serialize {
class ReadAdapterInterface;
} // namespace serialize
} // namespace caffe2

namespace torch {
namespace jit {

Expand All @@ -21,6 +27,11 @@ TORCH_API void import_ir_module(
std::istream& in,
c10::optional<c10::Device> device = c10::nullopt);

TORCH_API void import_ir_module(
ModuleLookup module_lookup,
std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
c10::optional<c10::Device> device = c10::nullopt);

/// Loads a serialized `script::Module` from the given `istream`.
///
/// The istream must contain a serialized `script::Module`, exported via
Expand All @@ -38,5 +49,15 @@ TORCH_API std::shared_ptr<script::Module> load(
const std::string& filename,
c10::optional<c10::Device> device = c10::nullopt);

/// Loads a serialized `script::Module` from the given `rai`.
///
/// The reader adapter, which is for customized input stream, must contain a
/// serialized `script::Module`, exported either via `ScriptModule.save()` in
/// Python or `torch::jit::ExportModule` in C++.
TORCH_API std::shared_ptr<script::Module> load(
std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
c10::optional<c10::Device> device = c10::nullopt);


} // namespace jit
} // namespace torch

0 comments on commit b329e03

Please sign in to comment.