In [2]:
import numpy as np
import random

In [3]:
# inference, test, predict, same thing, Run model after training
# w - theta 1
# b - theta 0
# h(θ) = θ0 + θ1 * X
def inference(w, b, x):
    pred_y = w * x + b
    # 预测值
    return pred_y

In [4]:
# 求w,b值对应的loss
def eval_loss(w, b, x_list, gt_y_list):
    avg_loss = 0.0
    for i in range(len(x_list)):
        # loss function
        avg_loss += 0.5 * (w * x_list[i] + b - gt_y_list[i]) ** 2
    avg_loss /= len(gt_y_list)
    return avg_loss

In [11]:
# ∂ (Jθ0) / ∂θ0 = 1/m * ∑(h(xi) - yi)
# ∂ (Jθ1) / ∂θ1 = 1/m * ∑(h(xi) - yi) * xi
def gradient(pred_y, gt_y, x):
    diff = pred_y - gt_y
    dw = diff * x
    db = diff
    return dw, db

In [6]:
# batch 批
def cal_step_gradient(batch_x_list, batch_gt_y_list, w, b, lr):
    avg_dw, avg_db = 0,0
    batch_size = len(batch_x_list)
    for i in range(batch_size):
        # get label data
        pred_y = inference(w, b, batch_x_list[i])
        dw, db = gradient(pred_y, batch_gt_y_list[i], batch_x_list[i])
        avg_dw += dw
        avg_db += db
    avg_dw /= batch_size
    avg_db /= batch_size
    w -= lr * avg_dw
    b -= lr * avg_db
    return w,b

In [16]:
def train(x_list, gt_y_list, batch_size, lr, max_iter):
    w = 0
    b = 0
    num_samples = len(x_list)
    for i in range(max_iter):
        batch_idxs = np.random.choice(len(x_list), batch_size)
        batch_x = [x_list[j] for j in batch_idxs]
        batch_y = [gt_y_list[j] for j in batch_idxs]
        w, b = cal_step_gradient(batch_x, batch_y, w, b, lr)
        print('w:{0}, b:{1}'.format(w, b))
        print('loss is {0}'.format(eval_loss(w, b, x_list, gt_y_list)))

In [8]:
def gen_sample_data():
    w = random.randint(0, 10) + random.random()
    b = random.randint(0, 5) + random.random()
    num_samples = 100
    x_list = []
    y_list = []
    for i in range(num_samples):
        x = random.randint(0, 100) * random.random()
        y = w * x + b + random.random() * random.randint(-1, 1)
        x_list.append(x)
        y_list.append(y)
    return x_list, y_list, w, b

In [18]:
def run():
    x_list, y_list, w, b = gen_sample_data()
    lr = 0.00001
    max_iter = 10000
    train(x_list, y_list, 50, lr, max_iter)

In [19]:
if __name__ == '__main__':
    run()

w:0.12199538988946165, b:0.0020330346231504067
loss is 31066.181418009557
w:0.22191221570278175, b:0.003893753459513401
loss is 30057.522957067144
w:0.3293324106793329, b:0.005791081451767756
loss is 28991.707038404747
w:0.4140572325777036, b:0.00745425836531514
loss is 28164.616687133515
w:0.5016126755708644, b:0.00905139060191852
loss is 27322.492275625646
w:0.5919771120102303, b:0.010751400939419101
loss is 26466.74676098428
w:0.6967197692496324, b:0.012556651040447316
loss is 25491.907817185845
w:0.7853443291821772, b:0.014199786715208095
loss is 24681.348657024155
w:0.8940782787815713, b:0.01605244329174599
loss is 23704.788617839404
w:0.9622357634877617, b:0.01744338397227294
loss is 23102.668495496855
w:1.04904572700973, b:0.019149739344645995
loss is 22346.995557612776
w:1.1308939829234927, b:0.020613425288379335
loss is 21646.046851921357
w:1.2196383755588585, b:0.02234160061693875
loss is 20898.640540314296
w:1.3219185021756634, b:0.024057543937081996
loss is 20053.5686206459

w:6.22552490170312, b:0.13022723183705323
loss is 3.9374645332254783
w:6.225633255700452, b:0.1302481562213336
loss is 3.937470249473884
w:6.225657569838252, b:0.13026897134040966
loss is 3.9374443840672404
w:6.225644346263572, b:0.13028546566902052
loss is 3.9374067997099917
w:6.225663500910364, b:0.13030548266987396
loss is 3.9373800181775755
w:6.225782004112778, b:0.13032892730045043
loss is 3.93741399160364
w:6.225733180502349, b:0.13034760341968343
loss is 3.937344727510955
w:6.225744086002607, b:0.13036527731055467
loss is 3.9373194589540144
w:6.225668234265878, b:0.1303828918296002
loss is 3.9372393225773523
w:6.225351274522785, b:0.13039503006273154
loss is 3.937121729395827
w:6.225468091250417, b:0.13041912918532425
loss is 3.9370925373022327
w:6.22547762829418, b:0.1304391549889757
loss is 3.937057625563276
w:6.225457711224731, b:0.1304586229056385
loss is 3.93701684937383
w:6.225084483037857, b:0.13046568256098584
loss is 3.937039370692046
w:6.2250087196213695, b:0.130483286

w:6.225104114336098, b:0.14429964632080938
loss is 3.9112777947961934
w:6.225043996797185, b:0.1443163237753544
loss is 3.911247202411448
w:6.225014309059624, b:0.14433679571561275
loss is 3.9112115667788148
w:6.225060996715073, b:0.14435830069347658
loss is 3.9111685868286346
w:6.225166680275177, b:0.1443787003796626
loss is 3.9111373343127633
w:6.225217302349799, b:0.14439802782741543
loss is 3.9111113133163324
w:6.225371364001658, b:0.1444171146325032
loss is 3.91113213178787
w:6.225521289583478, b:0.1444337874273851
loss is 3.9111940396847844
w:6.225330401064304, b:0.1444483319836468
loss is 3.911055659265979
w:6.225332186347704, b:0.14446532969479978
loss is 3.9110250152112265
w:6.224893314600095, b:0.14447667361515917
loss is 3.9109761098404636
w:6.225004653113581, b:0.14449792398028877
loss is 3.9109133405601564
w:6.225052595602127, b:0.14451620683485253
loss is 3.910875768566839
w:6.224921968212907, b:0.1445321534847181
loss is 3.910864946645577
w:6.224929272856413, b:0.1445539

loss is 3.8858409732974866
w:6.224653433358828, b:0.15805373939698736
loss is 3.8858245703484418
w:6.224702682182565, b:0.15807326887867104
loss is 3.885776029734919
w:6.224463356489797, b:0.1580876612262531
loss is 3.8858470376331353
w:6.224394559108197, b:0.1581058630616301
loss is 3.8858588810245553
w:6.224523088960019, b:0.15812589257386864
loss is 3.885742668501136
w:6.2246050243060544, b:0.15814609458730097
loss is 3.8856692388338403
w:6.224414764858915, b:0.15815827562577034
loss is 3.8857470628299313
w:6.224600517964086, b:0.1581773267397128
loss is 3.885612952692307
w:6.224780237518912, b:0.15820042553857494
loss is 3.8855294951650476
w:6.224612315408377, b:0.1582139642900867
loss is 3.885540642061793
w:6.224734382279307, b:0.15823516343385735
loss is 3.885470402123348
w:6.2248256334750875, b:0.15825918065119987
loss is 3.88541910840466
w:6.2248807495304135, b:0.15827965117647053
loss is 3.8853838944831796
w:6.225025862258068, b:0.15829731485986742
loss is 3.8853824979603138
w

loss is 3.8630818487223006
w:6.225675027902655, b:0.17102189618637084
loss is 3.8628296533604862
w:6.225680538386716, b:0.17103725885131224
loss is 3.8628117939835653
w:6.225822459783242, b:0.17105640716926912
loss is 3.863051125465737
w:6.2259658896643995, b:0.17107834918821568
loss is 3.863322602897044
w:6.226081645264685, b:0.17109899439190665
loss is 3.8635615557155036
w:6.226143356860135, b:0.17111996439174534
loss is 3.8636802306656186
w:6.2262392167677545, b:0.17113753511480173
loss is 3.8639042260642698
w:6.226282183643213, b:0.1711538207910826
loss is 3.8639945564911398
w:6.226232686875659, b:0.17116742147202516
loss is 3.8638327212847035
w:6.226152594072778, b:0.17118437371488754
loss is 3.8635884922134767
w:6.226071654975974, b:0.17119819738896927
loss is 3.8633584088926534
w:6.225948864691164, b:0.17121422916510073
loss is 3.86303897462053
w:6.22602922201001, b:0.1712347143315462
loss is 3.8631894390277015
w:6.225922060815891, b:0.1712503201781036
loss is 3.862913805080368


w:6.224043956465789, b:0.18195985405157752
loss is 3.8418219523389974
w:6.224056174019013, b:0.18197855612648228
loss is 3.8417803466053146
w:6.224054142318911, b:0.18199698305157946
loss is 3.8417474308495905
w:6.2241795876439365, b:0.18201889309563543
loss is 3.8416489191867287
w:6.224276340406625, b:0.1820377937183686
loss is 3.841587278757968
w:6.224322160410417, b:0.18205797783422836
loss is 3.8415428934321336
w:6.2243567502162, b:0.18207773078973086
loss is 3.8415034344805523
w:6.224306629217109, b:0.18209466484105027
loss is 3.841477402419346
w:6.224110585676211, b:0.18210933655769243
loss is 3.8415105351112375
w:6.224076349661758, b:0.18212951180752637
loss is 3.8414903246569794
w:6.223940913417161, b:0.18214766397597057
loss is 3.8415433083975588
w:6.22388875479647, b:0.18216312381231423
loss is 3.841556128530661
w:6.2239234166911315, b:0.1821818950168695
loss is 3.84149330289333
w:6.223875013877346, b:0.18219744495048534
loss is 3.8415041733359123
w:6.223888153362798, b:0.182

loss is 3.8183378640066734
w:6.224877074581352, b:0.1948865760849544
loss is 3.818418558458548
w:6.225006123748116, b:0.1949055943850163
loss is 3.8185520192097653
w:6.22505723462073, b:0.1949248321876265
loss is 3.81859145639039
w:6.22516126263918, b:0.19494864751064367
loss is 3.8187130045761264
w:6.225283906920984, b:0.19497108193481513
loss is 3.81888981022592
w:6.225284205964058, b:0.19499068649902457
loss is 3.8188551096226617
w:6.225330318661089, b:0.19500905780046138
loss is 3.8189103066432097
w:6.2252472903357905, b:0.19502619450258296
loss is 3.8187231400920774
w:6.225360403708045, b:0.19504620076379095
loss is 3.8189030678582214
w:6.225339098001505, b:0.19506462365934069
loss is 3.818827661975193
w:6.225363884312316, b:0.19508464580413754
loss is 3.8188409986070293
w:6.225149191901623, b:0.19509653121062293
loss is 3.8184264052023535
w:6.225001943058526, b:0.1951151814640089
loss is 3.818167399761265
w:6.224986356592426, b:0.19513362786799857
loss is 3.818112315910687
w:6.22

w:6.224930501473718, b:0.20731033757582162
loss is 3.796082516119964
w:6.225027319111448, b:0.2073283676740612
loss is 3.796218418597805
w:6.22506316187211, b:0.20734604931726527
loss is 3.796252985435911
w:6.225031566297267, b:0.20736475642102944
loss is 3.7961608825856374
w:6.225059602372756, b:0.20738247012748875
loss is 3.7961810039751147
w:6.225058992916375, b:0.20740053001662495
loss is 3.796147480556057
w:6.225073241003878, b:0.20742023608121904
loss is 3.796139026115442
w:6.22509305532992, b:0.20744168267773094
loss is 3.7961385361586526
w:6.224971915766473, b:0.20745776894549084
loss is 3.795887840905977
w:6.225007120656859, b:0.2074758890201013
loss is 3.795917282479388
w:6.224969109369512, b:0.20749204784037342
loss is 3.7958214386042783
w:6.22478524478815, b:0.2075107248579729
loss is 3.795498442530207
w:6.2248636649578195, b:0.207531560454461
loss is 3.795577531780777
w:6.2248154880510365, b:0.20754779345475802
loss is 3.795475452232944
w:6.224834347923442, b:0.20756781730

loss is 3.7724970169415957
w:6.2234841907318135, b:0.21980688145253335
loss is 3.7725053662239008
w:6.223302418576647, b:0.21982119984620177
loss is 3.772574949403724
w:6.223252806783248, b:0.21983926552843838
loss is 3.7725774980029056
w:6.223137748790942, b:0.21985285547295144
loss is 3.7726510323842
w:6.223205946407516, b:0.2198759258759323
loss is 3.7725476340886868
w:6.223165345156914, b:0.2198930247539563
loss is 3.772551584035625
w:6.22323352999355, b:0.21991554080780798
loss is 3.77245245254238
w:6.223356866176998, b:0.21993594623121498
loss is 3.7723302853229184
w:6.223420313673675, b:0.21995497364711306
loss is 3.7722617633457833
w:6.223370763884762, b:0.2199749462624526
loss is 3.7722509041481582
w:6.2233265094463395, b:0.21999419145627358
loss is 3.7722420355386355
w:6.2231785055307345, b:0.2200124527852457
loss is 3.7723204095597964
w:6.2232436130208315, b:0.22003255235429733
loss is 3.7722298266775844
w:6.223272737597355, b:0.22005144652937275
loss is 3.7721734537887017
w

w:6.223685495642522, b:0.2342497174189885
loss is 3.7462521745024513
w:6.223454499073064, b:0.23426177897180162
loss is 3.746183329124252
w:6.223488691256571, b:0.2342821315282101
loss is 3.7461477755181147
w:6.223547698029081, b:0.23430453561051487
loss is 3.746114144659407
w:6.223727975720474, b:0.2343265847846801
loss is 3.7461316453897218
w:6.223787138554858, b:0.23434742738438039
loss is 3.7461246922248055
w:6.223841453193011, b:0.23437036016914423
loss is 3.746116663119125
w:6.22383278441354, b:0.23438908965378458
loss is 3.7460772405362337
w:6.223737683902145, b:0.23440501110761697
loss is 3.7459946443106418
w:6.2238112816441875, b:0.2344260147604123
loss is 3.745997147375171
w:6.223804627274624, b:0.23444210068355784
loss is 3.7459641041963327
w:6.223755243445304, b:0.2344583244449977
loss is 3.7459072123398314
w:6.22354064280435, b:0.23447269547251814
loss is 3.7458083228672847
w:6.223525358367496, b:0.2344930647291665
loss is 3.7457691417618797
w:6.223364680735844, b:0.234510

loss is 3.7212203009524334
w:6.222506091105311, b:0.24832132829298964
loss is 3.721106360651385
w:6.22259169337564, b:0.2483433459585392
loss is 3.7209742525670126
w:6.222555527484057, b:0.2483603918209001
loss is 3.720980408418055
w:6.222411017024671, b:0.24837213133837363
loss is 3.721129662082616
w:6.222227206023466, b:0.24838771589404351
loss is 3.721368444217106
w:6.2223014243182275, b:0.24840722558318135
loss is 3.7212179195633386
w:6.222340226768537, b:0.24842518019689128
loss is 3.7211286895894333
w:6.222481741314282, b:0.2484474013446529
loss is 3.7209040392296453
w:6.222433873007583, b:0.248468712675914
loss is 3.720923483071962
w:6.222478605924284, b:0.24848583748744008
loss is 3.720837420383163
w:6.2222874841754585, b:0.24850196755047332
loss is 3.721064848697499
w:6.222361986867723, b:0.24852528051423953
loss is 3.720914708350082
w:6.22229693017978, b:0.24854326793011153
loss is 3.7209749225904267
w:6.222434851294516, b:0.2485627351310056
loss is 3.7207500855658107
w:6.222

w:6.222951092494121, b:0.2613478670701753
loss is 3.6971943236141236
w:6.223029383901471, b:0.2613654592734062
loss is 3.697166983702199
w:6.222989858003918, b:0.26138507112713244
loss is 3.6971281686670263
w:6.223078342913166, b:0.26140542144970597
loss is 3.69710302527577
w:6.22285062602396, b:0.26141916440213187
loss is 3.697074988872062
w:6.222867039037821, b:0.26143669263859476
loss is 3.6970406971366168
w:6.222826022639313, b:0.26145301068077687
loss is 3.6970186642097955
w:6.223014853124329, b:0.26147261916176145
loss is 3.696972438393285
w:6.222874075065026, b:0.26148969733261296
loss is 3.6969440340413304
w:6.222870844964519, b:0.2615055240589909
loss is 3.6969159088878625
w:6.22294072523654, b:0.2615237876365861
loss is 3.696877346801239
w:6.2230881455177505, b:0.2615426505996575
loss is 3.6968583081073767
w:6.2230595107234254, b:0.26156004147518963
loss is 3.6968212768336497
w:6.222988271055814, b:0.2615783530818453
loss is 3.6967798983989253
w:6.223076323487436, b:0.2615995

w:6.222524750971746, b:0.27425604232380685
loss is 3.673997765901866
w:6.22248849848855, b:0.274272327232333
loss is 3.673981474095076
w:6.2225082466427155, b:0.2742918432394661
loss is 3.6739388962642243
w:6.222541547969771, b:0.27430935712045956
loss is 3.6738963676218463
w:6.222542617073126, b:0.27432769394449613
loss is 3.673863002881074
w:6.222403653979222, b:0.27434393319176764
loss is 3.673891271272926
w:6.222203161352447, b:0.2743599046326367
loss is 3.6740020408078347
w:6.22228501860497, b:0.27437928077020485
loss is 3.6739018796909697
w:6.222366097729837, b:0.2743987329791249
loss is 3.6738133026837954
w:6.222354281439308, b:0.2744181414680045
loss is 3.673785315589307
w:6.222388385676406, b:0.27443740488923013
loss is 3.673730685576465
w:6.22240775196868, b:0.27445359084009047
loss is 3.673691068051271
w:6.222412325127526, b:0.27447421113801546
loss is 3.6736514721071627
w:6.222296139422391, b:0.2744885091602705
loss is 3.673696269123444
w:6.2224828906576635, b:0.27451040298

w:6.223492926097711, b:0.28797154491486743
loss is 3.650229418940076
w:6.223509815797738, b:0.28799450642937524
loss is 3.650217923144815
w:6.223288732538632, b:0.2880090262943966
loss is 3.649851479824689
w:6.223310625383341, b:0.28802692267145724
loss is 3.6498500140617285
w:6.223305597157667, b:0.2880458882190974
loss is 3.6498095421694656
w:6.223389388465107, b:0.28806793872418224
loss is 3.649892738038862
w:6.223421488383284, b:0.2880884000683989
loss is 3.64990657369676
w:6.22337159906828, b:0.2881050428472514
loss is 3.649800408820292
w:6.223218605393729, b:0.2881199460329678
loss is 3.6495642687637218
w:6.223178160366211, b:0.2881359933808975
loss is 3.649486920382777
w:6.223199515867447, b:0.28815237797394244
loss is 3.649483468146073
w:6.2231799715800715, b:0.2881726269670942
loss is 3.6494242518130586
w:6.223157468862722, b:0.2881857691792834
loss is 3.649374831438405
w:6.223254016093054, b:0.28820543261706955
loss is 3.649458341276479
w:6.223222195593787, b:0.28822384001757