# How to use an optimizer
To use torch.optim you have to construct an optimizer object, that will hold the current state and will update the parameters based on the computed gradients.

In [1]:
#pragma cling add_include_path("../../libtorch/include")
#pragma cling add_include_path("../../libtorch/include/torch/csrc/api/include")
#pragma cling add_library_path("../../libtorch/lib")
#pragma cling load("libtorch")

In [2]:
#include <iostream>
#include <stdexcept>
#include <string>
#include <tuple>
#include <memory>
#include <utility>
#include <torch/torch.h>
#include <c10/util/flat_hash_map.h>
namespace nn = torch::nn;

note: the grad of a parameter can be overwritten

In [3]:
torch::Tensor w = torch::tensor({1.0, 2.0}, torch::requires_grad());
std::cout << w << std::endl;
std::cout << w.requires_grad() << std::endl;

 1
 2
[ CPUFloatType{2} ]
1


In [4]:
w.sum().backward();

In [5]:
std::cout << w.grad();

 1
 1
[ CPUFloatType{2} ]

In [6]:
auto option = torch::optim::SGDOptions(0.1);

In [7]:
torch::optim::SGD oprimizer({w}, option);

In [8]:
oprimizer.step();

In [9]:
std::cout << w << std::endl;

 0.9000
 1.9000
[ CPUFloatType{2} ]


## understand SGD

In [10]:
class CustomSGDOptimizer{
    public:
    torch::Tensor param;
    float lr;
    
    CustomSGDOptimizer(torch::Tensor param, float lr):param(param), lr(lr){
    }
    
    void step(){
        {
            torch::NoGradGuard no_grad;
            param.data().add_(param.grad(), -1 * lr);
        }
    }
}

In [11]:
torch::Tensor custom_w = torch::tensor({1.0, 2.0}, torch::requires_grad());
std::cout << custom_w << std::endl;
std::cout << custom_w.requires_grad() << std::endl;

 1
 2
[ CPUFloatType{2} ]
1


In [12]:
custom_w.sum().backward();

In [13]:
std::cout << custom_w.grad();

 1
 1
[ CPUFloatType{2} ]

In [14]:
CustomSGDOptimizer custom_optimizer(custom_w, 0.1);

In [15]:
custom_optimizer.step();

In [16]:
std::cout << custom_w << std::endl;

 0.9000
 1.9000
[ CPUFloatType{2} ]


# understand Optimizer source code
https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/optim/optimizer.h

### 1 OptimizerParamState

In [17]:
class OptimizerParamState {
 public:
  virtual std::unique_ptr<OptimizerParamState> clone() const;
  virtual void serialize(torch::serialize::InputArchive& archive);
  virtual void serialize(torch::serialize::OutputArchive& archive) const;
  virtual ~OptimizerParamState() = default;
};

In [19]:
std::unique_ptr<OptimizerParamState> OptimizerParamState::clone() const {
      throw std::runtime_error("clone() has not been implemented for torch::optim::OptimizerParamState. ");
};

In [20]:
void OptimizerParamState::serialize(torch::serialize::InputArchive& archive) {
  throw std::runtime_error("void serialize(torch::serialize::InputArchive& archive) has not been implemented for torch::optim::OptimizerParamState. ");
};

In [21]:
void OptimizerParamState::serialize(torch::serialize::OutputArchive& archive) const {
  throw std::runtime_error("void serialize(torch::serialize::OutputArchive& archive) has not been implemented for torch::optim::OptimizerParamState. ");
};

In [None]:
template <typename Derived>
class OptimizerCloneableParamState : public OptimizerParamState {
  std::unique_ptr<OptimizerParamState> clone() const override {
    return std::make_unique<Derived>(static_cast<const Derived&>(*this));
  }
};

### 2 OptimizerOptions

In [22]:
class OptimizerOptions {
 public:
  virtual std::unique_ptr<OptimizerOptions> clone() const;
  virtual void serialize(torch::serialize::InputArchive& archive);
  virtual void serialize(torch::serialize::OutputArchive& archive) const;
  virtual ~OptimizerOptions() = default;
  virtual double get_lr() const;
  virtual void set_lr(const double lr);
};

In [23]:
template <typename Derived>
class OptimizerCloneableOptions : public OptimizerOptions {
 private:
  std::unique_ptr<OptimizerOptions> clone() const override {
    return std::make_unique<Derived>(static_cast<const Derived&>(*this));
  }
};

In [24]:
double OptimizerOptions::get_lr() const {
  throw std::runtime_error("double get_lr() has not been overidden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass.");
};

In [25]:
void OptimizerOptions::set_lr(const double lr) {
  throw std::runtime_error("double set_lr() has not been overidden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass.");
};

In [26]:
std::unique_ptr<OptimizerOptions> OptimizerOptions::clone() const {
  throw std::runtime_error("clone() has not been implemented for torch::optim::OptimizerOptions. ");
}

In [27]:
void OptimizerOptions::serialize(torch::serialize::InputArchive& archive) {
  throw std::runtime_error("void serialize(torch::serialize::InputArchive& archive) has not been implemented for torch::optim::OptimizerOptions. ");
};

In [28]:
void OptimizerOptions::serialize(torch::serialize::OutputArchive& archive) const {
throw std::runtime_error("void serialize(torch::serialize::OutputArchive& archive) has not been implemented for torch::optim::OptimizerOptions. ");
};

### 3 OptimizerParamGroup

In [29]:
class OptimizerParamGroup {
  protected:
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::vector<torch::Tensor> params_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::unique_ptr<OptimizerOptions> options_;

 public:
  // NOTE: In order to store `OptimizerParamGroup` in a `std::vector`, it has to
  // be copy-constructible.
  OptimizerParamGroup(const OptimizerParamGroup& param_group)
      : params_(param_group.params()),
        options_(
            param_group.has_options() ? param_group.options().clone()
                                      : nullptr) {}
  OptimizerParamGroup(std::vector<torch::Tensor> params)
      : params_(std::move(params)) {}
    
  OptimizerParamGroup(
      std::vector<torch::Tensor> params,
      std::unique_ptr<OptimizerOptions> options)
      : params_(std::move(params)), options_(std::move(options)) {}

  bool has_options() const;
  OptimizerOptions& options();
  const OptimizerOptions& options() const;
  void set_options(std::unique_ptr<OptimizerOptions> options);
    
  std::vector<torch::Tensor>& params();
  const std::vector<torch::Tensor>& params() const{
      return params_;
    }
};

In [30]:
bool OptimizerParamGroup::has_options() const {
  return options_ != nullptr;
};

In [31]:
OptimizerOptions& OptimizerParamGroup::options() {
  return *options_.get();
};

In [32]:
const OptimizerOptions& OptimizerParamGroup::options() const {
  return *options_.get();
};

In [33]:
void OptimizerParamGroup::set_options(std::unique_ptr<OptimizerOptions> options) {
  options_ = std::move(options);
};

In [34]:
std::vector<torch::Tensor>& OptimizerParamGroup::params() {
  return params_;
};

In [35]:
//const std::vector<torch::Tensor>& OptimizerParamGroup::params() const {
//  return params_;
//};

### 4 Optimizer

In [36]:
class Optimizer {
  protected:
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::vector<OptimizerParamGroup> param_groups_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>> state_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::unique_ptr<OptimizerOptions> defaults_;

    
 public:
  // The copy constructor is deleted, because the user should use the
  // `state_dict` / `load_state_dict` API to copy an optimizer instead.
  Optimizer(const Optimizer& optimizer) = delete;
  Optimizer(Optimizer&& optimizer) = default;

  explicit Optimizer(
      std::vector<OptimizerParamGroup> param_groups,
      std::unique_ptr<OptimizerOptions> defaults)
      : defaults_(std::move(defaults)) {
    for (const auto& param_group : param_groups) {
      add_param_group(param_group);
    }
  }

  /// Constructs the `Optimizer` from a vector of parameters.
  // NOLINTNEXTLINE(performance-move-const-arg)
  explicit Optimizer(
      std::vector<torch::Tensor> parameters,
      std::unique_ptr<OptimizerOptions> defaults)
      : Optimizer(
            {std::move(OptimizerParamGroup(parameters))},
            std::move(defaults)){};

  /// Adds the given param_group to the optimizer's param_group list.
  void add_param_group(const OptimizerParamGroup& param_group);
  /// Adds the given vector of parameters to the optimizer's parameter list.
  void add_parameters(const std::vector<torch::Tensor>& parameters);

    
  virtual ~Optimizer() = default;

  using LossClosure = std::function<torch::Tensor()>;
  /// A loss function closure, which is expected to return the loss value.
  virtual torch::Tensor step(LossClosure closure = nullptr) = 0;


  /// Zeros out the gradients of all parameters.
  void zero_grad();

  /// Provides a reference to the parameters in the first param_group this
  /// optimizer holds.
  std::vector<torch::Tensor>& parameters();
  const std::vector<torch::Tensor>& parameters() const;
    
  /// Returns the number of parameters referenced by the optimizer.
  size_t size() const;

  OptimizerOptions& defaults();
  const OptimizerOptions& defaults() const;

  /// Provides a reference to the param_groups this optimizer holds.
  std::vector<OptimizerParamGroup>& param_groups();
  /// Provides a const reference to the param_groups this optimizer holds.
  const std::vector<OptimizerParamGroup>& param_groups() const;
    
  /// Provides a reference to the state this optimizer holds
  ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>& state();

  /// Serializes the optimizer state into the given `archive`.
  virtual void save(torch::serialize::OutputArchive& archive) const;

  /// Deserializes the optimizer state from the given `archive`.
  virtual void load(torch::serialize::InputArchive& archive);

};

In [37]:
void Optimizer::add_param_group(const OptimizerParamGroup& param_group) {
  for (const auto& param : param_group.params()) {
    if(param.is_leaf()){
        std::cout << "can't optimize a non-leaf Tensor" << std::endl;
    }
  }
  
  OptimizerParamGroup param_group_(param_group.params());
  if (!param_group.has_options()) {
    param_group_.set_options(defaults_->clone());
  } else {
    param_group_.set_options(param_group.options().clone());
  }

  param_groups_.emplace_back(std::move(param_group_));
}

In [38]:
void Optimizer::add_parameters(const std::vector<torch::Tensor>& parameters) {
  auto& parameters_ = param_groups_[0].params();
  parameters_.insert(parameters_.end(), parameters.begin(), parameters.end());
};

In [39]:
void Optimizer::zero_grad() {
  for (auto& group : param_groups_) {
    for (auto& p : group.params()) {
      if (p.grad().defined()) {
        p.grad().detach_();
        p.grad().zero_();
      }
    }
  }
};

In [40]:
std::vector<torch::Tensor>& Optimizer::parameters() {
  return param_groups_.at(0).params();
};

In [41]:
const std::vector<torch::Tensor>& Optimizer::parameters() const{
  return param_groups_.at(0).params();
};

[1minput_line_49:2:65: [0m[0;1;31merror: [0m[1mfunction definition is not allowed here[0m
 const std::vector<torch::Tensor>& Optimizer::parameters() const{
[0;1;32m                                                                ^
[0m

Interpreter Error: 

In [42]:
size_t Optimizer::size() const {
  size_t count = 0;
  for (const auto& group : param_groups_) {
    count += group.params().size();
  }
  return count;
}

In [43]:
OptimizerOptions& Optimizer::defaults() {
  return *defaults_.get();
};

In [44]:
const OptimizerOptions& Optimizer::defaults() const {
  return *defaults_.get();
};

In [45]:
std::vector<OptimizerParamGroup>& Optimizer::param_groups() {
  return param_groups_;
};

In [46]:
const std::vector<OptimizerParamGroup>& Optimizer::param_groups() const {
  return param_groups_;
};

[1minput_line_54:2:74: [0m[0;1;31merror: [0m[1mfunction definition is not allowed here[0m
 const std::vector<OptimizerParamGroup>& Optimizer::param_groups() const {
[0;1;32m                                                                         ^
[0m

Interpreter Error: 

In [47]:
ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>& Optimizer::state(){
  return state_;
}

[1minput_line_55:2:91: [0m[0;1;31merror: [0m[1mfunction definition is not allowed here[0m
  ...std::unique_ptr<OptimizerParamState>>& Optimizer::state(){
[0;1;32m                                                              ^
[0m

Interpreter Error: 

In [48]:
const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>& Optimizer::state() const {
  return state_;
}

[1minput_line_56:2:104: [0m[0;1;31merror: [0m[1mfunction definition is not allowed here[0m
  ...std::unique_ptr<OptimizerParamState>>& Optimizer::state() const {
[0;1;32m                                                                     ^
[0m

Interpreter Error: 

In [49]:
void Optimizer::save(torch::serialize::OutputArchive& archive) const {};

In [50]:
void Optimizer::load(torch::serialize::InputArchive& archive) {};

### 5 SGD

In [None]:
#define TORCH_ARG(T, name)                                              \
 public:                                                                \
  inline auto name(const T& new_##name)->decltype(*this) { /* NOLINT */ \
    this->name##_ = new_##name;                                         \
    return *this;                                                       \
  }                                                                     \
  inline auto name(T&& new_##name)->decltype(*this) { /* NOLINT */      \
    this->name##_ = std::move(new_##name);                              \
    return *this;                                                       \
  }                                                                     \
  inline const T& name() const noexcept { /* NOLINT */                  \
    return this->name##_;                                               \
  }                                                                     \
  inline T& name() noexcept { /* NOLINT */                              \
    return this->name##_;                                               \
  }                                                                     \
                                                                        \
 private:                                                               \
  T name##_ /* NOLINT */

In [52]:
class SGDOptions : public OptimizerCloneableOptions<SGDOptions> {
  SGDOptions(double lr);
  TORCH_ARG(double, lr);

　public:
   void serialize(torch::serialize::InputArchive& archive) override;
   void serialize(torch::serialize::OutputArchive& archive) const override;
   ~SGDOptions() override = default;
   double get_lr() const override;
   void set_lr(const double lr) override;
};

　public:
[0;1;32m^~
[0m

In [55]:
SGDOptions::SGDOptions(double lr) : lr_(lr) {};

[1minput_line_63:1:13: [0m[0;1;31merror: [0m[1mredefinition of 'SGDOptions'[0m
SGDOptions::SGDOptions(double lr) : lr_(lr) {};
[0;1;32m            ^
[0m[1minput_line_61:1:13: [0m[0;1;30mnote: [0mprevious definition is here[0m
SGDOptions::SGDOptions(double lr) : lr_(lr) {};
[0;1;32m            ^
[0m

Interpreter Error: 

In [56]:
double SGDOptions::get_lr() const {
  return lr();
};

In [57]:
void SGDOptions::set_lr(const double lr) {
  this->lr(lr);
};

In [58]:
void SGDOptions::serialize(torch::serialize::OutputArchive& archive) const {
  //_TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
};

In [59]:
void SGDOptions::serialize(torch::serialize::InputArchive& archive) {
  //_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
};

In [60]:
class SGDParamState: public OptimizerCloneableParamState<SGDParamState> {
  TORCH_ARG(torch::Tensor, momentum_buffer);

 public:
  void serialize(torch::serialize::InputArchive& archive) override;
  void serialize(torch::serialize::OutputArchive& archive) const override;
  ~SGDParamState() override = default;
};

In [61]:
void SGDParamState::serialize(torch::serialize::OutputArchive& archive) const {
  //_TORCH_OPTIM_SERIALIZE_TORCH_ARG(momentum_buffer);
};

In [62]:
void SGDParamState::serialize(torch::serialize::InputArchive& archive) {
  //_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, momentum_buffer);
}

In [63]:
class SGD : public Optimizer {
 public:
  explicit SGD(
      std::vector<OptimizerParamGroup> param_groups,
      SGDOptions defaults)
      : Optimizer(
            std::move(param_groups),
            std::make_unique<SGDOptions>(defaults)) {
      }

  explicit SGD(
      std::vector<torch::Tensor> params,
      // NOLINTNEXTLINE(performance-move-const-arg)
      SGDOptions defaults)
      : SGD({std::move(OptimizerParamGroup(params))}, defaults) {}

  torch::Tensor step(LossClosure closure = nullptr) override;

  void save(torch::serialize::OutputArchive& archive) const override;
  void load(torch::serialize::InputArchive& archive) override;

 private:
  template <typename Self, typename Archive>
  static void serialize(Self& self, Archive& archive) {
  }
};

In [65]:
torch::Tensor SGD::step(LossClosure closure) {
  torch::NoGradGuard no_grad;
  torch::Tensor loss = {};
  if (closure != nullptr) {
    at::AutoGradMode enable_grad(true);
    loss = closure();
  }
  for (auto& group : param_groups_) {
    auto& options = static_cast<SGDOptions&>(group.options());

    for (auto& p : group.params()) {
      if (!p.grad().defined()) {
        continue;
      }
      auto d_p = p.grad().data();
      p.data().add_(d_p, -1 * options.lr());
    }
  }
  return loss;
}

[1minput_line_73:6:12: [0m[0;1;31merror: [0m[1mtype '__cling_N535::Optimizer::LossClosure' (aka 'function<torch::Tensor ()>')
      does not provide a call operator[0m
    loss = closure();
[0;1;32m           ^~~~~~~
[0m

Interpreter Error: 

In [66]:
void SGD::save(torch::serialize::OutputArchive& archive) const {
  serialize(*this, archive);
};

In [67]:
void SGD::load(torch::serialize::InputArchive& archive) {
  torch::IValue pytorch_version;
  if (archive.try_read("pytorch_version", pytorch_version)) {
    serialize(*this, archive);
  }
}