# PYTORCH: DEFINING NEW AUTOGRAD FUNCTIONS

全连接ReLU网络：1层隐藏层、无偏置，根据𝑥预测𝑦，通过最小化欧氏距离训练网络。

pytorch张量计算网络前向传播、损失，autograd计算梯度。

In [1]:
import torch

In [2]:
class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """
    
    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
        

In [3]:
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply
    
    # Forward pass: compute predicted y using operations; we compute
    # ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)
    
    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())
    
    # Use autograd to compute the backward pass.
    loss.backward()
    
    # Update weights using gradient descent
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
        
    w1.grad.zero_()
    w2.grad.zero_()


0 26704418.0
1 23119984.0
2 25058864.0
3 28962344.0
4 31388814.0
5 28815458.0
6 21401456.0
7 12828680.0
8 6737296.0
9 3436374.5
10 1890187.5
11 1180184.125
12 835876.5625
13 648153.1875
14 530947.375
15 448334.90625
16 385095.3125
17 334073.0
18 291720.625
19 255881.9375
20 225280.25
21 198968.90625
22 176228.4375
23 156488.671875
24 139253.40625
25 124198.6640625
26 111008.359375
27 99411.5390625
28 89191.4609375
29 80164.046875
30 72170.7734375
31 65073.453125
32 58759.4609375
33 53128.921875
34 48102.93359375
35 43606.64453125
36 39578.04296875
37 35961.640625
38 32711.66015625
39 29782.12890625
40 27142.009765625
41 24760.3359375
42 22609.142578125
43 20661.951171875
44 18897.98046875
45 17298.53515625
46 15844.810546875
47 14521.8798828125
48 13319.41015625
49 12224.1806640625
50 11225.951171875
51 10315.96484375
52 9485.552734375
53 8727.4765625
54 8034.54345703125
55 7400.7265625
56 6820.50927734375
57 6289.14013671875
58 5802.30517578125
59 5355.6865234375
60 4946.60400390625
6

464 3.429989374126308e-05
465 3.3892090868903324e-05
466 3.345632285345346e-05
467 3.295509668532759e-05
468 3.249102519475855e-05
469 3.2345924410037696e-05
470 3.188969276379794e-05
471 3.145761365885846e-05
472 3.106479562120512e-05
473 3.057441062992439e-05
474 3.0312317903735675e-05
475 3.00664032693021e-05
476 2.971201502077747e-05
477 2.93389311991632e-05
478 2.9042430469417013e-05
479 2.878546183637809e-05
480 2.837507054209709e-05
481 2.8106516765546985e-05
482 2.7762069294112734e-05
483 2.749171653704252e-05
484 2.7232586944592185e-05
485 2.693791248020716e-05
486 2.6588038963382132e-05
487 2.642482468218077e-05
488 2.6083354896400124e-05
489 2.5749788619577885e-05
490 2.550400677137077e-05
491 2.5228298909496516e-05
492 2.507366480131168e-05
493 2.4756120183155872e-05
494 2.465858415234834e-05
495 2.447017322992906e-05
496 2.4188431780203246e-05
497 2.3868838979979046e-05
498 2.3644581233384088e-05
499 2.3362030333373696e-05
