In [21]:
import tensorflow as tf
import numpy as np
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler

n_epochs = 10000
learning_rate = 0.01

# 下载数据
housing = fetch_california_housing(data_home='./data', download_if_missing=True)


m, n = housing.data.shape
housing_data_plus_bias = np.c_[np.ones((m, 1)), housing.data]

# 用方差做数据归一化
scaler = StandardScaler().fit(housing_data_plus_bias)
scaled_housing_data_plus_bias = scaler.transform(housing_data_plus_bias)

X = tf.constant(scaled_housing_data_plus_bias, dtype=tf.float32, name="x")
Y = tf.constant(housing.target.reshape(-1,1), dtype=tf.float32, name="y")


# 初始化theta
theta = tf.Variable(tf.random_uniform([n + 1, 1], -1.0, 1.0), name="theta")



# 构建mse函数
y_pred = tf.matmul(X, theta, name="predictions")
error = y_pred - Y
mse = tf.reduce_mean(tf.square(error), name="mse")


''' 手动构建梯度下降公式 '''
# # 梯度公式
# gradient = 2 / m * tf.matmul(tf.transpose(X), error)

# # 执行梯度下降
# training_op = tf.assign(theta, theta - learning_rate * gradient)


''' 使用tensorFlow封装的梯度下降 '''
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(mse)


init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    
    # 记录MSE的值， 方便打印输出
    result = []
    
    last_mse = -1
    for epoch in range(n_epochs):

        # MSE 差值小于0.00001时停止迭代
        if last_mse >= 0 and (abs(mse.eval() - last_mse)) <= 0.00001:
            break
        last_mse = mse.eval()
        
        if epoch % 100 == 0:
            result.append(mse.eval())

        sess.run(training_op)
        
        
        
    print(result)
    print(theta.eval())



' 手动构建梯度下降公式 '

' 使用tensorFlow封装的梯度下降 '

[6.456498, 5.179294, 5.0676937, 5.0003524, 4.951445, 4.9152126, 4.888237, 4.868073, 4.852943, 4.8415394, 4.8329077, 4.8263416, 4.8213205, 4.8174644, 4.814486, 4.8121734, 4.8103704, 4.8089542]
[[ 0.3392315 ]
 [ 0.87871593]
 [ 0.14141214]
 [-0.33312058]
 [ 0.3505228 ]
 [ 0.00310326]
 [-0.04229001]
 [-0.6885941 ]
 [-0.66374636]]
