In [1]:
#pragma cling add_include_path("../../libtorch/include")
#pragma cling add_include_path("../../libtorch/include/torch/csrc/api/include")
#pragma cling add_library_path("../../libtorch/lib")
#pragma cling load("libtorch")

In [2]:
#include <iostream>
#include <tuple>
#include <string>
#include <vector>
#include <memory>
#include <functional>
#include <type_traits>
#include <torch/torch.h>
#include <torch/script.h>
namespace nn = torch::nn;

In [3]:
std::string join_name(const std::string& name_prefix, const std::string& name) {
  size_t total_size = name.size();
  if (!name_prefix.empty()) {
    total_size += name_prefix.size() + 1;
  }
  std::string full_name;
  full_name.reserve(total_size);
  if (!name_prefix.empty()) {
    full_name += name_prefix;
    full_name.push_back('.');
  }
  full_name += name;
  return full_name;
};

In [4]:
class Module{
    public:
        using NamedModuleApplyFunction = std::function<void(const std::string&, Module&)>;    
        using ConstNamedModuleApplyFunction = std::function<void(const std::string&, const Module&)>;
        using NamedModulePointerApplyFunction = std::function<void(const std::string&, const std::shared_ptr<Module>&)>;

         /// Whether the module is in training mode.
        bool is_training{true};
    
    public:
    torch::OrderedDict<std::string, torch::Tensor> parameters_;
    torch::OrderedDict<std::string, std::shared_ptr<Module>> children_;
    
    public:
    torch::Tensor& register_parameter(std::string name, torch::Tensor tensor, bool requires_grad=true){
          if (!tensor.defined()) {
            if (requires_grad) {
                std::cout << "An undefined tensor cannot require grad. ";
            }
          } else {
            tensor.set_requires_grad(requires_grad);
          }
          return parameters_.insert(std::move(name), std::move(tensor));
    }
    
    template <typename ModuleType>
    std::shared_ptr<ModuleType> register_module(std::string name, std::shared_ptr<ModuleType> module){
        auto& base_module = children_.insert(std::move(name), std::move(module)); //std::move
        return std::dynamic_pointer_cast<ModuleType>(base_module);
    }
    
    // -------------------------------------------------------
    torch::OrderedDict<std::string, torch::Tensor> named_parameters(bool recurse = true) const{
        
      torch::OrderedDict<std::string, torch::Tensor> result;
      if (!recurse) {
        for (const auto& parameter : parameters_) {
              if (parameter.value().defined()) {
                result.insert(parameter.key(), parameter.value());
              }
        }
      } else {
            apply([&result](const std::string& name, const Module& module) {
                for (const auto& parameter : module.named_parameters(/*recurse=*/false)) {
                    result.insert(join_name(name, parameter.key()), parameter.value());
                }
            }); 
      }
      return result;
    }
    
    torch::OrderedDict<std::string, std::shared_ptr<Module>> named_modules(const std::string& name_prefix = std::string(), bool include_self = true) const {
      torch::OrderedDict<std::string, std::shared_ptr<Module>> result;
      if (include_self) {
         apply([&result](const std::string& key, const std::shared_ptr<Module>& module) {
              result.insert(key, module);
            }, name_prefix);
      } else {
        apply_to_submodules([&result](const std::string& key, const std::shared_ptr<Module>& module) {
              result.insert(key, module);
            }, name_prefix);
      }
      return result;
    }
        
    void apply_to_submodules(const NamedModulePointerApplyFunction& function, const std::string& name_prefix = std::string()) const {
       for (const auto& child : children_) {
            auto qualified_name = child.key();
            function(qualified_name, child.value());
            child.value()->apply_to_submodules(function, qualified_name);
        }    
    }
    
    ///   module->apply([](const std::string& key, nn::Module& module) {
    ///   std::cout << key << ": " << module.name() << std::endl;
    void apply(const NamedModuleApplyFunction& function, const std::string& name_prefix = std::string()){
        function(/*name=*/name_prefix, *this);
        apply_to_submodules([&function](const std::string& name, const std::shared_ptr<Module>& module) {
                                        function(name, *module);
                                        },
                            name_prefix);
    }
    
    ///   module->apply([](const std::string& key, const nn::Module& module) {
    ///   std::cout << key << ": " << module.name() << std::endl;
    void apply(const ConstNamedModuleApplyFunction& function, const std::string& name_prefix = std::string()) const{
          function(/*name=*/name_prefix, *this);
          apply_to_submodules([&function](const std::string& name, const std::shared_ptr<Module>& module) {
                function(name, *module);
            }, name_prefix);
    }
    
    ///   module->apply([](const std::string& key,
    ///                    const std::shared_ptr<nn::Module>& module) {
    ///     std::cout << key << ": " << module->name() << std::endl;
    ///   });
    void apply(const NamedModulePointerApplyFunction& function,const std::string& name_prefix = std::string()) const{
      function(/*name=*/name_prefix, shared_from_this_checked());
      apply_to_submodules(function, name_prefix);

    }
    
    std::shared_ptr<Module> shared_from_this_checked() const {
      std::shared_ptr<const Module> ptr{this};
      return std::const_pointer_cast<Module>(ptr);
    }
    
    virtual torch::Tensor forward(const torch::Tensor& input) = 0;
    
}

In [5]:
class LinearImpl : public Module{
    public:
    torch::Tensor weight;
    
    LinearImpl(int64_t in_features, int64_t out_features, bool is_training_=true){
        weight = register_parameter("weight", torch::empty({in_features, out_features}));
        
        is_training = is_training_;
      }
    torch::Tensor forward(const torch::Tensor& input) {
          return input.mm(weight);
    }
}

In [6]:
LinearImpl linear{3,4};

In [7]:
torch::OrderedDict<std::string, torch::Tensor> ordered_parameter_dict = linear.named_parameters();
for (const auto& pair : ordered_parameter_dict) {
  std::cout << pair.key() << ": " << pair.value() << std::endl;
}

weight:  1.7034e+25  8.5265e+08  1.3219e+19  2.0152e+34
 1.0314e-08  1.8389e+25  7.7128e+31  1.1259e+24
 6.6991e+31  4.2915e+24  7.0813e+31  7.3961e+31
[ CPUFloatType{3,4} ]


In [8]:
torch::Tensor inputs = torch::randn({2,3});

In [9]:
std::cout << inputs << std::endl;

 0.7484  0.1145  0.3855
 0.0872 -0.0677 -0.8913
[ CPUFloatType{2,3} ]


In [10]:
torch::Tensor output = linear.forward(inputs);

In [11]:
std::cout << output << std::endl;

 2.5824e+31  3.7600e+24  3.6129e+31  1.5110e+34
-5.9711e+31 -5.0694e+24 -6.8336e+31  1.6920e+33
[ CPUFloatType{2,4} ]


In [12]:
std::make_shared<LinearImpl>(3, 5);

In [13]:
std::shared_ptr<Module> linear1_instance = std::make_shared<LinearImpl>(3, 5);
torch::OrderedDict<std::string, torch::Tensor> ordered_parameter_dict = linear1_instance->named_parameters();
for (const auto& pair : ordered_parameter_dict) {
  std::cout << pair.key() << ": " << pair.value() << std::endl;
}

weight: 1e-31 *
 0.0000  0.0000  0.0000  0.0000  4.1516
  0.0000  3.4675  0.0000  0.0000  0.0000
  0.0000  0.0000  2.9759  0.0000  1.6749
[ CPUFloatType{3,5} ]


In [14]:
std::shared_ptr<LinearImpl> children_instance = std::dynamic_pointer_cast<LinearImpl>(linear1_instance);
torch::OrderedDict<std::string, torch::Tensor> ordered_parameter_dict = children_instance->named_parameters();
for (const auto& pair : ordered_parameter_dict) {
  std::cout << pair.key() << ": " << pair.value() << std::endl;
}

weight: 1e-31 *
 0.0000  0.0000  0.0000  0.0000  4.1516
  0.0000  3.4675  0.0000  0.0000  0.0000
  0.0000  0.0000  2.9759  0.0000  1.6749
[ CPUFloatType{3,5} ]


In [15]:
class CustomModule : public Module{
    public:
    std::shared_ptr<LinearImpl> linear1{nullptr};
    std::shared_ptr<LinearImpl> linear2{nullptr};
    
    CustomModule(int64_t in_features, int64_t out_features){
        std::shared_ptr<LinearImpl> linear1_instance = std::make_shared<LinearImpl>(in_features, 5);
        std::shared_ptr<LinearImpl> linear2_instance = std::make_shared<LinearImpl>(5, out_features, false);
        
        linear1 = register_module<LinearImpl>("linear1", linear1_instance);
        linear2 = register_module<LinearImpl>("linear2", linear2_instance);
    }
    
    torch::Tensor forward(const torch::Tensor& input) {
        torch::Tensor output1,output2;
        output1 =  linear1->forward(input);
        output2 = linear2 ->forward(output1);
        return output2;
    }
}

In [16]:
CustomModule custom_module{3,5};

In [17]:
torch::OrderedDict<std::string, torch::Tensor> ordered_parameter_dict = custom_module.named_parameters();
for (const auto& pair : ordered_parameter_dict) {
  std::cout << pair.key() << ": " << pair.value() << std::endl;
}

linear1.weight:  1.8389e+25  7.7128e+31  1.1259e+24  6.6991e+31  4.2915e+24
 7.0813e+31  7.3961e+31  3.6366e+03  3.7148e+21  8.4727e+11
 1.3959e+31  7.1538e+22  7.6177e+31  3.0019e-09  4.1205e+21
[ CPUFloatType{3,5} ]
linear2.weight:  4.7429e+30  2.2755e-07  4.6226e+30  8.8603e+11  1.6084e+19
 3.6157e-09  8.2177e+20  4.3064e+21  4.8418e+30  9.4794e+05
 3.5554e-09  2.6792e+20  4.8403e+30  1.9364e+31  8.4777e+11
 3.6366e+03  3.7963e+03  2.8221e+03  2.8381e+03  2.8522e+03
 1.5226e+19  2.2854e+05  1.6750e+19  6.7422e+22  1.3354e+19
[ CPUFloatType{5,5} ]


In [18]:
torch::OrderedDict<std::string, std::shared_ptr<Module>> named_modules = custom_module.named_modules();

In [19]:
for (const auto& pair : named_modules) {
  std::cout << pair.key() << ": " << pair.value()->is_training << std::endl; //
}

: 1
linear1: 1
linear2: 0


In [21]:
void print_func(const std::string& key, const Module& module) {
    std::cout << key << ": " << module.is_training << std::endl;
};

custom_module.apply(print_func);

: 1
linear1: 1
linear2: 0


CustomModule custom{3,5};

torch::OrderedDict<std::string, torch::Tensor> ordered_parameter_dict = custom.named_parameters();
for (const auto& pair : ordered_parameter_dict) {
  std::cout << pair.key() << ": " << pair.value() << std::endl;
}