In [5]:
import torch
"""
N: batch size;
D_in: input dimension;
H: hidden dimension;
D_out: output dimension.
x, y: random, input
model: D_in->H, Relu(), H->D_out
loss_fn: MSEloss
learning_rate: 1e-6
optimizer: Adam

"""

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Create random Tensors to hold inputs and outputs.
x = torch.randn(N, D_in).to(device)
y = torch.randn(N, D_out).to(device)
 
# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.ReLU(),
          torch.nn.Linear(H, D_out),
        ).to(device)
loss_fn = torch.nn.MSELoss(reduction='sum')
 
# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
  # Forward pass: compute predicted y by passing x to the model.
  y_pred = model(x)
 
  # Compute and print loss.
  loss = loss_fn(y_pred, y)
  print(t, loss.item())
  
  # Before the backward pass, use the optimizer object to zero all of the
  # gradients for the Tensors it will update (which are the learnable weights
  # of the model)
  optimizer.zero_grad()
 
  # Backward pass: compute gradient of the loss with respect to model parameters
  loss.backward()
 
  # Calling the step function on an Optimizer makes an update to its parameters
  optimizer.step()


0 671.3015747070312
1 654.6463623046875
2 638.4153442382812
3 622.6648559570312
4 607.4295654296875
5 592.6456298828125
6 578.397216796875
7 564.5054931640625
8 550.9749755859375
9 537.8048095703125
10 524.911376953125
11 512.368408203125
12 500.21441650390625
13 488.4389343261719
14 476.9700927734375
15 465.7798156738281
16 454.8664855957031
17 444.258544921875
18 433.9541015625
19 423.964111328125
20 414.2348937988281
21 404.78717041015625
22 395.5699462890625
23 386.6094665527344
24 377.86309814453125
25 369.3631896972656
26 361.10858154296875
27 353.0291748046875
28 345.1308288574219
29 337.4386291503906
30 329.9320068359375
31 322.603759765625
32 315.4443054199219
33 308.4097900390625
34 301.51593017578125
35 294.786376953125
36 288.2240295410156
37 281.8104553222656
38 275.52838134765625
39 269.38299560546875
40 263.3674011230469
41 257.46630859375
42 251.69009399414062
43 246.03567504882812
44 240.4913787841797
45 235.04873657226562
46 229.711181640625
47 224.46707153320312
48 2

394 1.0502270015422255e-05
395 9.796547601581551e-06
396 9.136340850091074e-06
397 8.519376933691092e-06
398 7.9441651905654e-06
399 7.404574716929346e-06
400 6.901797860336956e-06
401 6.431964720832184e-06
402 5.993091235723114e-06
403 5.583691745414399e-06
404 5.20109415447223e-06
405 4.844121576752514e-06
406 4.510873623075895e-06
407 4.199841896479484e-06
408 3.9096212276490405e-06
409 3.6387973523233086e-06
410 3.386348907952197e-06
411 3.1514146030531265e-06
412 2.9315233405213803e-06
413 2.7268897611065768e-06
414 2.536526380936266e-06
415 2.3583811525895726e-06
416 2.1925047803961206e-06
417 2.0384691197250504e-06
418 1.8943451323139016e-06
419 1.7605169659873354e-06
420 1.635759076634713e-06
421 1.5195731748463004e-06
422 1.4115088333710446e-06
423 1.3111310863678227e-06
424 1.217416411236627e-06
425 1.1303105793558643e-06
426 1.0492477713341941e-06
427 9.737962045619497e-07
428 9.03587988432264e-07
429 8.384731131627632e-07
430 7.777392738717026e-07
431 7.216435733425897e-07
