In [8]:
class LinearRegression:
    def __init__(self, eta=0.01, iterations=10):
        self.lr = eta
        self.iterations = iterations
        self.w = 0.0
        self.bias = 0.0

    def fit(self, X, Y):
        cost_history = []

        for i in range(self.iterations):
            self.w, self.bias = self.update_weights(X, Y, self.w, self.bias, self.lr)
            # 计算误差观察训练过程
            cost = self.cost_function(X, Y, self.w, self.bias)
            cost_history.append(cost)

            if i % 10 == 0:
                print(f"iter={i} weight={self.w} bias={self.bias} cost={cost}")

        return self.w, self.bias, cost_history
    
    def predict(self, x):
        x = (x+100) / 200
        return self.w * x + self.bias

    def cost_function(self, X, Y, weight, bias):
        '''
        损失函数: 均平方差
        '''
        n = len(X)
        total_err = 0.0
        for i in range(n):
            predict_val = weight*X[i] + bias
            total_err += (Y[i] - predict_val) ** 2
        return total_err / n

    def update_weights(self, X, Y, weight, bias, learning_rate):
        '''
        更新权重
        '''
        dw = 0
        db = 0
        n = len(X)

        for i in range(n):
            predict_val = weight*X[i] + bias
            dw += -2 * X[i] * (Y[i] - predict_val)
            db += -2 * (Y[i] - predict_val)

        # 此处没有使用损失函数
        weight -= (dw/n) * learning_rate
        bias -= (db/n) * learning_rate

        return weight, bias

In [9]:
x = [1, 2, 3, 10, 20, -2, -10, -100, -5, -20]
y = [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]

model = LinearRegression(0.01, 500)

# 归一化
X = [(k+100)/200 for k in x]

model.fit(X, y)

iter=0 weight=0.00536 bias=0.01 cost=0.48728175110763206
iter=10 weight=0.053109070565445005 bias=0.09729210467195501 cost=0.38899297733851157
iter=20 weight=0.09185949793896436 bias=0.1650501725645301 cost=0.32850118905971937
iter=30 weight=0.12355559887601003 bias=0.21750957932395926 cost=0.29120225586961296
iter=40 weight=0.14972040136851605 bias=0.2579879308876827 cost=0.26813541813096137
iter=50 weight=0.17154693316539277 bias=0.2890839340760068 cost=0.2538025543730621
iter=60 weight=0.18996972908046258 bias=0.31283317412540457 cost=0.2448301356419118
iter=70 weight=0.20572084345786376 bias=0.33083013695269725 cost=0.23914823251671202
iter=80 weight=0.21937372535596153 bias=0.3443237905436823 cost=0.2354867733433102
iter=90 weight=0.23137758646445397 bias=0.35429245491156863 cost=0.23306646026790506
iter=100 weight=0.24208432187426088 bias=0.3615024485631224 cost=0.23140912266667318
iter=110 weight=0.25176959741464777 bias=0.36655402692391076 cost=0.23022135452048395
iter=120 weig

(0.48287606601065225,
 0.2947980037186236,
 [0.48728175110763206,
  0.47516750697296073,
  0.4636284757791754,
  0.4526372382525622,
  0.44216768223152825,
  0.4321949403548075,
  0.42269533072015475,
  0.413646300371919,
  0.40502637148263976,
  0.39681509010024046,
  0.38899297733851157,
  0.38154148289440826,
  0.3744429407812423,
  0.36768052717213195,
  0.36123822025311086,
  0.3551007619900942,
  0.34925362171846397,
  0.34368296146838867,
  0.3383756029431314,
  0.33331899607154614,
  0.32850118905971937,
  0.32391079987029014,
  0.31953698906138883,
  0.3153694339203815,
  0.3113983038306921,
  0.3076142368129218,
  0.30400831718428567,
  0.30057205428305317,
  0.297297362207223,
  0.2941765405190838,
  0.29120225586961296,
  0.28836752449886616,
  0.2856656955705966,
  0.28309043530133504,
  0.2806357118460584,
  0.2782957809043789,
  0.276065172012905,
  0.2739386754910644,
  0.27191133000923756,
  0.2699784107495339,
  0.26813541813096137,
  0.26637806707208067,
  0.26470227

In [10]:
test_x = [90, 80, 81, 82, 75, 40, 32, 15, 5, 1, -1, -15, -20, -22, -33, -45, -60, -90]

for i in range(len(test_x)):
    print(f'input:{test_x[i]}  =>  predict:{model.predict(test_x[i])}')

input:90  =>  predict:0.7535302664287432
input:80  =>  predict:0.7293864631282105
input:81  =>  predict:0.7318008434582639
input:82  =>  predict:0.7342152237883172
input:75  =>  predict:0.7173145614779444
input:40  =>  predict:0.6328112499260801
input:32  =>  predict:0.613496207285654
input:15  =>  predict:0.5724517416747485
input:5  =>  predict:0.548307938374216
input:1  =>  predict:0.538650417054003
input:-1  =>  predict:0.5338216563938964
input:-15  =>  predict:0.5000203317731508
input:-20  =>  predict:0.48794843012288447
input:-22  =>  predict:0.48311966946277796
input:-33  =>  predict:0.4565614858321921
input:-45  =>  predict:0.42758892187155295
input:-60  =>  predict:0.39137321692075405
input:-90  =>  predict:0.3189418070191562
