-
Notifications
You must be signed in to change notification settings - Fork 1.7k
optim
wangzhaode edited this page Feb 16, 2023
·
1 revision
module optim
optim时优化器模块,提供了一个优化器基类Optimizer
,并提供了SGD
和ADAM
优化器实现;主要用于训练阶段迭代优化
优化器的正则化方法,提供了L1和L2正则化方法
- 类型:
Enum
- 枚举值:
L1
L2
L1L2
创建一个SGD优化器
参数:
-
module:_Module
模型实例 -
lr:float
学习率 -
momentum:float
动量,默认为0.9 -
weight_decay:float
权重衰减,默认为0.0 -
regularization_method:RegularizationMethod
正则化方法,默认为L2正则化
返回:SGD优化器实例
返回类型:Optimizer
示例:
model = Net()
sgd = optim.SGD(model, 0.001, 0.9, 0.0005, optim.Regularization_Method.L2)
# feed some date to the model, then get the loss
loss = ...
sgd.step(loss) # backward and update parameters in the model
创建一个ADAM优化器
参数:
-
module:_Module
模型实例 -
lr:float
学习率 -
momentum:float
动量,默认为0.9 -
momentum2:float
动量2,默认为0.999 -
weight_decay:float
权重衰减,默认为0.0 -
eps:float
正则化阈值,默认为1e-8 -
regularization_method:RegularizationMethod
正则化方法,默认为L2正则化
返回:ADAM优化器实例
返回类型:Optimizer
示例:
model = Net()
sgd = optim.ADAM(model, 0.001)
# feed some date to the model, then get the loss
loss = ...
sgd.step(loss) # backward and update parameters in the model