## 梯度下降

In [1]:
import scipy.sparse as sp
import numpy as np

定义一个凸目标函数：
$$
\min_{x} \frac{1}{2} ||Ax-b||^{2}
$$

定义一个读取数据的函数`Fun()`, 读取参数$A$和$b$, 以及计算目标函数值的`min_f`:

In [2]:
class Fun(object):
    def __init__(self, path_a, path_b):
        self.file_a_path = path_a
        self.file_b_path = path_b
        self.A, self.b = self._get_parameter()  # 获取到参数A和b的值
        self.A_T = self.A.transpose()
        self.x_init = sp.eye(self.A.shape[-1], self.b.shape[-1]).tocsr()  # 设置初始解

    def min_f(self, x):
        y = self.A * x - self.b  # 得到目标函数的值
        return np.linalg.norm(y.toarray(), ord=2)  # 得到目标函数的二范数

    def _get_parameter(self):
        self.A = sp.load_npz(self.file_a_path).tocsr()  # shape = (15935, 62061)
        self.b = sp.load_npz(self.file_b_path).tocsr()  # shape = (15935, 1)
        return self.A, self.b

梯度下降的迭代公式：

$$
x_{k+1} = x_{k} - \eta A^{T}(Ax_{k}-b)
$$

因为$A$的维度为`(15935, 62061)`，$b$的维度为`(15935, 1)`。所以$x$的维度为`(62061, 1)`，$x_{k}$的维度为`(62061, 1)`。$Ax_{k}-b$的纬度为`(15935, 1)`。$A^{T}$的维度为`(62061, 15935)`, $A^{T}(Ax_{k}-b)$的维度`(62061, 1)`能够与$x_{k}$对齐。



In [3]:
class GradientDescent(Fun):
    def __init__(self, path_a, path_b):
        super(GradientDescent, self).__init__(path_a, path_b)

    def gradient_decs(self, eta=0.002, iter_times=1, x_input=None):
        x_output = None
        y_input = self.min_f(x_input)
        for i in range(iter_times):
            x_output = x_input - eta * self.A_T * (self.A * x_input - self.b)
            x_input = x_output  # 更新 x 的值
            y_output = self.min_f(x_output)
            print("pre_y is {}  and y is {}".format(y_input, y_output))
            y_input = y_output
        return x_output

In [4]:
if __name__ == "__main__":

    GD = GradientDescent(path_a='./news20_A.npz', path_b='./news20_b.npz')

    GD.gradient_decs(eta=0.002, iter_times=20, x_input=GD.x_init)

    print('')

pre_y is 15.722300259397239  and y is 13.255437472728753
pre_y is 13.255437472728753  and y is 12.830725844332473
pre_y is 12.830725844332473  and y is 12.51692906968907
pre_y is 12.51692906968907  and y is 12.27197839958532
pre_y is 12.27197839958532  and y is 12.071656835698864
pre_y is 12.071656835698864  and y is 11.900847835979816
pre_y is 11.900847835979816  and y is 11.750127725254952
pre_y is 11.750127725254952  and y is 11.613593698606474
pre_y is 11.613593698606474  and y is 11.487508355089918
pre_y is 11.487508355089918  and y is 11.369465105891495
pre_y is 11.369465105891495  and y is 11.25787798929786
pre_y is 11.25787798929786  and y is 11.151670710977404
pre_y is 11.151670710977404  and y is 11.05008706842521
pre_y is 11.05008706842521  and y is 10.9525750410614
pre_y is 10.9525750410614  and y is 10.858715532640206
pre_y is 10.858715532640206  and y is 10.768178203651498
pre_y is 10.768178203651498  and y is 10.680693783528262
pre_y is 10.680693783528262  and y is 10.59