In [10]:
# Code in file tensor/two_layer_net_tensor.py
import torch

device = torch.device('cpu')
# device = torch.device('cuda') # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device)
w2 = torch.randn(H, D_out, device=device)

learning_rate = 1e-6
for t in range(500):
  # Forward pass: compute predicted y
  h = x.mm(w1)
  h_relu = h.clamp(min=0)
  y_pred = h_relu.mm(w2)

  # Compute and print loss; loss is a scalar, and is stored in a PyTorch Tensor
  # of shape (); we can get its value as a Python number with loss.item().
  loss = (y_pred - y).pow(2).sum()
  print(t, loss.item())

  # Backprop to compute gradients of w1 and w2 with respect to loss
  grad_y_pred = 2.0 * (y_pred - y)
  grad_w2 = h_relu.t().mm(grad_y_pred)
  grad_h_relu = grad_y_pred.mm(w2.t())
  grad_h = grad_h_relu.clone()
  grad_h[h < 0] = 0
  grad_w1 = x.t().mm(grad_h)

  # Update weights using gradient descent
  w1 -= learning_rate * grad_w1
  w2 -= learning_rate * grad_w2

0 25951628.0
1 23591032.0
2 24758762.0
3 26092808.0
4 25027210.0
5 20455678.0
6 14128923.0
7 8488802.0
8 4785744.5
9 2735520.25
10 1685503.625
11 1145751.625
12 852346.3125
13 677248.8125
14 561417.625
15 477549.5625
16 412761.0
17 360273.5
18 316527.375
19 279478.6875
20 247770.34375
21 220418.0
22 196682.890625
23 175988.015625
24 157880.5
25 141974.921875
26 127956.6015625
27 115575.390625
28 104635.9140625
29 94924.7109375
30 86275.4296875
31 78553.484375
32 71640.984375
33 65446.015625
34 59876.51171875
35 54856.94921875
36 50328.69140625
37 46243.2421875
38 42542.26171875
39 39185.921875
40 36137.33203125
41 33362.95703125
42 30834.626953125
43 28529.130859375
44 26421.87890625
45 24494.16015625
46 22727.314453125
47 21107.912109375
48 19620.94921875
49 18253.822265625
50 16995.171875
51 15835.9892578125
52 14766.994140625
53 13780.0751953125
54 12868.2333984375
55 12025.3984375
56 11245.3359375
57 10522.6318359375
58 9852.537109375
59 9230.69921875
60 8653.833984375
61 8117.6020

385 0.03622458502650261
386 0.03506925702095032
387 0.033983178436756134
388 0.03291914239525795
389 0.03187290206551552
390 0.03086971677839756
391 0.029903540387749672
392 0.028961678966879845
393 0.028047848492860794
394 0.02718442864716053
395 0.026330430060625076
396 0.02550649642944336
397 0.024703102186322212
398 0.023935826495289803
399 0.02318372018635273
400 0.02246420457959175
401 0.021750612184405327
402 0.0210790503770113
403 0.020420663058757782
404 0.019775772467255592
405 0.01917247287929058
406 0.018575672060251236
407 0.017999794334173203
408 0.017432603985071182
409 0.016892703250050545
410 0.01637277938425541
411 0.01586846262216568
412 0.015375157818198204
413 0.01490082498639822
414 0.014439298771321774
415 0.013997436501085758
416 0.01356395985931158
417 0.013145070523023605
418 0.012746435590088367
419 0.012351022101938725
420 0.011974413879215717
421 0.011603656224906445
422 0.011253601871430874
423 0.010913815349340439
424 0.010578541085124016
425 0.0102611389

0 31247148.0
1 23192936.0
2 20996544.0
3 20534304.0
4 19578206.0
5 17166660.0
6 13393320.0
7 9407224.0
8 6105782.0
9 3844914.25
10 2443219.5
11 1620627.125
12 1138567.375
13 848968.9375
14 665648.6875
15 542484.0625
16 454305.21875
17 387687.8125
18 335139.875
19 292323.375
20 256658.265625
21 226544.640625
22 200845.65625
23 178673.453125
24 159428.671875
25 142631.765625
26 127914.1875
27 114953.640625
28 103514.1796875
29 93393.421875
30 84423.3515625
31 76447.2578125
32 69336.7578125
33 62977.234375
34 57284.4453125
35 52177.2109375
36 47587.05859375
37 43454.88671875
38 39728.51171875
39 36357.296875
40 33306.1015625
41 30540.71484375
42 28031.470703125
43 25752.36328125
44 23680.787109375
45 21793.6171875
46 20073.107421875
47 18501.68359375
48 17065.072265625
49 15750.47265625
50 14546.498046875
51 13443.1943359375
52 12430.884765625
53 11501.1025390625
54 10646.90625
55 9861.2861328125
56 9138.6533203125
57 8473.0
58 7860.068359375
59 7294.61083984375
60 6772.5673828125
61 6290

431 6.925024354131892e-05
432 6.835138628957793e-05
433 6.712444155709818e-05
434 6.563868373632431e-05
435 6.459031283156946e-05
436 6.407573528122157e-05
437 6.267166463658214e-05
438 6.17219993728213e-05
439 6.0757887695217505e-05
440 5.971883729216643e-05
441 5.8877383708022535e-05
442 5.776581019745208e-05
443 5.693183993571438e-05
444 5.615957343252376e-05
445 5.5469954531872645e-05
446 5.47063973499462e-05
447 5.389317811932415e-05
448 5.311513086780906e-05
449 5.231402246863581e-05
450 5.1365612307563424e-05
451 5.098305700812489e-05
452 5.0095823098672554e-05
453 4.9236841732636094e-05
454 4.874240039498545e-05
455 4.816383443539962e-05
456 4.7199009713949636e-05
457 4.636741505237296e-05
458 4.589073796523735e-05
459 4.579328015097417e-05
460 4.497222835198045e-05
461 4.459972842596471e-05
462 4.384338535601273e-05
463 4.3287182052154094e-05
464 4.284672832000069e-05
465 4.2286297684768215e-05
466 4.142001853324473e-05
467 4.1113227780442685e-05
468 4.0614595491206273e-05
469