Permalink
Cannot retrieve contributors at this time
Join GitHub today
GitHub is home to over 40 million developers working together to host and review code, manage projects, and build software together.
Sign up
Fetching contributors…

import matplotlib | |
matplotlib.use('TkAgg') | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import random | |
import torch | |
import torch.optim as optim | |
import utils | |
import math | |
print('载入数据') | |
features, labels = utils.get_nasa_data() | |
print('定义 AdaDelta') | |
def init_adam_states(): | |
v_w, v_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32) | |
s_w, s_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32) | |
return ((v_w, s_w), (v_b, s_b)) | |
def adam(params, states, hyperparams): | |
beta1, beta2, eps = 0.9, 0.999, 1e-6 | |
for p, (v, s) in zip(params, states): | |
v[:] = beta1 * v + (1 - beta1) * p.grad.data | |
s[:] = beta2 * s + (1 - beta2) * p.grad.data**2 | |
v_bias_corr = v / (1 - beta1 ** hyperparams['t']) | |
s_bias_corr = s / (1 - beta2 ** hyperparams['t']) | |
p.data -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr) + eps) | |
hyperparams['t'] = 1 | |
print('用 lr=0.01 来训练') | |
utils.train_opt(adam, init_adam_states(), {'lr': 0.01, 't': 1}, features, labels) | |
print('简洁实现') | |
utils.train_opt_pytorch(optim.Adam, {'lr': 0.01}, features, labels) |