# PyTorch: nn

A fully connected ReLU network with one hidden layer, trained to predict y from x by minimizing squared Eucledian distance 

Note:
PyTorch autograd makes it easy to define computational graphs and take gradients, but raw autograd can be a bit too low-level for defining complex neural networks; this is where the nn package can help. The nn package defines a set of Modules, which are in a way neural network layers that produce output from input and may have some trainable weights. 


In [5]:
import torch

batch_size = 64
input_dimension = 1000
hidden_dimension = 100
output_dimension = 10

x = torch.randn(batch_size, input_dimension)
y = torch.randn(batch_size, output_dimension)

#Use the nn package to define our model as a sequence of layers
#nn.Sequential is a Module which contains other Modules, and a applies them in sequence
#to produce its output. Each Linear Module computes output form input using a linear function,
#and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(input_dimension, hidden_dimension),
    torch.nn.ReLU(), 
    torch.nn.Linear(hidden_dimension, output_dimension),
)

#The nn package also contains definitions of populat loss functions; in this case
# we will use Mean Squared Error (MSE) as our loss function
loss_fn = torch.nn.MSELoss(size_average=False)

learning_rate = 1e-4

for n in range(500):
    #Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When 
    # doing so you pass a Tensor of input data to the Module and it produces a Tensor of output data
    y_pred = model(x)
    
    #Compute and print loss. We pass Tensors containing the predicted and true values of y, 
    #and the loss function returns a Tensor containing the loss 
    loss = loss_fn(y_pred, y)
    print(n, loss.item())
    
    #Zero the gradients before running the backward pass
    model.zero_grad()
    
    #Backward pass: compute gradient of the loss with respect to all the learnable 
    #parameters of the model. Internally, the parameters of each Module are stored in Tensors
    #with required_grad = True, so this call will compute gradients for all learnable parameters 
    # in the model 
    loss.backward()
    
    #Update the weights using gradients descent. Each parameter is a Tensor so we can acsess 
    #and gradients like we did before
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 688.6853637695312
1 635.7244262695312
2 589.9434814453125
3 550.083984375
4 515.1185302734375
5 483.629638671875
6 455.05548095703125
7 429.166259765625
8 405.27032470703125
9 383.1319274902344
10 362.560546875
11 343.35577392578125
12 325.3801574707031
13 308.3205261230469
14 292.0964050292969
15 276.7087097167969
16 262.1640625
17 248.3519744873047
18 235.24530029296875
19 222.81497192382812
20 210.9897918701172
21 199.68276977539062
22 188.92636108398438
23 178.67356872558594
24 168.9093475341797
25 159.60931396484375
26 150.76792907714844
27 142.3562469482422
28 134.35523986816406
29 126.76847839355469
30 119.59046936035156
31 112.77050018310547
32 106.29612731933594
33 100.13944244384766
34 94.32103729248047
35 88.82373809814453
36 83.62598419189453
37 78.72298431396484
38 74.09806823730469
39 69.7412338256836
40 65.63906860351562
41 61.78065872192383
42 58.14900207519531
43 54.734642028808594
44 51.52432632446289
45 48.510807037353516
46 45.67902755737305
47 43.011749267578125


449 2.5173907488351688e-05
450 2.448266968713142e-05
451 2.3813230654923245e-05
452 2.3162168872659095e-05
453 2.2527128749061376e-05
454 2.1909416318521835e-05
455 2.1312520402716473e-05
456 2.072807001241017e-05
457 2.016134385485202e-05
458 1.9609513401519507e-05
459 1.9074073861702345e-05
460 1.8552615074440837e-05
461 1.8046144759864546e-05
462 1.7553082216181792e-05
463 1.7075306459446438e-05
464 1.6607369616394863e-05
465 1.61555472004693e-05
466 1.5715229892521165e-05
467 1.5286370398825966e-05
468 1.4870018276269548e-05
469 1.4464682863035705e-05
470 1.4068656128074508e-05
471 1.3686739293916617e-05
472 1.3314713214640506e-05
473 1.29510517581366e-05
474 1.2598925422935281e-05
475 1.2256660738785286e-05
476 1.1923724741791375e-05
477 1.1599073332035914e-05
478 1.1284442734904587e-05
479 1.097708627639804e-05
480 1.0677655154722743e-05
481 1.0389316230430268e-05
482 1.0106229638040531e-05
483 9.832571777224075e-06
484 9.565402251610067e-06
485 9.30566147872014e-06
486 9.0532439