In [28]:
import torch

In [29]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

In [30]:
# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

In [31]:
# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

In [32]:
loss_fn = torch.nn.MSELoss(size_average=False)

In [33]:
class MSE_Loss(torch.nn.Module):
    
    def __init__(self):
        super(MSE_Loss,self).__init__()
        
    def forward(self, x, y):
        return ((x-y)**2).sum()

loss_fn = MSE_Loss()

In [34]:
# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [35]:
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 647.56298828125
1 630.1468505859375
2 613.2070922851562
3 596.7378540039062
4 580.75341796875
5 565.2120971679688
6 550.1463623046875
7 535.4425048828125
8 521.1622314453125
9 507.30810546875
10 493.8898620605469
11 480.90478515625
12 468.2868347167969
13 456.1055603027344
14 444.31048583984375
15 432.8438415527344
16 421.7044982910156
17 410.9455261230469
18 400.4736633300781
19 390.3260498046875
20 380.40643310546875
21 370.7688293457031
22 361.4028015136719
23 352.2725524902344
24 343.34912109375
25 334.6623840332031
26 326.2045593261719
27 317.9803161621094
28 309.9455871582031
29 302.1341247558594
30 294.5419616699219
31 287.140380859375
32 279.94268798828125
33 272.9185791015625
34 266.0706787109375
35 259.3677673339844
36 252.81710815429688
37 246.41844177246094
38 240.1749725341797
39 234.06951904296875
40 228.089111328125
41 222.2665557861328
42 216.59866333007812
43 211.07916259765625
44 205.71148681640625
45 200.4720458984375
46 195.34849548339844
47 190.34239196777344
48 

419 2.1711559838877292e-06
420 2.0143611436651554e-06
421 1.8684930864765192e-06
422 1.7330808077531401e-06
423 1.6063102066254942e-06
424 1.4893421393935569e-06
425 1.3800377018924337e-06
426 1.2787431842298247e-06
427 1.1846875622723019e-06
428 1.0974988526868401e-06
429 1.0160624697164167e-06
430 9.404150205227779e-07
431 8.702593845555384e-07
432 8.054854561123648e-07
433 7.452569548149768e-07
434 6.892271926517424e-07
435 6.372411007760093e-07
436 5.892921990380273e-07
437 5.446584054880077e-07
438 5.031262730881281e-07
439 4.648641152016353e-07
440 4.292813855499844e-07
441 3.9639510873712425e-07
442 3.6614233067666646e-07
443 3.378477231308352e-07
444 3.118392157830385e-07
445 2.876862481571152e-07
446 2.650645285484643e-07
447 2.4455081870655704e-07
448 2.2564188384421868e-07
449 2.0791343047221744e-07
450 1.9159574549121317e-07
451 1.764155399541778e-07
452 1.6246323752966418e-07
453 1.495789661021263e-07
454 1.377746912112343e-07
455 1.2680654037922068e-07
456 1.1658906373668