In [1]:
# Code in file tensor/two_layer_net_tensor.py
import torch

# device = torch.device('cpu')
device = torch.device('cuda') # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device)
w2 = torch.randn(H, D_out, device=device)

learning_rate = 1e-6
for t in range(500):
  # Forward pass: compute predicted y
  h = x.mm(w1)
  h_relu = h.clamp(min=0)
  y_pred = h_relu.mm(w2)

  # Compute and print loss; loss is a scalar, and is stored in a PyTorch Tensor
  # of shape (); we can get its value as a Python number with loss.item().
  loss = (y_pred - y).pow(2).sum()
  print(t, loss.item())

  # Backprop to compute gradients of w1 and w2 with respect to loss
  grad_y_pred = 2.0 * (y_pred - y)
  grad_w2 = h_relu.t().mm(grad_y_pred)
  grad_h_relu = grad_y_pred.mm(w2.t())
  grad_h = grad_h_relu.clone()
  grad_h[h < 0] = 0
  grad_w1 = x.t().mm(grad_h)

  # Update weights using gradient descent
  w1 -= learning_rate * grad_w1
  w2 -= learning_rate * grad_w2

0 43221520.0
1 48380608.0
2 54741836.0
3 48879860.0
4 29612112.0
5 12586838.0
6 4951349.0
7 2524984.0
8 1711146.625
9 1337560.75
10 1102834.0
11 928793.125
12 791073.0
13 679262.375
14 587107.25
15 510523.75
16 446258.75
17 391986.0625
18 345842.875
19 306372.0625
20 272419.84375
21 243075.53125
22 217597.375
23 195350.609375
24 175885.40625
25 158768.53125
26 143661.09375
27 130285.0234375
28 118400.6328125
29 107820.8828125
30 98365.515625
31 89902.0
32 82318.0546875
33 75495.09375
34 69338.65625
35 63776.79296875
36 58749.3984375
37 54184.2890625
38 50032.93359375
39 46256.0859375
40 42815.56640625
41 39670.41015625
42 36790.47265625
43 34150.98046875
44 31729.328125
45 29504.1015625
46 27456.501953125
47 25572.8125
48 23835.259765625
49 22229.591796875
50 20746.126953125
51 19373.181640625
52 18101.966796875
53 16925.4453125
54 15834.3798828125
55 14821.2900390625
56 13880.06640625
57 13004.6357421875
58 12190.330078125
59 11432.18359375
60 10726.4814453125
61 10068.6953125
62 9454