In [0]:
# Install Pytorch.
from os import path
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())

accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'

!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.3.0.post4-{platform}-linux_x86_64.whl torchvision

In [0]:
%matplotlib inline


PyTorch: Defining new autograd functions
----------------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Variables, and uses PyTorch autograd to compute gradients.

In this implementation we implement our own custom autograd function to perform
the ReLU function.



In [0]:
import torch
from torch.autograd import Variable

In [0]:
class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """
    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the 
        ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the 
        loss with respect to the output, and we need to compute the gradient 
        of the loss with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

In [6]:
# dtype = torch.FloatTensor # Run on CPU
dtype = torch.cuda.FloatTensor

# N: batch size, D_in: input dim, H: hidden dim, D_out: output dim
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)

# Create random Tensors for weights, and wrap them in Variables.
w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)

learning_rate = 1e-6

for t in range(500):
    # To apply our Function, we use Function.apply method.
    # We alias this as "relu".
    relu = MyReLU.apply
    
    # Forward pass: compute predicted y using operations on Variables.
    # We compute ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)
    
    # Compute and print loss.
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.data[0])
    
    # Use autograd to compute the backward pass.
    loss.backward()
    
    # Update weights using gradient descent.
    w1.data -= learning_rate * w1.grad.data
    w2.data -= learning_rate * w2.grad.data
    
    # Manually zero the gradient after updating weights.
    w1.grad.data.zero_()
    w2.grad.data.zero_()

0 25005564.0
1 20043392.0
2 18847940.0
3 18811310.0
4 18364160.0
5 16581848.0
6 13519177.0
7 9949932.0
8 6753169.0
9 4364870.0
10 2785861.25
11 1806242.0
12 1215848.75
13 858918.5625
14 638164.5
15 496120.1875
16 400258.65625
17 332213.34375
18 281468.3125
19 241983.84375
20 210237.203125
21 184074.84375
22 162110.03125
23 143419.65625
24 127382.34375
25 113548.8359375
26 101520.4375
27 90995.921875
28 81748.90625
29 73589.4609375
30 66371.515625
31 59968.4296875
32 54273.828125
33 49195.55859375
34 44660.1875
35 40599.1484375
36 36958.3984375
37 33687.83984375
38 30745.7890625
39 28092.685546875
40 25697.416015625
41 23532.033203125
42 21569.560546875
43 19790.408203125
44 18174.978515625
45 16706.828125
46 15370.6904296875
47 14152.59375
48 13041.2216796875
49 12026.6357421875
50 11099.01171875
51 10250.18359375
52 9473.2529296875
53 8760.7080078125
54 8106.9296875
55 7506.7041015625
56 6955.05322265625
57 6447.8271484375
58 5980.8955078125
59 5550.970703125
60 5154.65087890625
61 47