/
momentum_sgd.rs
118 lines (112 loc) · 3.64 KB
/
momentum_sgd.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
//! Momentum SGD optimizer
use crate::optimizers::Optimizer;
use crate::tensor::Tensor;
use crate::tensor_ops::gradient_descent_ops::sgd;
use crate::variable::VariableID;
use crate::{Context, Float, VariableEnvironment};
/// Momentum gradient descent optimizer
///
/// Use `ag::tensor_ops::gradient_descent` for the banilla sgd.
///
/// ```
/// use autograd as ag;
/// use ag::prelude::*;
/// use ag::optimizers;
/// use ag::optimizers::MomentumSGD;
///
/// type Tensor<'g> = ag::Tensor<'g, f64>;
/// let mut env = ag::VariableEnvironment::new();
/// let opt = MomentumSGD::default("sgd", env.default_namespace().current_var_ids(), &mut env);
///
/// env.run(|g| {
/// let p = g.placeholder("p", &[]);
///
/// let mut feeder = ag::Feeder::new();
/// let feed = ag::ndarray::arr0(2.);
/// feeder.push(p, feed.view());
///
/// let (params, grads): (&[Tensor], &[Tensor]) = (&[], &[]); // dummy here
/// opt.update(params, grads, g, feeder); // do parameter update
/// });
/// ```
pub struct MomentumSGD<F> {
pub alpha: F,
pub momentum: F,
pub momentum_sgd_namespace_id: &'static str,
}
impl<'t, 'g, F: Float> MomentumSGD<F> {
/// Instantiates `MomentumSGD` optimizer with the recommended parameters.
pub fn default(
unique_namespace_id: &'static str,
var_id_list: impl IntoIterator<Item = VariableID>,
env_handle: &mut VariableEnvironment<F>,
) -> MomentumSGD<F> {
MomentumSGD::new(
F::from(0.01).unwrap(),
F::from(0.9).unwrap(),
var_id_list,
env_handle,
unique_namespace_id,
)
}
/// Instantiates `MomentumSGD` optimizer with given params.
pub fn new(
alpha: F,
momentum: F,
var_id_list: impl IntoIterator<Item = VariableID>,
env: &mut VariableEnvironment<F>,
momentum_sgd_namespace_id: &'static str,
) -> MomentumSGD<F> {
for vid in var_id_list.into_iter() {
let v_name = format!("{}", vid);
let v = {
let target_var = env
.get_array_by_id(vid)
.expect("variable array not found")
.borrow();
let var_shape = target_var.shape();
crate::ndarray_ext::zeros(var_shape)
};
let mut ns = env.namespace_mut(momentum_sgd_namespace_id);
ns.slot().name(v_name).set(v);
}
MomentumSGD {
alpha,
momentum,
momentum_sgd_namespace_id,
}
}
}
impl<F: Float> Optimizer<F> for MomentumSGD<F> {
fn compute_updates<'g, A, B>(
&self,
params: &[A],
grads: &[B],
g: &'g Context<F>,
) -> Vec<Tensor<'g, F>>
where
A: AsRef<Tensor<'g, F>> + Copy,
B: AsRef<Tensor<'g, F>> + Copy,
{
let num_params = params.len();
assert_eq!(num_params, grads.len());
let mut ret = Vec::with_capacity(num_params);
for i in 0..num_params {
let param = params[i].as_ref();
let namespace = g.env().namespace(self.momentum_sgd_namespace_id);
let var_id = param.get_variable_id().expect("Got non-variable tensor");
let v = g.variable_by_name(&format!("{}", var_id), &namespace);
ret.push(
Tensor::builder(g)
.append_input(param, true)
.append_input(grads[i].as_ref(), false)
.append_input(&v, true)
.build(sgd::MomentumSGDOp {
lr: self.alpha,
momentum: self.momentum,
}),
);
}
ret
}
}