In [1]:
#pragma cling add_include_path("./libtorch/include")
#pragma cling add_include_path("./libtorch/include/torch/csrc/api/include")
#pragma cling add_library_path("./libtorch/lib")
#pragma cling load("libtorch")

In [2]:
#include <iostream>
#include <tuple>
#include <string>
#include <vector>
#include <memory>
#include <type_traits>
#include <torch/torch.h>

# the relation between nn::Module, nn::ModuleHolder and nn::AnyModule

## 1.1 nn::Module

~~~
class TORCH_API Module : public std::enable_shared_from_this<Module> {
 public:
  /// Returns the parameters of this `Module` and if `recurse` is true, also
  /// recursively of every submodule.
  std::vector<Tensor> parameters(bool recurse = true) const;

  /// Returns an `OrderedDict` with the parameters of this `Module` along with
  /// their keys, and if `recurse` is true also recursively of every submodule.
  OrderedDict<std::string, Tensor> named_parameters(bool recurse = true) const;
  
  /// Registers a submodule with this `Module`.
  ///
  /// Registering a module makes it available to methods such as `modules()`,
  /// `clone()` or `to()`.
  ///
  /// \rst
  /// .. code-block:: cpp
  ///
  ///   MyModule::MyModule() {
  ///     submodule_ = register_module("linear", torch::nn::Linear(3, 4));
  ///   }
  /// \endrst
  template <typename ModuleType>
  std::shared_ptr<ModuleType> register_module(
      std::string name,
      std::shared_ptr<ModuleType> module);

  /// Registers a submodule with this `Module`.
  ///
  /// This method deals with `ModuleHolder`s.
  ///
  /// Registering a module makes it available to methods such as `modules()`,
  /// `clone()` or `to()`.
  ///
  /// \rst
  /// .. code-block:: cpp
  ///
  ///   MyModule::MyModule() {
  ///     submodule_ = register_module("linear", torch::nn::Linear(3, 4));
  ///   }
  /// \endrst
  template <typename ModuleType>
  std::shared_ptr<ModuleType> register_module(
      std::string name,
      ModuleHolder<ModuleType> module_holder);
~~~

In [3]:
class NetImpl : public torch::nn::Module {
    public:
    NetImpl(int64_t N, int64_t M) {
        W = register_parameter("W", torch::randn({N, M}));
        b = register_parameter("b", torch::randn(M));
      }
      torch::Tensor forward(torch::Tensor input) {
            return torch::addmm(b, input, W);
      }
      torch::Tensor W, b;
};

In [4]:
NetImpl model_impl{2,3};

In [5]:
torch::Tensor m = torch::randn({4,2});
torch::Tensor out = model_impl.forward(m);
std::cout << out << std::endl;

-1.5087  1.5692  0.2568
 0.5782  0.2423 -1.9050
-0.7325  1.1485 -0.5545
-4.7999  2.3417  3.7989
[ CPUFloatType{4,3} ]


## 1.2 nn::ModuleHolder

~~~
// https://github.com/pytorch/pytorch/blob/main/torch/csrc/api/include/torch/nn/pimpl.h

template <typename Contained>
class ModuleHolder : torch::detail::ModuleHolderIndicator {
 protected:
  /// The module pointer this class wraps.
  /// NOTE: Must be placed at the top of the class so that we can use it with
  /// trailing return types below.
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::shared_ptr<Contained> impl_;

 public:
  using ContainedType = Contained;

  /// Default constructs the contained module if if has a default constructor,
  /// else produces a static error.
  ///
  /// NOTE: This uses the behavior of template
  /// classes in C++ that constructors (or any methods) are only compiled when
  /// actually used.
  ModuleHolder() : impl_(default_construct()) {
    static_assert(
        std::is_default_constructible<Contained>::value,
        "You are trying to default construct a module which has "
        "no default constructor. Use = nullptr to give it the empty state "
        "(e.g. `Linear linear = nullptr;` instead of `Linear linear;`).");
  }

  /// Constructs the `ModuleHolder` with an empty contained value. Access to
  /// the underlying module is not permitted and will throw an exception, until
  /// a value is assigned.
  /* implicit */ ModuleHolder(std::nullptr_t) : impl_(nullptr) {}
  
  /// Calls the `forward()` method of the contained module.
  template <typename... Args>
  auto operator()(Args&&... args)
      -> torch::detail::return_type_of_forward_t<Contained, Args...> {
    // This will not compile if the module does not have a `forward()` method
    // (as expected).
    // NOTE: `std::forward` is qualified to prevent VS2017 emitting
    // error C2872: 'std': ambiguous symbol
    return impl_->forward(::std::forward<Args>(args)...);
  }
~~~

~~~
#define TORCH_MODULE_IMPL(Name, ImplType)                              \
  class Name : public torch::nn::ModuleHolder<ImplType> { /* NOLINT */ \
   public:                                                             \
    using torch::nn::ModuleHolder<ImplType>::ModuleHolder;             \
    using Impl TORCH_UNUSED_EXCEPT_CUDA = ImplType;                    \
  }

/// Like `TORCH_MODULE_IMPL`, but defaults the `ImplType` name to `<Name>Impl`.
#define TORCH_MODULE(Name) TORCH_MODULE_IMPL(Name, Name##Impl)
~~~

In [6]:
TORCH_MODULE(Net);

## 1.3 nn::AnyModulenn

~~~
//https://github.com/pytorch/pytorch/blob/main/torch/csrc/api/include/torch/nn/modules/container/any.h

///   torch::nn::AnyModule module(torch::nn::Linear(3, 4));
///   std::shared_ptr<nn::Module> ptr = module.ptr();
///   torch::nn::Linear linear(module.get<torch::nn::Linear>());
/// \endrst
class AnyModule {
 public:
 
 /// The type erased module.
  std::unique_ptr<AnyModulePlaceholder> content_;
  
  
  /// A default-constructed `AnyModule` is in an empty state.
  AnyModule() = default;

  /// Constructs an `AnyModule` from a `shared_ptr` to concrete module object.
  template <typename ModuleType>
  explicit AnyModule(std::shared_ptr<ModuleType> module);

  /// Constructs an `AnyModule` from a concrete module object.
  template <
      typename ModuleType,
      typename = torch::detail::enable_if_module_t<ModuleType>>
  explicit AnyModule(ModuleType&& module);

  /// Constructs an `AnyModule` from a module holder.
  template <typename ModuleType>
  explicit AnyModule(const ModuleHolder<ModuleType>& module_holder);


template <typename... ArgumentTypes>
AnyValue AnyModule::any_forward(ArgumentTypes&&... arguments) {
  TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule");
  std::vector<AnyValue> values;
  values.reserve(sizeof...(ArgumentTypes));
  torch::apply(
      [&values](AnyValue&& value) { values.push_back(std::move(value)); },
      AnyValue(std::forward<ArgumentTypes>(arguments))...);
  return content_->forward(std::move(values));
}

template <typename ReturnType, typename... ArgumentTypes>
ReturnType AnyModule::forward(ArgumentTypes&&... arguments) {
  return any_forward(std::forward<ArgumentTypes>(arguments)...)
      .template get<ReturnType>();
}
~~~

### a construct from ModuleHolder

In [7]:
torch::nn::AnyModule module(torch::nn::Linear(2, 3));

In [8]:
torch::Tensor x = torch::randn({4,2});

In [9]:
torch::Tensor output = module.forward(x);
std::cout << output << std::endl;

-0.1699  0.9128  0.3406
 1.3906 -0.3705  0.4861
-0.1416  1.3699  0.2101
 0.9539 -0.7656  0.6544
[ CPUFloatType{4,3} ]


## b construct from std::shared_ptr point to nn::Module

In [10]:
std::shared_ptr<torch::nn::LinearImpl> module_ptr(std::make_shared<torch::nn::LinearImpl>(2,3));

In [11]:
torch::nn::AnyModule module_b(module_ptr);

In [12]:
torch::Tensor x2 = torch::randn({4,2});
torch::Tensor output2 = module_b.forward(x2);
std::cout << output2 << std::endl;

 0.7532 -0.1333  0.0107
 0.7393 -0.1115  0.3020
 1.0847 -0.4567 -0.0311
 0.9058 -0.2797  0.0773
[ CPUFloatType{4,3} ]


# 2 how to perform polymorphically for nn::Module

In [13]:
torch::Tensor model_impl_process(torch::nn::AnyModule model){
    torch::Tensor x = torch::randn({4,2});
    torch::Tensor output = model.forward(x);
    std::cout << output << std::endl;
    return output;
}

In [14]:
//torch::Tensor a = model_impl_process(model_impl);
torch::Tensor a = model_impl_process(module);

 0.7436  0.6978  0.2771
 0.1955  1.3015  0.1836
 0.6660 -0.1387  0.5194
 0.9377 -1.0845  0.7449
[ CPUFloatType{4,3} ]


In [15]:
torch::Tensor b = model_impl_process(module_b);

 0.2503  0.3578  0.0910
-0.0873  0.6288 -1.9189
-0.5216  1.0904 -0.5350
 1.2251 -0.5861  0.2206
[ CPUFloatType{4,3} ]


In [23]:
torch::nn::AnyModule net_any_module(Net{2,3});
torch::Tensor c = model_impl_process(net_any_module);

-1.9709 -1.7128  2.0667
 1.1812 -2.0922  0.4920
-1.5993 -0.4468  1.3714
 2.4313 -1.3531 -0.4786
[ CPUFloatType{4,3} ]
