# [PYTORCH: TENSORS](https://pytorch.org/tutorials/beginner/examples_tensor/two_layer_net_tensor.html)

全连接ReLU网络：1层隐藏层、无偏置，根据𝑥预测𝑦，通过最小化欧氏距离训练网络。

pytorch张量计算网络前向传播、损失、反向传播。

In [7]:
import torch

dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(*(N, D_in), device=device, dtype=dtype)
y = torch.randn(*(N, D_out), device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    
    # Forward pass: compute predicted y
    h = torch.mm(x, w1)
    h_relu = torch.clamp(input=h, min=0)
    y_pred = torch.mm(h_relu, w2)
    # h = x.mm(w1)
    # h_relu = h.clamp(min=0)
    # y_pred = h_relu.mm(w2)
    
    # Compute and print loss
    loss = torch.pow(input=(y_pred - y), exponent=2)
    loss = torch.sum(input=loss).item()
    # loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)
    
    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)
    
    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2
    

0 27463732.0
1 26083364.0
2 28679328.0
3 31264960.0
4 29874694.0
5 23410122.0
6 14755237.0
7 7966616.0
8 4074660.0
9 2209865.5
10 1352627.5
11 939916.0625
12 717856.875
13 581052.875
14 485926.0625
15 414031.71875
16 356737.84375
17 309674.78125
18 270507.28125
19 237383.234375
20 209128.75
21 184889.84375
22 164018.71875
23 145951.125
24 130214.0234375
25 116467.7265625
26 104422.8515625
27 93827.921875
28 84482.5
29 76223.8515625
30 68900.5625
31 62392.1875
32 56589.42578125
33 51402.453125
34 46763.59375
35 42602.99609375
36 38863.53515625
37 35497.7109375
38 32478.13671875
39 29753.10546875
40 27287.408203125
41 25051.537109375
42 23023.640625
43 21181.205078125
44 19506.279296875
45 17980.146484375
46 16588.2421875
47 15316.6015625
48 14152.541015625
49 13086.423828125
50 12109.767578125
51 11214.08203125
52 10390.9462890625
53 9634.400390625
54 8938.43359375
55 8298.0400390625
56 7707.66748046875
57 7163.6904296875
58 6661.9404296875
59 6199.16796875
60 5770.9912109375
61 5375.09