In [5]:
#Linear Regression
import numpy as np
import random

In [6]:
#inference:get prediction
def inference(w,b,x):   #w为权重，b为残差
    pred_y = w * x + b
    return pred_y

In [35]:
def eval_loss(w,b,x_list,gt_y_list):
    avg_loss = 0.0
    for i in range(len(x_list)):
        avg_loss += 0.5 * (w * x_list[i] + b - gt_y_list[i]) ** 2
    avg_loss /= len(gt_y_list)
    return avg_loss        

In [36]:
#gradient: 得到dw&db
def gradient(pred_y,gt_y,x):
    diff = pred_y - gt_y # gt_y是ground truth是已知的，pred_y是inference算出来的 
    dw = diff * x
    db = diff
    return dw,db    

In [37]:
def cal_step_gradient(batch_x_list,batch_gt_y_list,w,b,lr):
    #batch:批
    avg_dw,avg_db = 0, 0
    batch_size = len(batch_x_list)
    for i in range(batch_size):
        pred_y = inference(w,b,batch_x_list[i])
        dw,db = gradient(pred_y,batch_gt_y_list[i],batch_x_list[i])
        avg_dw += dw
        avg_db += db
    avg_dw /= batch_size
    avg_db /= batch_size
    w -= lr * avg_dw
    b -= lr * avg_db
    return w,b     
    

In [38]:
def train(x_list,gt_y_list, batch_size, lr, max_iter):
    w = 0
    b = 0
    num_samples = len(x_list)
    for i in range(max_iter):
        batch_idxs = np.random.choice(len(x_list),batch_size)
        batch_x = [x_list[j] for j in batch_idxs]
        batch_y = [gt_y_list[j] for j in batch_idxs]
        w,b = cal_step_gradient(batch_x,batch_y,w,b,lr)
        print('w:{0},b:{1}'.format(w,b))
        print('loss is {0}'.format(eval_loss(w,b,x_list,gt_y_list)))       

In [39]:
def gen_sample_data():
    w = random.randint(0,10)+random.random()
    b = random.randint(0,5) +random.random()
    num_samples = 100
    x_list = []
    y_list = []
    for i in range(num_samples):
        x = random.randint(0,100) * random.random()
        y = w*x + b + random.random() * random.randint(-1,1)
        x_list.append(x)
        y_list.append(y)
    return x_list,y_list,w,b


In [40]:
def run():
    x_list,y_list,w,b = gen_sample_data()
    lr = 0.001
    max_iter = 10000
    train(x_list, y_list, 50, lr, max_iter)

In [41]:
if __name__ == '__main__':
    run()

w:5.5081747631266715,b:0.13226085002266885
loss is 124.37988881796588
w:5.96260980691037,b:0.14439067229287295
loss is 2.277704511749757
w:6.01897551274637,b:0.14643351460419557
loss is 1.1461148016722973
w:6.017431456081949,b:0.14735985558650125
loss is 1.135110154920532
w:6.0067910688766135,b:0.14815470187600618
loss is 1.126328177010027
w:6.012104372929059,b:0.14927326003272376
loss is 1.1156091266706196
w:6.001610287257139,b:0.14975754080732237
loss is 1.1603378349400644
w:6.0181362320220915,b:0.15105174332931934
loss is 1.1366482109096792
w:6.005320055932416,b:0.15171954375102534
loss is 1.1300220680984643
w:6.00337503624583,b:0.1526478265795404
loss is 1.1420116758957801
w:6.0067646440818265,b:0.1537240733730154
loss is 1.1206389776669605
w:6.000017334458275,b:0.15450122211414086
loss is 1.1708680424012001
w:6.014765339814773,b:0.15578464398605404
loss is 1.1155786535660959
w:6.006502056615039,b:0.1565602272314702
loss is 1.1188401072473184
w:6.010404666858153,b:0.157552373297044

w:5.995471749592485,b:0.802175443427
loss is 0.5963767515408304
w:5.991894675274716,b:0.8027389229015767
loss is 0.602713131021575
w:5.997918878487092,b:0.8036837194310117
loss is 0.598023309755171
w:5.994476962250056,b:0.8042490364874728
loss is 0.5956020844342258
w:5.984605101464345,b:0.8046898361692327
loss is 0.6531141881827678
w:5.994010155568357,b:0.8057140865386052
loss is 0.595201202083851
w:5.9927033530251315,b:0.8064886909891719
loss is 0.5974206162600114
w:5.987704980794571,b:0.8070235058936656
loss is 0.6228142688419098
w:5.99260839352305,b:0.80785130869501
loss is 0.5967086870589645
w:5.997967246681922,b:0.8086032751054397
loss is 0.5952187984545965
w:5.998086119730033,b:0.8091066306319844
loss is 0.5952186973697239
w:5.991749715204091,b:0.8096312128015057
loss is 0.5981631542041961
w:5.997680109859625,b:0.8104295146222403
loss is 0.5934736247076273
w:6.002827986863529,b:0.811168723001865
loss is 0.6172830235037464
w:5.992823965085598,b:0.8116757536538665
loss is 0.5934161

loss is 0.3545651885161063
w:5.988027664726804,b:1.2423522950076566
loss is 0.356612238285414
w:5.985150974200746,b:1.2426789474561584
loss is 0.3521837242277646
w:5.986563762168285,b:1.2431035464176912
loss is 0.35311185198833817
w:5.985064532429621,b:1.2435805510911084
loss is 0.3517672129944753
w:5.988250313740008,b:1.244115098898031
loss is 0.3566051008100427
w:5.98402282450828,b:1.2444861940272822
loss is 0.35182449778912495
w:5.984667799540075,b:1.244949686787129
loss is 0.35119383656413306
w:5.990086381653375,b:1.2454904730791219
loss is 0.3635512786401337
w:5.981568453378863,b:1.2457238908882542
loss is 0.35643936115375796
w:5.985649446918734,b:1.246187357887091
loss is 0.35080304302095144
w:5.981026129695492,b:1.2466273871388978
loss is 0.35788506973643375
w:5.98344413699547,b:1.2470715729156183
loss is 0.3512698048201265
w:5.989578820289526,b:1.2476432221063256
loss is 0.36040959773555725
w:5.984664089341709,b:1.2480277351989937
loss is 0.3497708713458673
w:5.9865635923733205

w:5.981467058246779,b:1.561774592979241
loss is 0.23667206688358244
w:5.979488701894418,b:1.5620976433220455
loss is 0.2306498089223556
w:5.976161142086744,b:1.5622608837453267
loss is 0.22916719575363287
w:5.977406030517939,b:1.5625889361482739
loss is 0.228341482258243
w:5.9760984130457135,b:1.5627992468988436
loss is 0.22905688975248606
w:5.977097491150962,b:1.5631208561517447
loss is 0.228205876836534
w:5.973603639352465,b:1.5634926005227852
loss is 0.23487957523060335
w:5.979650421751822,b:1.563933738835607
loss is 0.23051024625372005
w:5.97824994040144,b:1.5642492855912122
loss is 0.22822485428190237
w:5.97519759526731,b:1.5645988054920714
loss is 0.22990322053656295
w:5.977462886469255,b:1.5649403523103258
loss is 0.2276078531842985
w:5.978058812129735,b:1.5652555110270667
loss is 0.227771698959503
w:5.975456889183356,b:1.5655473675867277
loss is 0.22905935814576586
w:5.974958734160848,b:1.5658013045927281
loss is 0.22997928517364818
w:5.979766539961301,b:1.5662811711342572
loss

w:5.969198440060034,b:1.7988898196693315
loss is 0.16881595773047905
w:5.972832384708919,b:1.7992206819188254
loss is 0.1663348891937124
w:5.969876531018822,b:1.7993411149032936
loss is 0.16727502458790863
w:5.966083709725348,b:1.7994162610083138
loss is 0.18092691621474427
w:5.971129172720042,b:1.799670879535485
loss is 0.16575563549927214
w:5.9745088647952205,b:1.7999283646663726
loss is 0.16945562310026557
w:5.9733624873426185,b:1.8001005139634172
loss is 0.1669066298343631
w:5.975514174196377,b:1.800309456459141
loss is 0.17266963449505743
w:5.966425639726107,b:1.8002660974197944
loss is 0.1788362064245923
w:5.971935245953769,b:1.800617072550551
loss is 0.16544366602713828
w:5.972006757186715,b:1.8008195024166642
loss is 0.1654241121052809
w:5.971136207090641,b:1.8010608459890247
loss is 0.1654398386392981
w:5.967696761478086,b:1.80117861098997
loss is 0.17286634335938672
w:5.973673141749691,b:1.8015561023814854
loss is 0.1672244278643213
w:5.97357777785149,b:1.8017920271333974
los

w:5.96256993755241,b:1.9426994318258277
loss is 0.15539140140577276
w:5.967135156460226,b:1.9429399659409663
loss is 0.14027759427256206
w:5.970890679995162,b:1.9431579355657345
loss is 0.14307129863558932
w:5.969374358785729,b:1.9432284613027877
loss is 0.14026874342901738
w:5.967009039541416,b:1.9433386528051704
loss is 0.1403501234565782
w:5.968511756313988,b:1.9434902774066987
loss is 0.13963808331337876
w:5.9656934782472115,b:1.943502428285566
loss is 0.1427226953223729
w:5.968322841978396,b:1.9437249707601472
loss is 0.1395690809790452
w:5.967864366886864,b:1.9438775574544807
loss is 0.13960304717540514
w:5.9655706366218775,b:1.9440270724210698
loss is 0.14292215854721654
w:5.968562073252319,b:1.9441974508342328
loss is 0.13955335473640507
w:5.973426532388598,b:1.9445073054274908
loss is 0.1527015939917658
w:5.962182283381003,b:1.944398932760975
loss is 0.15711121871488243
w:5.967745284739089,b:1.9447968025646565
loss is 0.13950435206079553
w:5.969566589681519,b:1.944945255059046

loss is 0.13039918899302336
w:5.96582380545826,b:2.0348631415513148
loss is 0.12782369387697703
w:5.961914984164196,b:2.034795452893282
loss is 0.1360713674441508
w:5.963249722289774,b:2.034901890910647
loss is 0.13156903179307874
w:5.9646319965738686,b:2.0350735898931713
loss is 0.1287302896541167
w:5.967054177604471,b:2.035333045544059
loss is 0.12826571322595728
w:5.965381861739419,b:2.035336977033533
loss is 0.12795105865799397
w:5.965734753605619,b:2.035443638896398
loss is 0.12777872713962865
w:5.968747321895905,b:2.0355800267780815
loss is 0.13134257358625714
w:5.963248474291662,b:2.0356090998963094
loss is 0.13144958477854768
w:5.966240943608071,b:2.035774499238693
loss is 0.1277276143638946
w:5.970394458576785,b:2.036017415349748
loss is 0.13702788605195673
w:5.967776173706304,b:2.036048220244409
loss is 0.12919333917469739
w:5.963087076631611,b:2.036055651823041
loss is 0.13181786508865193
w:5.968592519402084,b:2.036405310174847
loss is 0.13090378239050315
w:5.965187227696691

loss is 0.12424525886967616
w:5.966096120566234,b:2.112551723465432
loss is 0.12246065041844618
w:5.965956408499134,b:2.1125921132468632
loss is 0.12220734781617958
w:5.964622342677249,b:2.112614516362303
loss is 0.12075447782860593
w:5.96527737152176,b:2.1126826737433744
loss is 0.12124792533366363
w:5.962621345547077,b:2.1126576278469513
loss is 0.12181561705759845
w:5.960406386098844,b:2.1126730906442233
loss is 0.12753011169761633
w:5.965960378869672,b:2.1128849611423792
loss is 0.12220461233263828
w:5.966191150288479,b:2.112971360687314
loss is 0.12263133434629453
w:5.961549222583981,b:2.112897709614317
loss is 0.12395465024020055
w:5.964057579797353,b:2.113026276281761
loss is 0.12062922734260757
w:5.961311775605642,b:2.112999066325443
loss is 0.12457112721513618
w:5.964587648697551,b:2.113232808521512
loss is 0.12069975737775
w:5.967382209528795,b:2.1133365572192906
loss is 0.12567178907588455
w:5.9648123436654155,b:2.1133899052702994
loss is 0.12080954005422127
w:5.966683497521

loss is 0.12144066542906477
w:5.963909483489158,b:2.1664859396244154
loss is 0.11778101046636708
w:5.959846605044968,b:2.1663857236763087
loss is 0.12173775876832052
w:5.9633349696392965,b:2.1664450539871645
loss is 0.11736623244280141
w:5.963402273710935,b:2.1665217748708443
loss is 0.11739537720432543
w:5.962416894053886,b:2.1666208387259784
loss is 0.11735689906735514
w:5.961373413072406,b:2.166575871241477
loss is 0.11835321448918093
w:5.962080993105749,b:2.166646063460496
loss is 0.11755908068382456
w:5.96372090792254,b:2.1667378616929174
loss is 0.11760106030269102
w:5.960498700892611,b:2.1667397384570504
loss is 0.11998423966207859
w:5.965143959125002,b:2.1669002866811753
loss is 0.11976218456500388
w:5.960622263016744,b:2.166825033071186
loss is 0.11969790900535388
w:5.9643411932484485,b:2.1669366362974634
loss is 0.11829755749252163
w:5.962905207723631,b:2.167029728809198
loss is 0.11723558068075698
w:5.9624158027083185,b:2.16707609652188
loss is 0.11732937766474333
w:5.960386

w:5.962265525649043,b:2.1943623889355766
loss is 0.11600982191672699
w:5.962826279863145,b:2.1945109406371293
loss is 0.11619067270895889
w:5.958935967571316,b:2.1944413759386188
loss is 0.12120195615214534
w:5.960680403884813,b:2.194515557079964
loss is 0.11713045013580076
w:5.963178006477901,b:2.194582423820933
loss is 0.11646246405842745
w:5.96327713847359,b:2.1946000332622093
loss is 0.11656104061922803
w:5.961702866206423,b:2.194682669972871
loss is 0.11611582605387838
w:5.961142746336739,b:2.194667329298161
loss is 0.11653994046859192
w:5.962229845957301,b:2.194717278464678
loss is 0.11599488453765426
w:5.96490073356753,b:2.1948594303818654
loss is 0.11954610987559952
w:5.965834102653119,b:2.1949710751950735
loss is 0.12242948593582803
w:5.962129533674263,b:2.194922289818702
loss is 0.11598852248654615
w:5.961416879217614,b:2.194884561387102
loss is 0.1162822782649344
w:5.962436547434198,b:2.1948968571805696
loss is 0.11601590936908615
w:5.960437587609072,b:2.1947840709510125
los

loss is 0.12003466272954023
w:5.962610957708995,b:2.2207358356630547
loss is 0.11565839577506415
w:5.959229458327353,b:2.220704138247405
loss is 0.11781194332161059
w:5.960535880658515,b:2.220840729218213
loss is 0.11565514426933171
w:5.958564580622741,b:2.220808394591446
loss is 0.11953298833250851
w:5.961437495065596,b:2.2208921823633356
loss is 0.11514055361451347
w:5.96335695373168,b:2.221002493215028
loss is 0.1166844960093918
w:5.961469755361413,b:2.220975838484262
loss is 0.11513441212886563
w:5.962976821096247,b:2.220990130478798
loss is 0.11609355585680613
w:5.961874760867358,b:2.2209914754791638
loss is 0.11517486686341531
w:5.960357329014445,b:2.2208841206662244
loss is 0.11584808540011578
w:5.9646147383940695,b:2.2210504734313457
loss is 0.11964312671841126
w:5.957491885666804,b:2.2209170741983866
loss is 0.12321620297937566
w:5.963560261475402,b:2.221058404185903
loss is 0.11705925684038972
w:5.9607104466674246,b:2.220985159918778
loss is 0.1154872941026175
w:5.96012110768

w:5.958739418894126,b:2.239083440720825
loss is 0.11749751108599851
w:5.961917243218949,b:2.239223186111827
loss is 0.11500802716327434
w:5.959763396475482,b:2.239151869666738
loss is 0.11561977139111292
w:5.9588841541132025,b:2.2391287442504586
loss is 0.11716748694309537
w:5.963112650570441,b:2.2392488437098015
loss is 0.11661745297589868
w:5.957694221897994,b:2.239077307635251
loss is 0.12046249829629842
w:5.961165313128909,b:2.2391074617155935
loss is 0.11471002400437572
w:5.960376925187331,b:2.2391198289852077
loss is 0.1149873263255897
w:5.964661844740146,b:2.2392147263931195
loss is 0.1207679456572167
w:5.962806270408744,b:2.239232773587974
loss is 0.11607222900877973
w:5.9632190947286805,b:2.23928433520953
loss is 0.11682917888817057
w:5.960538415129308,b:2.2392581855171767
loss is 0.1148767018176397
w:5.958999439923292,b:2.2391594214160535
loss is 0.11691977821270465
w:5.9603648099937985,b:2.2391657299195766
loss is 0.11499460323390527
w:5.9597170832451605,b:2.239185723623982


w:5.9605144749795995,b:2.2509745953087545
loss is 0.11456714122790436
w:5.961812193282242,b:2.251057093481384
loss is 0.11496487297637004
w:5.958671870053149,b:2.25102363993544
loss is 0.1168130275004172
w:5.959760231261682,b:2.251052206619422
loss is 0.11508567089118255
w:5.961365667801997,b:2.2511642647691508
loss is 0.11464245535438429
w:5.96580914963386,b:2.251336625977985
loss is 0.12651117031749629
w:5.965470688163719,b:2.2513966737501354
loss is 0.12493737610119143
w:5.964010171115864,b:2.2514901857979703
loss is 0.11940256592348289
w:5.961730296413659,b:2.251460041787146
loss is 0.1148940356199469
w:5.960619368322521,b:2.2514492000448287
loss is 0.1145293573528559
w:5.957887153375187,b:2.251366743462947
loss is 0.11874372333853986
w:5.958826058447311,b:2.251426954363567
loss is 0.11647388663784745
w:5.961426172134361,b:2.2514578517193726
loss is 0.11467487143641397
w:5.9613381349247225,b:2.251482369862052
loss is 0.11462816534801634
w:5.960222368615736,b:2.251488152395632
loss 

In [None]:
#第二章第1节：经典机器学习I：线性回归与逻辑回归