# PyTorch: optim


A fully connected ReLU network with one hidden layer, trained to predict y from x by minimizing squared Eucledian distance.
This implementation uses the nn package 

In [7]:
import torch 

batch_size = 64
input_dimension = 1000
hidden_dimension = 100
output_dimension = 10

x = torch.randn(batch_size, input_dimension)
y = torch.randn(batch_size, output_dimension)

#Use the nn package to define our model and loss funciton 
model = torch.nn.Sequential(
    torch.nn.Linear(input_dimension, hidden_dimension), 
    torch.nn.ReLU(), 
    torch.nn.Linear(hidden_dimension, output_dimension),
)
loss_fn = torch.nn.MSELoss(size_average = False)

#Use the optim package to define an Optimizer that will update the weightsof
# the model for us. Here we will use Adam; the optim package contains many other optimazation
#algorithms. The first argument to the Adam contructor tells the optimizer which Tensors it should update 
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
for n in range(500):
    #Forward pass: compute predicted y by passing x to the model 
    y_pred = model(x)
    
    #Compute and print loss 
    loss = loss_fn(y_pred, y)
    print(n, loss.item())
    
    #Before the backward pass, use the optimizer object to zero all of the 
    # gradients for the variables it will update (which are the learnable weghts of the model).
    # This is because be default, gradients are accumulated in buffers (i.e. not overwritten) 
    # whenever .backward() is called. CHeckout docs of torch.autograd.backward for more details. 
    optimizer.zero_grad()
    
    #Backward pass: compute gradient of the loss with respect to model parameters
    loss.backward()
    
    #Calling the step function on an Optimizer makes an update to its parameters 
    optimizer.step()

0 707.2461547851562
1 689.624755859375
2 672.4673461914062
3 655.7445678710938
4 639.459716796875
5 623.6341552734375
6 608.2202758789062
7 593.2911987304688
8 578.7227172851562
9 564.597412109375
10 550.947509765625
11 537.6646118164062
12 524.6492919921875
13 511.9114685058594
14 499.555908203125
15 487.562255859375
16 475.9190368652344
17 464.64373779296875
18 453.71282958984375
19 443.0516052246094
20 432.6463317871094
21 422.5628662109375
22 412.7955322265625
23 403.25543212890625
24 393.93475341796875
25 384.82745361328125
26 375.92999267578125
27 367.25732421875
28 358.7690734863281
29 350.4398193359375
30 342.36224365234375
31 334.4931640625
32 326.8040466308594
33 319.268310546875
34 311.89306640625
35 304.6750793457031
36 297.60986328125
37 290.69879150390625
38 283.93817138671875
39 277.32940673828125
40 270.83209228515625
41 264.47344970703125
42 258.2449645996094
43 252.13458251953125
44 246.1531219482422
45 240.28086853027344
46 234.5074920654297
47 228.86866760253906
48 

452 0.001252606394700706
453 0.0012229221174493432
454 0.001193898031488061
455 0.0011655132984742522
456 0.0011377526680007577
457 0.0011106128804385662
458 0.0010840827599167824
459 0.0010581373935565352
460 0.0010327757336199284
461 0.0010079717030748725
462 0.0009837268153205514
463 0.0009600347257219255
464 0.0009368690080009401
465 0.0009142222697846591
466 0.0008920911932364106
467 0.0008704559877514839
468 0.0008493185159750283
469 0.0008286475786007941
470 0.0008084558648988605
471 0.0007887159590609372
472 0.0007694301311857998
473 0.0007505895337089896
474 0.0007321697776205838
475 0.0007141815731301904
476 0.000696595583576709
477 0.0006794251967221498
478 0.0006626436370424926
479 0.0006462547462433577
480 0.0006302366964519024
481 0.0006146028754301369
482 0.0005993209779262543
483 0.000584397348575294
484 0.0005698190070688725
485 0.0005555886309593916
486 0.0005416838685050607
487 0.0005281048361212015
488 0.0005148449563421309
489 0.000501900736708194
490 0.00048925745