In [126]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable

In [127]:
# N is batch size; D_in is input dimension; H is hidden dimension
N, D_in, D_noise, H = 100, 10, 100, 1000
x1 = Variable(torch.randn(N, D_in), requires_grad=False)
x2 = Variable(torch.randn(N, D_in), requires_grad=False)
x3 = Variable(torch.randn(N, D_noise), requires_grad=False)

w1 = Variable(torch.randn(D_in, 1), requires_grad=False)
w2 = Variable(torch.randn(D_in, 1), requires_grad=False)

print(torch.min(w1), torch.max(w1), torch.min(w2), torch.max(w2))

y = torch.mm(x1, w1) + torch.mm(x2, w2)

model1 = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, 1),
)

model2 = torch.nn.Sequential(
    torch.nn.Linear(D_in * 2, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, 1),
)

model3 = torch.nn.Sequential(
    torch.nn.Linear(D_in + D_noise, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, 1),
)

loss_fn = torch.nn.MSELoss()

N_test = 10000
x1_test = Variable(torch.randn(N_test, D_in))
x2_test = Variable(torch.randn(N_test, D_in))
x3_test = Variable(torch.randn(N_test, D_noise))
y_test = torch.mm(x1_test, w1) + torch.mm(x2_test, w2)

Variable containing:
-2.1349
[torch.FloatTensor of size 1]
 Variable containing:
 1.7703
[torch.FloatTensor of size 1]
 Variable containing:
-2.8305
[torch.FloatTensor of size 1]
 Variable containing:
 0.8744
[torch.FloatTensor of size 1]



In [128]:
print(x1.size(), x2.size(), x3.size(), y.size(), y_test.size())

torch.Size([100, 10]) torch.Size([100, 10]) torch.Size([100, 100]) torch.Size([100, 1]) torch.Size([10000, 1])


In [129]:
learning_rate = 1e-4
batch_size = 32
optimizer = torch.optim.Adam(model1.parameters(), lr=learning_rate, weight_decay=0.0001)
for t in range(500):
    index = torch.randperm(N)
    y_pred = model1(x1[index[0:batch_size],:])

    loss = loss_fn(y_pred, y[index[0:batch_size]])
    print(t, loss.data[0])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
y_test_pred = model1(x1_test)
test_loss = loss_fn(y_test_pred, y_test)
print(test_loss.data[0])

0 22.699222564697266
1 23.458572387695312
2 18.416669845581055
3 27.2829532623291
4 18.44976043701172
5 23.350513458251953
6 25.98028564453125
7 21.13286018371582
8 14.495930671691895
9 19.305130004882812
10 20.379884719848633
11 20.00303077697754
12 15.482952117919922
13 27.90176010131836
14 20.249065399169922
15 30.88016128540039
16 21.072797775268555
17 15.596909523010254
18 20.09552764892578
19 14.043254852294922
20 14.890569686889648
21 22.658079147338867
22 21.907791137695312
23 15.235745429992676
24 24.32879638671875
25 19.113901138305664
26 24.78824234008789
27 23.698211669921875
28 24.54877471923828
29 17.03287124633789
30 10.559033393859863
31 21.805797576904297
32 26.422170639038086
33 15.843240737915039
34 22.465112686157227
35 15.53139877319336
36 20.273080825805664
37 23.928346633911133
38 18.807933807373047
39 15.674638748168945
40 20.62818145751953
41 18.33435821533203
42 17.210800170898438
43 28.719661712646484
44 27.6954345703125
45 15.024020195007324
46 23.9028015136

403 9.566397666931152
404 4.914907932281494
405 9.504349708557129
406 5.361973285675049
407 7.244177341461182
408 9.45284366607666
409 5.635361671447754
410 9.2537202835083
411 6.816647529602051
412 9.570273399353027
413 10.947955131530762
414 7.072892189025879
415 5.099292278289795
416 8.228792190551758
417 10.126411437988281
418 7.135617733001709
419 4.7916364669799805
420 6.470151424407959
421 5.276288986206055
422 8.36807918548584
423 7.58637809753418
424 6.685788154602051
425 5.547449588775635
426 10.65848159790039
427 9.29642105102539
428 6.101983070373535
429 6.504255294799805
430 4.321829795837402
431 6.971958160400391
432 5.4182281494140625
433 6.035802841186523
434 5.5955071449279785
435 7.688999652862549
436 5.1820759773254395
437 3.7370376586914062
438 8.3994779586792
439 5.69775915145874
440 9.378360748291016
441 8.963579177856445
442 6.51920223236084
443 9.404097557067871
444 7.101121425628662
445 6.5730156898498535
446 9.681452751159668
447 7.038683891296387
448 5.383360

In [130]:
optimizer = torch.optim.Adam(model2.parameters(), lr=learning_rate, weight_decay=0.0001)
for t in range(500):
    index = torch.randperm(N)
    y_pred = model2(torch.cat((x1[index[0:batch_size],:], x2[index[0:batch_size],:]), 1))

    loss = loss_fn(y_pred, y[index[0:batch_size],:])
    print(t, loss.data[0])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
y_test_pred = model2(torch.cat((x1_test, x2_test), 1))
test_loss = loss_fn(y_test_pred, y_test)
print(test_loss.data[0])

0 16.925355911254883
1 20.96596336364746
2 22.69664764404297
3 15.431150436401367
4 26.36935043334961
5 21.43410873413086
6 20.110994338989258
7 23.082677841186523
8 26.02195930480957
9 27.166515350341797
10 23.952909469604492
11 17.486570358276367
12 14.738186836242676
13 25.207508087158203
14 17.88997459411621
15 28.424161911010742
16 24.277286529541016
17 16.115678787231445
18 24.683963775634766
19 18.352615356445312
20 15.384073257446289
21 25.128713607788086
22 23.513647079467773
23 30.098785400390625
24 15.49545669555664
25 17.680383682250977
26 20.99519157409668
27 22.895931243896484
28 16.312349319458008
29 19.502695083618164
30 16.958480834960938
31 15.671489715576172
32 17.41341209411621
33 22.835983276367188
34 17.519819259643555
35 19.67540168762207
36 13.733476638793945
37 20.012474060058594
38 15.164956092834473
39 17.451324462890625
40 20.75246238708496
41 23.9717960357666
42 21.648635864257812
43 26.736297607421875
44 19.488252639770508
45 15.22969913482666
46 18.959274

380 3.6037545204162598
381 3.328852653503418
382 3.7398786544799805
383 2.823195219039917
384 5.004744529724121
385 3.0245935916900635
386 2.772092819213867
387 2.9307737350463867
388 2.70703387260437
389 2.150507688522339
390 3.0735747814178467
391 2.326432704925537
392 4.388235569000244
393 3.8955202102661133
394 4.206214904785156
395 3.5463199615478516
396 2.372321844100952
397 2.42350435256958
398 3.5455076694488525
399 3.264883279800415
400 4.052709102630615
401 4.357390403747559
402 4.356958389282227
403 3.942857265472412
404 3.6445443630218506
405 3.9862172603607178
406 4.030275344848633
407 3.4663338661193848
408 2.221125841140747
409 3.8166286945343018
410 3.9755678176879883
411 2.9836955070495605
412 3.211559772491455
413 2.972045660018921
414 2.9078726768493652
415 4.3148908615112305
416 2.698582649230957
417 2.5513339042663574
418 2.6423728466033936
419 2.8489389419555664
420 2.836050510406494
421 3.128516435623169
422 1.6090927124023438
423 3.282806158065796
424 2.65580391

In [131]:
optimizer = torch.optim.Adam(model3.parameters(), lr=learning_rate, weight_decay=0.0001)
for t in range(500):
    index = torch.randperm(N)
    y_pred = model3(torch.cat((x1[index[0:batch_size],:], x3[index[0:batch_size],:]), 1))

    loss = loss_fn(y_pred, y[index[0:batch_size],:])
    print(t, loss.data[0])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
y_test_pred = model3(torch.cat((x1_test, x3_test), 1))
test_loss = loss_fn(y_test_pred, y_test)
print(test_loss.data[0])

0 16.96501922607422
1 25.16316032409668
2 33.521080017089844
3 20.910531997680664
4 18.662824630737305
5 21.706912994384766
6 22.953046798706055
7 23.470245361328125
8 25.775096893310547
9 26.61598777770996
10 23.17015266418457
11 13.787217140197754
12 20.268667221069336
13 22.85104751586914
14 16.97904396057129
15 22.614538192749023
16 15.852815628051758
17 26.186115264892578
18 19.958303451538086
19 20.364883422851562
20 23.34739112854004
21 14.897451400756836
22 22.7756404876709
23 26.582046508789062
24 16.635631561279297
25 18.212568283081055
26 22.446495056152344
27 19.758268356323242
28 16.73517608642578
29 17.770885467529297
30 20.240001678466797
31 16.089494705200195
32 21.136995315551758
33 22.263626098632812
34 14.656776428222656
35 19.299232482910156
36 20.826547622680664
37 16.000463485717773
38 8.254864692687988
39 24.014625549316406
40 13.839470863342285
41 21.430540084838867
42 13.497140884399414
43 18.55499839782715
44 19.713111877441406
45 21.267675399780273
46 16.8820

387 0.8794196248054504
388 0.8477887511253357
389 0.5670363306999207
390 0.6109449863433838
391 1.2150986194610596
392 0.5471371412277222
393 0.6763240694999695
394 0.7152014374732971
395 0.6182607412338257
396 0.6095579862594604
397 0.6677837371826172
398 0.7875270843505859
399 0.6656719446182251
400 0.6082417368888855
401 0.8058264851570129
402 0.6543492078781128
403 0.4894472062587738
404 0.7383871078491211
405 0.512421190738678
406 0.5349602103233337
407 0.621405839920044
408 0.6951754689216614
409 0.5437887907028198
410 0.5949397087097168
411 0.5526716113090515
412 0.4873393177986145
413 0.5036922693252563
414 0.4883797764778137
415 0.41262581944465637
416 0.44504493474960327
417 0.4990594685077667
418 0.6606658101081848
419 0.4770664870738983
420 0.45853137969970703
421 0.5789499878883362
422 0.3537692725658417
423 0.5450637936592102
424 0.3760782778263092
425 0.4731403887271881
426 0.42648616433143616
427 0.4489218592643738
428 0.4978477358818054
429 0.42477571964263916
430 0.25