Linear Regression

In [24]:
import numpy as np
import random

In [25]:
def inference(w,b,x):
    pred_y = w*x+b
    return pred_y

In [26]:
def eval_loss(w,b,x,y):
    pred_y = inference(w,b,x)
    avg_loss = 0.5 * (pred_y - y)**2
    avg_loss = avg_loss.mean()
    return avg_loss

In [27]:
def gradient(pred_y,gt_y,x):
    diff = pred_y - gt_y
    dw = diff * x
    db = diff
    return dw, db

In [28]:
def cal_step_gradient(batch_x_list, batch_gt_y_list, w,b,lr):
    avg_dw,avg_db = 0,0
    batch_size = len(batch_x_list)
    
    pred_y = inference(w,b,batch_x_list)
    dw,db = gradient(pred_y,batch_gt_y_list,batch_x_list)   
    avg_dw = dw.mean()
    avg_db = db.mean()
    w -= lr * avg_dw
    b -= lr * avg_db
    
    return w,b

In [29]:
def train(x_list, gt_y_list, batch_size,lr,max_iter):
    w = 0
    b = 0
    num_samples = len(x_list)
    for i in range (max_iter):
        c = np.arange(num_samples)
        a = np.random.choice(c, size=batch_size, replace=True)
        x = x_list[a]
        y = gt_y_list[a]

        w,b = cal_step_gradient(x,y,w,b,lr)
        print('w:{0}, b:{1}'.format(w,b))
        print('loss is {0}'.format(eval_loss(w,b,x,y)))

In [30]:
def gen_sample_data():
    w = random.randint(0,10) +  random.random()
    b = random.randint(0,5) + random.random()
    
    print(w,b)
    
    num_samples = 100
    x = np.random.randint(0,100,num_samples)*np.random.random(num_samples)
    y = w*x+b+np.random.random(num_samples)*np.random.randint(-1,1,num_samples)
    return x, y,w,b

In [31]:
def run():
    x_list,y_list,w,b = gen_sample_data()
    lr = 0.001
    max_iter = 10000
    train(x_list,y_list,50,lr,max_iter)
    

In [33]:
if __name__ == '__main__':
    run()

1.6079110624594475 1.6527916027547864
w:1.9640838855933644, b:0.04325827071510774
loss is 64.08425872047131
w:1.5174598798468804, b:0.0347755605553971
loss is 9.642445158808854
w:1.6623931256363298, b:0.038498041875128584
loss is 0.7665741041621661
w:1.618596013626598, b:0.03817510687043517
loss is 0.6007561262973675
w:1.6413171503561466, b:0.03925445934813741
loss is 0.47836933330253806
w:1.63148013456455, b:0.0395358106526109
loss is 0.3778970395807451
w:1.6383158773393098, b:0.04027108308392897
loss is 0.45122480621836464
w:1.636773115899162, b:0.04085460941143519
loss is 0.46933179934776403
w:1.632144740346523, b:0.041283698583030354
loss is 0.4157713437686549
w:1.637775512811395, b:0.04190409295491535
loss is 0.3918159388513882
w:1.6333961783059114, b:0.042359025765879826
loss is 0.4267447047078398
w:1.6320068542453967, b:0.04286798939852729
loss is 0.39213522272879886
w:1.6394381115421095, b:0.04359214698643095
loss is 0.41701221276615463
w:1.6340772771620027, b:0.043953633990769

w:1.628773486505064, b:0.5062632833355004
loss is 0.21497013148330185
w:1.627286584685839, b:0.5066686502223686
loss is 0.24939313598459772
w:1.6284234956160824, b:0.50704774232624
loss is 0.20899738670747856
w:1.6251081362156845, b:0.5073342823011195
loss is 0.2090841873749431
w:1.6273292596805462, b:0.5077630690549808
loss is 0.2219065874747784
w:1.6285074425187729, b:0.5081440933636243
loss is 0.21712971419910265
w:1.6231652581585971, b:0.5084403012056143
loss is 0.2192566427352389
w:1.6272366981660753, b:0.5088851833336513
loss is 0.22144322592368584
w:1.6239779706872512, b:0.5091288770994977
loss is 0.2006697488438177
w:1.6272167050860193, b:0.509484094755909
loss is 0.17447978265861944
w:1.625509073022503, b:0.5097286471621728
loss is 0.18824865718252431
w:1.6219271542161586, b:0.5099819733415799
loss is 0.19568249326546666
w:1.6234651459924347, b:0.5103444430105072
loss is 0.2162083359203344
w:1.6287441170725785, b:0.5107811103247688
loss is 0.18786827027907457
w:1.6270471574872

w:1.6188978904775602, b:0.768770084908101
loss is 0.1192037945679934
w:1.6176515683296926, b:0.7690616089316106
loss is 0.1770072896116818
w:1.6218548506219457, b:0.769340276179042
loss is 0.113388402490628
w:1.6212239170839327, b:0.769599239039625
loss is 0.15225213362779597
w:1.6205499163297503, b:0.7698701653901214
loss is 0.15116674208574601
w:1.6225503851907184, b:0.7701282799311682
loss is 0.11906090744221423
w:1.616086284885638, b:0.7702334413654268
loss is 0.13742669807657165
w:1.6236942731526103, b:0.7705554624307946
loss is 0.11355180391801573
w:1.6214828775613406, b:0.7707336959658413
loss is 0.13163503958037104
w:1.6247383423560475, b:0.7710035411886127
loss is 0.09161359997011535
w:1.615446221360926, b:0.770954376659744
loss is 0.11969529493587004
w:1.6197690451167848, b:0.7712821102983676
loss is 0.13686312417225258
w:1.6236887894261693, b:0.771565393667794
loss is 0.1136937806271672
w:1.618187279071305, b:0.7716946463680093
loss is 0.13066050932631462
w:1.624391606951138

w:1.6209136109342277, b:0.9607260385048643
loss is 0.09730908489361405
w:1.6188767251508005, b:0.9608451041005726
loss is 0.0926142011397082
w:1.6184534990630186, b:0.961001313533878
loss is 0.08024623055944805
w:1.6153659533231277, b:0.9610591286266906
loss is 0.07503249718986309
w:1.6177109891636507, b:0.9612321929673102
loss is 0.07945637674627166
w:1.6183127845839718, b:0.9614364202230149
loss is 0.09583247045112284
w:1.614617145851704, b:0.9614936748185668
loss is 0.07582849548696165
w:1.615831328913038, b:0.9617286730975557
loss is 0.0999258943160262
w:1.6173024727826473, b:0.9619134779550602
loss is 0.0862718659893363
w:1.6139301578313818, b:0.9620065940321906
loss is 0.0960563515174157
w:1.6134785191578822, b:0.9621721035538201
loss is 0.10009548045090946
w:1.6141699349097394, b:0.9624076453327272
loss is 0.11341290653163962
w:1.6175900775189012, b:0.9626239979859447
loss is 0.0836943339762957
w:1.6152210387920891, b:0.9627535005114004
loss is 0.10593131436074843
w:1.6168903249

loss is 0.05988994987983607
w:1.615119463722538, b:1.0972382599475057
loss is 0.0655107595537216
w:1.6152014724885557, b:1.0973469433335292
loss is 0.0678234602967529
w:1.6148189874125993, b:1.097489647072917
loss is 0.08080608219954195
w:1.6142167737913244, b:1.0976237283384227
loss is 0.07650901088528732
w:1.6113934714190354, b:1.0976892780207463
loss is 0.08428155346323743
w:1.615124494681869, b:1.0978761630511946
loss is 0.06421127396057119
w:1.6113547519428875, b:1.0979105252783845
loss is 0.07104023901614921
w:1.6157829210909131, b:1.09819986653885
loss is 0.08959124702449361
w:1.6140007022290015, b:1.0982631725657361
loss is 0.060489516824283045
w:1.6109917484668574, b:1.0983824255932757
loss is 0.08723463383330449
w:1.6148386787916538, b:1.0986199504182759
loss is 0.08027473847739873
w:1.6144917351096033, b:1.0987283746956698
loss is 0.06941352588684618
w:1.6130872068583153, b:1.0987541791960767
loss is 0.06151012968578933
w:1.6127514492310697, b:1.0988312546105055
loss is 0.07

w:1.612802659954815, b:1.190602965518658
loss is 0.055228112498444865
w:1.6108835921736613, b:1.1906210607703287
loss is 0.06384786317933697
w:1.612846748296089, b:1.1907665713051268
loss is 0.07319358579903815
w:1.612979121838492, b:1.1908270370229774
loss is 0.0541820469435215
w:1.6140365188965444, b:1.1909152850855194
loss is 0.05838708822729221
w:1.6046103303761552, b:1.19083360990831
loss is 0.09609932352732642
w:1.6156647094992436, b:1.1911771954936998
loss is 0.0577142692517182
w:1.606680080065324, b:1.1911000989986396
loss is 0.08633002568998853
w:1.610270524137323, b:1.1913096048554248
loss is 0.0771472082117164
w:1.6116557324055056, b:1.1914486682966017
loss is 0.07193996002222584
w:1.6117803980239942, b:1.191510925806184
loss is 0.06938145309137236
w:1.6113574054450002, b:1.191601914562019
loss is 0.06916443407631218
w:1.611416007768045, b:1.1916417629340315
loss is 0.06951356414486975
w:1.61044013578157, b:1.1917127524963351
loss is 0.06116947567256929
w:1.6130231896112153,

w:1.610806350493602, b:1.2519698409492148
loss is 0.07247973937956315
w:1.6089171655162693, b:1.2519603901837946
loss is 0.06554836705673361
w:1.6097878803529866, b:1.2520541263524865
loss is 0.0699937304221857
w:1.6094342312748011, b:1.252007158618372
loss is 0.061152566065297904
w:1.6119605343964944, b:1.2521286199345916
loss is 0.06083378178958989
w:1.610867601651201, b:1.2521614973801942
loss is 0.04609176519134763
w:1.6113867282412104, b:1.2522679589652306
loss is 0.0601692269896911
w:1.6096136392651095, b:1.2522819278148005
loss is 0.06289668235004771
w:1.608004786764068, b:1.2523023322546893
loss is 0.06682370008712181
w:1.6109386249324398, b:1.2524156023249589
loss is 0.048150938986449836
w:1.6099952797485244, b:1.2524628892445608
loss is 0.05980993629860287
w:1.6080309036277474, b:1.252462562827316
loss is 0.07493487403464959
w:1.6130800071252296, b:1.2526237215999336
loss is 0.05056072572340429
w:1.6113724115401602, b:1.2526431353021092
loss is 0.05420195540498506
w:1.6115544

w:1.608276788246078, b:1.289654817306792
loss is 0.06239184787174938
w:1.609829621909749, b:1.2897129758560462
loss is 0.0595377426636434
w:1.6132398757095037, b:1.2897753814535649
loss is 0.04247598522527383
w:1.6059974470907328, b:1.2897003152106687
loss is 0.07178902087170928
w:1.6125668103787445, b:1.2898672353602347
loss is 0.05994064497032939
w:1.610003505942891, b:1.289791421054517
loss is 0.06523146712670813
w:1.6096060509747436, b:1.2897851872515234
loss is 0.06514200913570117
w:1.6102397946278753, b:1.2898370275709359
loss is 0.0496970482751614
w:1.6063242688033497, b:1.2897680859434102
loss is 0.07084967680281672
w:1.6113065573576537, b:1.289967092124959
loss is 0.05336015055384311
w:1.6103054174805134, b:1.2899001321461974
loss is 0.05520955978729096
w:1.6094436753035157, b:1.2899502897461297
loss is 0.05975168025991505
w:1.6125843930109423, b:1.290075798012542
loss is 0.05366719025137831
w:1.6069339856205764, b:1.290019767102918
loss is 0.06390027046275748
w:1.608484007579

w:1.6061389280580445, b:1.3150335541548994
loss is 0.06790381184758519
w:1.6083330938694844, b:1.3151330921232467
loss is 0.06448039916034824
w:1.6097409682040362, b:1.3151523087517016
loss is 0.05660462082445142
w:1.6107444782305471, b:1.3152048415705646
loss is 0.04955036852890701
w:1.6112881315348926, b:1.3152496717284603
loss is 0.04194380791296491
w:1.6027280615090551, b:1.3151307359774806
loss is 0.08077461992507087
w:1.6108783363687593, b:1.3152471490607416
loss is 0.05678951592404264
w:1.6065619653415115, b:1.3152088077927442
loss is 0.0694479999674528
w:1.6081364674100493, b:1.3152078821830149
loss is 0.06568281364181176
w:1.6056595463468153, b:1.315177058087701
loss is 0.06720968769812079
w:1.6097719509146338, b:1.315266179855167
loss is 0.06399490709230254
w:1.6077436728385763, b:1.3152484993123144
loss is 0.0626012370089492
w:1.60824980197611, b:1.3153186396014793
loss is 0.06484300312160209
w:1.6110074915201842, b:1.3153877516223906
loss is 0.056813298721556243
w:1.6052505

loss is 0.052535944258358454
w:1.610538330743253, b:1.3367881133950206
loss is 0.04713429397962862
w:1.6029753874420207, b:1.33668964661573
loss is 0.0724570917297014
w:1.6117277182523686, b:1.3368921837861536
loss is 0.05707760322986041
w:1.6072963667447997, b:1.336782537113145
loss is 0.06902142396073846
w:1.608671980002592, b:1.336823850881129
loss is 0.06096750868475791
w:1.6077316380124598, b:1.3368082767316545
loss is 0.06196165409714818
w:1.6077240387206553, b:1.3368307473460501
loss is 0.0652277437147463
w:1.609476837245769, b:1.336888398015105
loss is 0.04789398470016541
w:1.6084836952757007, b:1.336859024655036
loss is 0.05679276671254108
w:1.6084876062155882, b:1.3368876626184878
loss is 0.05771745144481383
w:1.6065141576153479, b:1.3368036055455463
loss is 0.05880672129544153
w:1.60964468466106, b:1.3369026891911582
loss is 0.05984468856802211
w:1.611420741980959, b:1.336939228928905
loss is 0.0427654002097214
w:1.6084373251631523, b:1.336870439787735
loss is 0.059632875953

w:1.6058483111787798, b:1.348443767561976
loss is 0.06445376543850537
w:1.6096322597557748, b:1.3485399260733366
loss is 0.05914281811088211
w:1.606087957334652, b:1.3484646611636908
loss is 0.06733557259252979
w:1.6092786644113506, b:1.348593682302427
loss is 0.04794950645623647
w:1.611662538503388, b:1.348690601528966
loss is 0.03461524577738896
w:1.6107398120296157, b:1.3486826832678132
loss is 0.04650657796284666
w:1.6077996129630032, b:1.3486743097438978
loss is 0.05256095273317384
w:1.6100130367953858, b:1.3487354517515016
loss is 0.05332928307354668
w:1.6098441340058662, b:1.3487394482481143
loss is 0.05017437070485378
w:1.6082653984297723, b:1.3487490650567822
loss is 0.05941989867239203
w:1.607673916257236, b:1.3487998947833053
loss is 0.051129713098407796
w:1.6069266613778126, b:1.3488223173538934
loss is 0.06647816100348294
w:1.6054700135882132, b:1.348831847086544
loss is 0.06621110242624426
w:1.6097357753803565, b:1.3489160624436163
loss is 0.06206466561609437
w:1.60955357