In [None]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [3]:
import torch


dtype = torch.float
device = torch.device("cpu")
#device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

(0, 21660224.0)
(1, 14853343.0)
(2, 11161867.0)
(3, 8868970.0)
(4, 7250881.5)
(5, 6010645.0)
(6, 4984365.0)
(7, 4119721.0)
(8, 3378546.0)
(9, 2756118.25)
(10, 2235457.75)
(11, 1810500.5)
(12, 1465018.125)
(13, 1188746.875)
(14, 968156.5)
(15, 793435.875)
(16, 654767.6875)
(17, 544810.5)
(18, 457053.9375)
(19, 386836.875)
(20, 330159.90625)
(21, 284100.59375)
(22, 246336.171875)
(23, 215128.65625)
(24, 189086.0625)
(25, 167170.9375)
(26, 148567.796875)
(27, 132655.15625)
(28, 118949.328125)
(29, 107058.6875)
(30, 96680.5703125)
(31, 87575.9453125)
(32, 79545.1328125)
(33, 72432.0)
(34, 66100.2578125)
(35, 60443.26171875)
(36, 55372.40234375)
(37, 50811.1875)
(38, 46698.4140625)
(39, 42981.6328125)
(40, 39612.99609375)
(41, 36555.28515625)
(42, 33774.33203125)
(43, 31241.603515625)
(44, 28930.623046875)
(45, 26817.19140625)
(46, 24881.4453125)
(47, 23105.962890625)
(48, 21475.498046875)
(49, 19975.9921875)
(50, 18596.16015625)
(51, 17324.484375)
(52, 16152.2548828125)
(53, 15069.79003906