In [0]:
# Install Pytorch.
from os import path
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())

accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'

!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.3.0.post4-{platform}-linux_x86_64.whl torchvision
import torch

In [0]:
%matplotlib inline


Warm-up: numpy
--------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x using Euclidean error.

This implementation uses numpy to manually compute the forward pass, loss, and
backward pass.

A numpy array is a generic n-dimensional array; it does not know anything about
deep learning or gradients or computational graphs, and is just a way to perform
generic numeric computations.



In [8]:
import numpy as np

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

# Randomly initialize weights
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6

for t in range(500):
    # Forward pass: compute predicted y
    h = x.dot(w1)
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)
    
    # Compute and print loss
    loss = np.square(y_pred - y).sum()
    print(t, loss)
    
    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 = x.T.dot(grad_h)
    
    # Update weights
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 36051090.34382218
1 32931998.240041863
2 31997668.04804722
3 28064537.904867373
4 20804419.029557742
5 12863090.873992477
6 7182343.221983374
7 3947787.1211063126
8 2335109.4823108353
9 1531159.1755057988
10 1107212.6299823888
11 858185.8864985671
12 695034.1385228182
13 578009.7034577879
14 488643.99091412534
15 417605.2634637584
16 359800.08607781626
17 311894.75818676734
18 271765.2057372482
19 237956.53306421172
20 209190.8030866377
21 184568.7211418646
22 163373.6585166155
23 145037.39828048338
24 129130.22304193722
25 115275.26532313583
26 103170.30848129911
27 92555.24527743497
28 83226.35384634219
29 74997.90268492281
30 67721.71526906168
31 61276.87495752837
32 55541.81276572382
33 50429.44746616839
34 45854.8592837268
35 41759.15431075192
36 38085.38212245512
37 34785.27357765892
38 31814.124084414805
39 29134.937034354174
40 26714.348276872897
41 24522.33636569708
42 22534.31976844623
43 20730.284239006804
44 19089.953698575795
45 17596.961591131785
46 16236.110783843815
4