In [1]:
%matplotlib inline


PyTorch: Tensors and autograd
-------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Tensors, and uses PyTorch autograd to compute gradients.


A PyTorch Tensor represents a node in a computational graph. If ``x`` is a
Tensor that has ``x.requires_grad=True`` then ``x.grad`` is another Tensor
holding the gradient of ``x`` with respect to some scalar value.



In [6]:
import torch

dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
# Setting requires_grad=False indicates that we do not need to compute gradients
# with respect to these Tensors during the backward pass.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y using operations on Tensors; these
    # are exactly the same operations we used to compute the forward pass using
    # Tensors, but we do not need to keep references to intermediate values since
    # we are not implementing the backward pass by hand.
    y_pred = x.mm(w1).clamp(min=0).mm(w2)

    # Compute and print loss using operations on Tensors.
    # Now loss is a Tensor of shape (1,)
    # loss.item() gets the a scalar value held in the loss.
    loss = (y_pred - y).pow(2).sum()
    print("%4d\t%.6f" % (t, loss))

    # Use autograd to compute the backward pass. This call will compute the
    # gradient of loss with respect to all Tensors with requires_grad=True.
    # After this call w1.grad and w2.grad will be Tensors holding the gradient
    # of the loss with respect to w1 and w2 respectively.
    loss.backward()

    # Manually update weights using gradient descent. Wrap in torch.no_grad()
    # because weights have requires_grad=True, but we don't need to track this
    # in autograd.
    # An alternative way is to operate on weight.data and weight.grad.data.
    # Recall that tensor.data gives a tensor that shares the storage with
    # tensor, but doesn't track history.
    # You can also use torch.optim.SGD to achieve this.
    with torch.no_grad():    # temporarily set all the requires_grad flag to false
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

   0	23547240.000000
   1	16368425.000000
   2	13744135.000000
   3	13129416.000000
   4	13335319.000000
   5	13551493.000000
   6	13166821.000000
   7	11929022.000000
   8	9944028.000000
   9	7674540.500000
  10	5551736.000000
  11	3845173.500000
  12	2603190.000000
  13	1760653.500000
  14	1208902.625000
  15	853298.750000
  16	623273.375000
  17	472582.937500
  18	371278.718750
  19	300988.062500
  20	250436.015625
  21	212696.187500
  22	183497.500000
  23	160193.828125
  24	141101.015625
  25	125113.835938
  26	111497.757812
  27	99757.968750
  28	89540.351562
  29	80586.312500
  30	72688.867188
  31	65693.945312
  32	59473.828125
  33	53922.578125
  34	48958.710938
  35	44510.042969
  36	40514.101562
  37	36914.261719
  38	33670.683594
  39	30745.968750
  40	28102.492188
  41	25708.359375
  42	23536.910156
  43	21564.917969
  44	19772.941406
  45	18143.048828
  46	16663.730469
  47	15315.412109
  48	14085.493164
  49	12961.746094
  50	11934.931641
  51	10999.587891
  52	10144.725