In [1]:
import numpy as np
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import random

In [2]:
torch.set_printoptions(10)
seed = 0

def set_seed(seed):
#     cudnn.benchmark = False
#     cudnn.deterministic = True

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
set_seed(seed)

# Numpy

In [3]:
"""
A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x using Euclidean error.
This implementation uses numpy to manually compute the forward pass, loss, and
backward pass.
A numpy array is a generic n-dimensional array; it does not know anything about
deep learning or gradients or computational graphs, and is just a way to perform
generic numeric computations.
"""

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

# Randomly initialize weights
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
  # Forward pass: compute predicted y
  h = x.dot(w1)
  h_relu = np.maximum(h, 0)
  y_pred = h_relu.dot(w2)
  
  # Compute and print loss
  loss = np.square(y_pred - y).sum()
  print(t, loss)
  
  # Backprop to compute gradients of w1 and w2 with respect to loss
  grad_y_pred = 2.0 * (y_pred - y)
  grad_w2 = h_relu.T.dot(grad_y_pred)
  grad_h_relu = grad_y_pred.dot(w2.T)
  grad_h = grad_h_relu.copy()
  grad_h[h < 0] = 0
  grad_w1 = x.T.dot(grad_h)
 
  # Update weights
  w1 -= learning_rate * grad_w1
  w2 -= learning_rate * grad_w2

0 26365524.193454307
1 18539004.95677144
2 14019258.774463901
3 10761426.115604792
4 8169072.895881062
5 6099576.785099281
6 4499933.628590757
7 3309552.858061184
8 2448877.062830102
9 1836295.6953177939
10 1401510.3554378161
11 1091450.0026216193
12 867583.8290616142
13 702751.8395318142
14 578821.849254452
15 483705.55141397787
16 409258.65372018237
17 349984.5458304787
18 301912.7517767879
19 262373.29954350623
20 229437.54386099748
21 201725.3396193426
22 178190.95402275835
23 158046.65045700333
24 140704.4563696527
25 125687.00958495442
26 112617.03015166077
27 101208.33623653787
28 91198.10600890784
29 82378.02880559987
30 74580.71681630617
31 67667.6643213151
32 61521.61810792578
33 56042.554621769705
34 51144.28138319467
35 46756.47637796733
36 42816.53245462748
37 39270.62895496186
38 36073.4944180353
39 33183.64832382827
40 30567.245427718397
41 28193.99703700848
42 26037.525430212638
43 24075.757799959956
44 22287.68086269561
45 20656.11883451095
46 19165.497847104092
47 178

392 0.004525930275686374
393 0.004351654886141213
394 0.0041841284481591275
395 0.004022949950899415
396 0.0038680222168257497
397 0.00371905873440962
398 0.003575891214921854
399 0.003438231755966013
400 0.003305885880486539
401 0.003178683495108549
402 0.0030563631889825374
403 0.002938771324879172
404 0.0028257721916459515
405 0.002717084484711289
406 0.002612595453575259
407 0.002512140370290178
408 0.0024155500935104943
409 0.0023226908611876556
410 0.0022333890451342792
411 0.002147563905077541
412 0.0020650236914509764
413 0.0019856709605726138
414 0.0019093752213397842
415 0.001836070524655552
416 0.0017655699544486551
417 0.0016977524239640514
418 0.0016325572019738148
419 0.0015698663901494336
420 0.0015095896557237296
421 0.0014516523243327736
422 0.0013959191519695726
423 0.0013423453452926933
424 0.0012908312218957857
425 0.0012413218103345704
426 0.0011937396246933583
427 0.0011479336459637936
428 0.0011039013533103445
429 0.0010615621872878895
430 0.001020854265060344
43

# Pytorch

## Manual Calculate

In [3]:
set_seed(0)
device = torch.device('cpu')
# device = torch.device('cuda') # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device)
w2 = torch.randn(H, D_out, device=device)


learning_rate = 1e-6
for t in range(500):
  # Forward pass: compute predicted y
  h = x.mm(w1)
  h_relu = h.clamp(min=0)
  y_pred = h_relu.mm(w2)

  # Compute and print loss; loss is a scalar, and is stored in a PyTorch Tensor
  # of shape (); we can get its value as a Python number with loss.item().
  scale = torch.sqrt(torch.FloatTensor([3]))
  loss = (y_pred - y).pow(2).sum() / 2
  

  # Backprop to compute gradients of w1 and w2 with respect to loss
  grad_y_pred = 2 * (y_pred - y) / 2
  grad_w2 = h_relu.t().mm(grad_y_pred)
  grad_h_relu = grad_y_pred.mm(w2.t())
  grad_h = grad_h_relu.clone()
  grad_h[h < 0] = 0
  grad_w1 = x.t().mm(grad_h)

  print(t, loss.item(), grad_w2[0,:3].tolist())

  # Update weights using gradient descent
  w1 -= learning_rate * grad_w1
  w2 -= learning_rate * grad_w2
    
#   print(w2.shape, grad_y_pred.shape, grad_w2.shape)
#   break

0 14714332.0 [-4652.74853515625, -76917.609375, 73483.203125]
1 8884345.0 [-9479.4384765625, 60128.60546875, 29030.208984375]
2 7004241.5 [-8675.3896484375, 36310.81640625, 25557.88671875]
3 5653118.0 [-8898.234375, 32313.70703125, 21330.892578125]
4 4635046.0 [-8925.4560546875, 27715.87109375, 17925.59375]
5 3851391.0 [-8939.767578125, 24101.111328125, 15254.0185546875]
6 3236659.5 [-8804.087890625, 21117.6015625, 13099.0185546875]
7 2747155.0 [-8569.4111328125, 18612.884765625, 11348.5771484375]
8 2352047.0 [-8322.4248046875, 16494.74609375, 9859.8505859375]
9 2029288.875 [-8073.0087890625, 14687.08984375, 8608.298828125]
10 1763313.875 [-7809.8447265625, 13165.904296875, 7503.20556640625]
11 1541914.75 [-7532.95703125, 11860.017578125, 6551.36328125]
12 1355897.0 [-7251.3857421875, 10719.490234375, 5737.4365234375]
13 1198771.875 [-6929.63818359375, 9725.6962890625, 5064.2724609375]
14 1064955.0 [-6644.86083984375, 8851.4541015625, 4519.36279296875]
15 950065.5 [-6372.900390625, 807

183 290.15399169921875 [-36.55781173706055, 44.135494232177734, 46.39582824707031]
184 280.7472229003906 [-35.693031311035156, 43.397613525390625, 45.28382110595703]
185 271.66058349609375 [-34.764339447021484, 42.761436462402344, 44.101654052734375]
186 262.8811340332031 [-33.874027252197266, 42.069793701171875, 43.07806396484375]
187 254.3979034423828 [-33.05713653564453, 41.51272201538086, 41.93484878540039]
188 246.1983642578125 [-32.209110260009766, 40.890380859375, 40.92081069946289]
189 238.27879333496094 [-31.42056655883789, 40.23341751098633, 39.91129684448242]
190 230.63031005859375 [-30.638004302978516, 39.62453842163086, 38.92250061035156]
191 223.23924255371094 [-29.846479415893555, 39.09217071533203, 37.90645980834961]
192 216.08914184570312 [-29.165414810180664, 38.47395324707031, 37.04184341430664]
193 209.18133544921875 [-28.44709587097168, 37.910675048828125, 36.0563850402832]
194 202.5047149658203 [-27.751869201660156, 37.318058013916016, 35.24732971191406]
195 196.0

303 7.209768295288086 [-2.5959432125091553, 7.07023811340332, 2.5956380367279053]
304 7.0005035400390625 [-2.5460870265960693, 6.960391044616699, 2.5409178733825684]
305 6.798269271850586 [-2.4960155487060547, 6.867752552032471, 2.498365640640259]
306 6.601235866546631 [-2.4782392978668213, 6.777451515197754, 2.4844329357147217]
307 6.411051273345947 [-2.385404586791992, 6.647656440734863, 2.3919339179992676]
308 6.225625991821289 [-2.3631033897399902, 6.602870464324951, 2.36100697517395]
309 6.0461530685424805 [-2.3442625999450684, 6.433340072631836, 2.294483184814453]
310 5.871260166168213 [-2.277806043624878, 6.3679351806640625, 2.2662222385406494]
311 5.702021598815918 [-2.224031448364258, 6.241219520568848, 2.166627883911133]
312 5.537458419799805 [-2.1888322830200195, 6.168207168579102, 2.1674866676330566]
313 5.377825736999512 [-2.120651960372925, 6.081921577453613, 2.0902366638183594]
314 5.223066329956055 [-2.1017041206359863, 6.020215034484863, 2.0752511024475098]
315 5.07312

475 0.052974577993154526 [-0.1300024539232254, 0.4993228316307068, 0.15454301238059998]
476 0.051524873822927475 [-0.13930252194404602, 0.4776793122291565, 0.13939636945724487]
477 0.05009768158197403 [-0.14213064312934875, 0.4698408842086792, 0.11554893851280212]
478 0.04874010384082794 [-0.14911919832229614, 0.46668460965156555, 0.11704361438751221]
479 0.04740041866898537 [-0.15839076042175293, 0.4542801082134247, 0.11909297108650208]
480 0.04609513655304909 [-0.1575051248073578, 0.4729258716106415, 0.11978384852409363]
481 0.04484662041068077 [-0.14717556536197662, 0.46324336528778076, 0.1381816864013672]
482 0.043605539947748184 [-0.14108797907829285, 0.45161136984825134, 0.12020903825759888]
483 0.04239983484148979 [-0.13496387004852295, 0.46159037947654724, 0.11343953013420105]
484 0.04122699424624443 [-0.14587602019309998, 0.4342973530292511, 0.12777987122535706]
485 0.04010186344385147 [-0.1322351098060608, 0.4310513436794281, 0.12194488942623138]
486 0.039004985243082047 [-0.

## AutoGrad

In [11]:
set_seed(0)
device = torch.device('cpu')
# device = torch.device('cuda') # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)

# Create random Tensors for weights; setting requires_grad=True means that we
# want to compute gradients for these Tensors during the backward pass.
w1 = torch.randn(D_in, H, device=device, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, requires_grad=True)

learning_rate = 1e-6
loss_fn = torch.nn.MSELoss(reduction='sum') 

for t in range(500):
  # Forward pass: compute predicted y using operations on Tensors. Since w1 and
  # w2 have requires_grad=True, operations involving these Tensors will cause
  # PyTorch to build a computational graph, allowing automatic computation of
  # gradients. Since we are no longer implementing the backward pass by hand we
  # don't need to keep references to intermediate values.
  y_pred = x.mm(w1).clamp(min=0).mm(w2)
    
  # Compute and print loss. Loss is a Tensor of shape (), and loss.item()
  # is a Python number giving its value.
  loss = loss_fn(y_pred, y)

  # Use autograd to compute the backward pass. This call will compute the
  # gradient of loss with respect to all Tensors with requires_grad=True.
  # After this call w1.grad and w2.grad will be Tensors holding the gradient
  # of the loss with respect to w1 and w2 respectively.
  loss.backward()
    
  print(t, loss.item(), w2.grad[0,:3].tolist())

  # Update weights using gradient descent. For this step we just want to mutate
  # the values of w1 and w2 in-place; we don't want to build up a computational
  # graph for the update steps, so we use the torch.no_grad() context manager
  # to prevent PyTorch from building a computational graph for the updates
  with torch.no_grad():
    w1 -= learning_rate * w1.grad
    w2 -= learning_rate * w2.grad

    # Manually zero the gradients after running the backward pass
    w1.grad.zero_()
    w2.grad.zero_()

0 29428664.0 [-9305.4970703125, -153835.21875, 146966.40625]
1 22739450.0 [-27854.61328125, 389329.40625, -27661.73046875]
2 20605262.0 [-8343.005859375, -299725.59375, 108428.3828125]
3 19520382.0 [-28527.59375, 413785.84375, -49037.6796875]
4 17810230.0 [-8476.041015625, -337188.875, 77443.9375]
5 14999208.0 [-20768.025390625, 364682.8125, -35075.90625]
6 11483335.0 [-11058.28515625, -275198.875, 41033.2265625]
7 8096648.0 [-11576.8603515625, 257116.9375, -11200.5166015625]
8 5398717.0 [-13027.2900390625, -177206.796875, 13753.0712890625]
9 3521559.5 [-6166.935546875, 155551.015625, 3862.72119140625]
10 2315860.75 [-12848.0791015625, -100034.671875, 858.1062622070312]
11 1570272.875 [-3837.48095703125, 88219.5, 9417.345703125]
12 1111700.0 [-11415.5888671875, -53094.3515625, -3065.14208984375]
13 825062.875 [-3120.274169921875, 49372.10546875, 9827.544921875]
14 639684.5625 [-9664.62890625, -27214.619140625, -3110.80078125]
15 514220.40625 [-3022.580322265625, 27865.11328125, 8416.69

137 44.82472610473633 [-12.84365177154541, 23.92343521118164, 7.657754898071289]
138 42.33126449584961 [-12.339293479919434, 23.239879608154297, 7.155277252197266]
139 39.984073638916016 [-11.97021484375, 22.58727264404297, 6.792696952819824]
140 37.76753616333008 [-11.485427856445312, 22.012325286865234, 6.349021911621094]
141 35.67774200439453 [-11.170778274536133, 21.38927459716797, 5.966006278991699]
142 33.70558166503906 [-10.769586563110352, 20.90813446044922, 5.719681739807129]
143 31.845827102661133 [-10.454221725463867, 20.20360565185547, 5.517154693603516]
144 30.08890724182129 [-10.008511543273926, 19.7643985748291, 5.027897834777832]
145 28.43230628967285 [-9.744449615478516, 19.156110763549805, 4.863913059234619]
146 26.86838722229004 [-9.424986839294434, 18.716920852661133, 4.494365692138672]
147 25.391687393188477 [-9.091997146606445, 18.185760498046875, 4.292512893676758]
148 23.997304916381836 [-8.773167610168457, 17.69623565673828, 4.090745449066162]
149 22.6856155395

293 0.010062022134661674 [-0.16820906102657318, 0.2964915931224823, -0.002321600914001465]
294 0.00956580601632595 [-0.13674452900886536, 0.3144276738166809, 0.008333176374435425]
295 0.009096969850361347 [-0.157889723777771, 0.28191980719566345, -0.003982603549957275]
296 0.008644966408610344 [-0.1501692533493042, 0.32068929076194763, 0.024810418486595154]
297 0.008219253271818161 [-0.1658172905445099, 0.2795201539993286, 0.010682269930839539]
298 0.007824203930795193 [-0.1085706278681755, 0.2994142770767212, 0.027415543794631958]
299 0.007437130436301231 [-0.18695712089538574, 0.2267884463071823, 0.003833457827568054]
300 0.00706620654091239 [-0.1477024406194687, 0.2698102295398712, 0.017071247100830078]
301 0.0067255450412631035 [-0.18157002329826355, 0.21868184208869934, 0.004744350910186768]
302 0.006394234951585531 [-0.13619272410869598, 0.2673295736312866, 0.02948462963104248]
303 0.006079384591430426 [-0.1520506739616394, 0.19619783759117126, 0.01155775785446167]
304 0.00578629

387 0.00022185167472343892 [-0.02562292292714119, 0.05537642911076546, -0.007961045950651169]
388 0.0002147048944607377 [-0.027907783165574074, 0.035381823778152466, 0.010813392698764801]
389 0.0002095061936415732 [-0.032567813992500305, 0.06586732715368271, -0.006649453192949295]
390 0.00020361166389193386 [-0.015999548137187958, 0.03292552009224892, 0.02860778197646141]
391 0.0001977954088943079 [-0.02596951648592949, 0.05296607315540314, -0.010655317455530167]
392 0.00019300948770251125 [-0.01810530573129654, 0.049620673060417175, 0.033792536705732346]
393 0.0001883586373878643 [-0.018120817840099335, 0.05490904673933983, -0.002691984176635742]
394 0.00018324566190131009 [-0.02057313546538353, 0.04688652232289314, 0.02382834255695343]
395 0.00017954289796762168 [-0.0223260298371315, 0.05490701273083687, 0.008267730474472046]
396 0.00017467296856921166 [-0.033630404621362686, 0.0395556278526783, 0.013218965381383896]
397 0.0001704572350718081 [-0.011503898538649082, 0.064430519938468

482 3.675548214232549e-05 [-0.010206244885921478, 0.0279594324529171, -0.0026673590764403343]
483 3.630343417171389e-05 [-0.009003333747386932, 0.013808704912662506, 0.0027780579403042793]
484 3.5991124605061486e-05 [-0.013345684856176376, 0.024089088663458824, -0.0037104617804288864]
485 3.532845585141331e-05 [-0.015237238258123398, 0.0234688613563776, 0.0012816330417990685]
486 3.474869663477875e-05 [-0.014149971306324005, 0.023904630914330482, -0.003014274872839451]
487 3.4417429560562596e-05 [-0.01417732797563076, 0.01526755839586258, -0.006449089385569096]
488 3.3852396882139146e-05 [-0.01194061804562807, 0.014578443020582199, -0.0055290888994932175]
489 3.353159627295099e-05 [-0.008941861800849438, 0.029035016894340515, -0.008525657467544079]
490 3.320854011690244e-05 [-0.004592878744006157, 0.020323175936937332, -0.001531936228275299]
491 3.279506927356124e-05 [-0.0009937426075339317, 0.03062223084270954, 0.00010042078793048859]
492 3.232156814192422e-05 [-0.0028765536844730377,

## Accumlate Gradients

In [12]:
def batch_generator(tensor, bs):
    while len(tensor) != 0:
        yield tensor[:bs]
        tensor = tensor[bs::]

### Small graph

In [13]:
set_seed(0)
device = torch.device('cpu')
# device = torch.device('cuda') # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs
x_data = torch.randn(N, D_in, device=device)
y_data = torch.randn(N, D_out, device=device)

net_subdiv = 8


# Create random Tensors for weights; setting requires_grad=True means that we
# want to compute gradients for these Tensors during the backward pass.
w1 = torch.randn(D_in, H, device=device, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, requires_grad=True)

loss_fn = torch.nn.MSELoss(reduction='sum') 
learning_rate = 1e-6
loss_scalar = 0
loss = 0
for t in range(500):
  x_dl = batch_generator(x_data, int(N / net_subdiv))
  y_dl = batch_generator(y_data, int(N / net_subdiv))
  for i, (x,y) in enumerate(zip(x_dl, y_dl)):
      # Forward pass: compute predicted y using operations on Tensors. Since w1 and
      # w2 have requires_grad=True, operations involving these Tensors will cause
      # PyTorch to build a computational graph, allowing automatic computation of
      # gradients. Since we are no longer implementing the backward pass by hand we
      # don't need to keep references to intermediate values.
      y_pred = x.mm(w1).clamp(min=0).mm(w2)

      
      # Compute and print loss. Loss is a Tensor of shape (), and loss.item()
      # is a Python number giving its value.
      loss = loss_fn(y_pred, y)
      loss.backward()
              
      loss_scalar += loss.item()     

      # Update for every 8 subdivisions --> 8 subdiv x 8 data pt per batch = 64 data pt
      if ((i+1) % net_subdiv == 0):
          print(t, loss_scalar, w2.grad[0,:3].tolist())
          with torch.no_grad():
            w1 -= learning_rate * w1.grad
            w2 -= learning_rate * w2.grad

            # Manually zero the gradients after running the backward pass
            w1.grad.zero_()
            w2.grad.zero_()
            loss_scalar = 0

0 29428662.5 [-9305.494140625, -153835.203125, 146966.40625]
1 22739450.125 [-27854.6015625, 389329.375, -27661.734375]
2 20605262.625 [-8343.0302734375, -299725.625, 108428.390625]
3 19520380.625 [-28527.5703125, 413785.9375, -49037.67578125]
4 17810229.625 [-8476.076171875, -337188.84375, 77443.9296875]
5 14999207.0 [-20767.98828125, 364682.8125, -35075.92578125]
6 11483334.625 [-11058.296875, -275198.875, 41033.2265625]
7 8096647.6875 [-11576.8798828125, 257116.953125, -11200.5068359375]
8 5398716.6875 [-13027.2529296875, -177206.734375, 13753.0703125]
9 3521559.40625 [-6166.9462890625, 155550.96875, 3862.71044921875]
10 2315860.84375 [-12848.0859375, -100034.65625, 858.0849609375]
11 1570272.953125 [-3837.461181640625, 88219.5, 9417.3896484375]
12 1111700.171875 [-11415.603515625, -53094.34375, -3065.16259765625]
13 825062.9609375 [-3120.264404296875, 49372.12109375, 9827.537109375]
14 639684.6015625 [-9664.626953125, -27214.62890625, -3110.826171875]
15 514220.51953125 [-3022.5769

122 106.89443254470825 [-22.234935760498047, 35.88512420654297, 17.615320205688477]
123 100.8066143989563 [-21.465633392333984, 35.015838623046875, 16.578521728515625]
124 95.07434272766113 [-20.590007781982422, 33.99181365966797, 15.795533180236816]
125 89.67505836486816 [-19.812612533569336, 33.15281677246094, 14.907546997070312]
126 84.59376192092896 [-19.20404815673828, 32.176971435546875, 14.071706771850586]
127 79.81018590927124 [-18.42662239074707, 31.422502517700195, 13.287269592285156]
128 75.30274367332458 [-17.745018005371094, 30.544645309448242, 12.57802677154541]
129 71.05760455131531 [-17.095138549804688, 29.776100158691406, 12.051403999328613]
130 67.0622193813324 [-16.474424362182617, 28.897563934326172, 11.195901870727539]
131 63.29185080528259 [-15.897478103637695, 28.144001007080078, 10.667259216308594]
132 59.74002814292908 [-15.30720043182373, 27.332443237304688, 9.986013412475586]
133 56.39441728591919 [-14.858638763427734, 26.655345916748047, 9.51531982421875]
13

234 0.221977305598557 [-0.7204787731170654, 1.6515882015228271, 0.09141212701797485]
235 0.2104644374921918 [-0.7820798754692078, 1.5400588512420654, 0.01391148567199707]
236 0.19953362364321947 [-0.6157040596008301, 1.5364259481430054, 0.07717108726501465]
237 0.18919373583048582 [-0.6869995594024658, 1.4531536102294922, 0.043710947036743164]
238 0.1794134508818388 [-0.5722784996032715, 1.4276502132415771, 0.002635776996612549]
239 0.17010038997977972 [-0.5912312865257263, 1.37339448928833, -0.000997781753540039]
240 0.16136763244867325 [-0.552939772605896, 1.354780673980713, -0.03560584783554077]
241 0.1530034840106964 [-0.5596973896026611, 1.3435747623443604, 0.016857624053955078]
242 0.1450972631573677 [-0.5498310327529907, 1.32014799118042, 0.03367882966995239]
243 0.13759103883057833 [-0.4754636287689209, 1.2456920146942139, 0.051730990409851074]
244 0.13048065034672618 [-0.517414391040802, 1.1855247020721436, 0.010033607482910156]
245 0.12370654474943876 [-0.49982357025146484, 1

351 0.0007157662657846231 [-0.023859329521656036, 0.07858918607234955, 0.02840738743543625]
352 0.0006904128749738447 [-0.04484705999493599, 0.06982196122407913, 0.023365776985883713]
353 0.000664517156110378 [-0.024579815566539764, 0.0871773213148117, 0.017612792551517487]
354 0.0006413689734472428 [-0.03253001719713211, 0.06135464832186699, 0.03015272319316864]
355 0.0006156679592095315 [-0.032006945461034775, 0.08827189356088638, 0.014213815331459045]
356 0.0005938266294833738 [-0.019028140231966972, 0.056676749140024185, 0.008631870150566101]
357 0.0005734284550271695 [-0.03495260328054428, 0.08281657099723816, 0.027370132505893707]
358 0.0005535410600714386 [-0.004035456106066704, 0.05673401057720184, -0.005873866379261017]
359 0.0005344499732018448 [-0.047507062554359436, 0.07791036367416382, 0.017070475965738297]
360 0.0005168565548956394 [0.00241694413125515, 0.0474836602807045, -0.015040621161460876]
361 0.0004998777567379875 [-0.023943252861499786, 0.08090531826019287, -0.007

454 5.4359602472686674e-05 [-0.0026387423276901245, 0.0437348373234272, 9.801238775253296e-05]
455 5.32227027179033e-05 [-0.0036278944462537766, 0.02437191642820835, -0.006146136671304703]
456 5.269048324407777e-05 [0.000464417040348053, 0.02948753535747528, 0.003818703815340996]
457 5.2305651934148045e-05 [-0.002530043013393879, 0.024083640426397324, -0.001468723639845848]
458 5.124415883983602e-05 [-0.000203598290681839, 0.03664162755012512, -0.008334625512361526]
459 5.038397375756176e-05 [-0.007386712357401848, 0.018973400816321373, -0.0025221146643161774]
460 4.9486309308122145e-05 [-0.0028668828308582306, 0.02939847856760025, -0.0053486451506614685]
461 4.868464588980714e-05 [-0.0015364410355687141, 0.027363311499357224, -0.003122655674815178]
462 4.786598015016352e-05 [-0.008806722238659859, 0.031130671501159668, -0.004776248708367348]
463 4.72533433821809e-05 [-0.0021360553801059723, 0.025822002440690994, -0.0019983649253845215]
464 4.6545337227144046e-05 [-0.003504477441310882

### Huge Graph and memory consumption

In [14]:
set_seed(0)
device = torch.device('cpu')
# device = torch.device('cuda') # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs
x_data = torch.randn(N, D_in, device=device)
y_data = torch.randn(N, D_out, device=device)

net_subdiv = 8


# Create random Tensors for weights; setting requires_grad=True means that we
# want to compute gradients for these Tensors during the backward pass.
w1 = torch.randn(D_in, H, device=device, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, requires_grad=True)

loss_fn = torch.nn.MSELoss(reduction='sum') 
learning_rate = 1e-6
loss_scalar = 0
loss = 0
for t in range(500):
  x_dl = batch_generator(x_data, int(N / net_subdiv))
  y_dl = batch_generator(y_data, int(N / net_subdiv))
  for i, (x,y) in enumerate(zip(x_dl, y_dl)):
      # Forward pass: compute predicted y using operations on Tensors. Since w1 and
      # w2 have requires_grad=True, operations involving these Tensors will cause
      # PyTorch to build a computational graph, allowing automatic computation of
      # gradients. Since we are no longer implementing the backward pass by hand we
      # don't need to keep references to intermediate values.
      y_pred = x.mm(w1).clamp(min=0).mm(w2)

      
      # Compute and print loss. Loss is a Tensor of shape (), and loss.item()
      # is a Python number giving its value.
      loss += loss_fn(y_pred, y)
        
#       loss_scalar += loss.item()



      # Use autograd to compute the backward pass. This call will compute the
      # gradient of loss with respect to all Tensors with requires_grad=True.
      # After this call w1.grad and w2.grad will be Tensors holding the gradient
      # of the loss with respect to w1 and w2 respectively.
#       loss.backward()

#       print(t, loss_scalar, w2.grad[0,:3].tolist())

      # Update weights using gradient descent. For this step we just want to mutate
      # the values of w1 and w2 in-place; we don't want to build up a computational
      # graph for the update steps, so we use the torch.no_grad() context manager
      # to prevent PyTorch from building a computational graph for the updates
#       print((i+1) % net_subdiv)
      if ((i+1) % net_subdiv == 0):
#           print('{} call backward'.format(i))
          loss.backward()
          print(t, loss.item(), w2.grad[0,:3].tolist())
          with torch.no_grad():
            w1 -= learning_rate * w1.grad
            w2 -= learning_rate * w2.grad

            # Manually zero the gradients after running the backward pass
            w1.grad.zero_()
            w2.grad.zero_()
            loss = 0

0 29428660.0 [-9305.49609375, -153835.203125, 146966.40625]
1 22739452.0 [-27854.609375, 389329.375, -27661.732421875]
2 20605264.0 [-8342.98828125, -299725.65625, 108428.3515625]
3 19520382.0 [-28527.599609375, 413785.9375, -49037.6640625]
4 17810230.0 [-8476.037109375, -337188.84375, 77443.9453125]
5 14999208.0 [-20768.048828125, 364682.8125, -35075.91015625]
6 11483336.0 [-11058.2412109375, -275198.875, 41033.21484375]
7 8096649.0 [-11576.90234375, 257116.96875, -11200.521484375]
8 5398717.0 [-13027.2470703125, -177206.8125, 13753.0576171875]
9 3521559.5 [-6166.95556640625, 155551.015625, 3862.72021484375]
10 2315860.75 [-12848.048828125, -100034.7109375, 858.114990234375]
11 1570272.875 [-3837.494873046875, 88219.5390625, 9417.373046875]
12 1111700.125 [-11415.5830078125, -53094.37109375, -3065.139404296875]
13 825062.875 [-3120.298583984375, 49372.13671875, 9827.5244140625]
14 639684.5625 [-9664.6201171875, -27214.666015625, -3110.81201171875]
15 514220.5 [-3022.58203125, 27865.13

130 67.06112670898438 [-16.497737884521484, 28.850326538085938, 11.187420845031738]
131 63.292110443115234 [-15.906176567077637, 28.15760040283203, 10.620555877685547]
132 59.740028381347656 [-15.296890258789062, 27.332563400268555, 10.055136680603027]
133 56.39472961425781 [-14.894037246704102, 26.62440299987793, 9.493632316589355]
134 53.24454116821289 [-14.269472122192383, 25.931182861328125, 9.036550521850586]
135 50.26869201660156 [-13.792027473449707, 25.18771743774414, 8.450213432312012]
136 47.467681884765625 [-13.2879638671875, 24.561756134033203, 8.002201080322266]
137 44.82502365112305 [-12.90540599822998, 23.90290641784668, 7.668695449829102]
138 42.33099365234375 [-12.314541816711426, 23.238245010375977, 7.0969696044921875]
139 39.983673095703125 [-11.966636657714844, 22.59593963623047, 6.8670148849487305]
140 37.7686767578125 [-11.490520477294922, 22.022235870361328, 6.271659851074219]
141 35.677528381347656 [-11.177983283996582, 21.364696502685547, 6.097528457641602]
142

246 0.11736037582159042 [-0.5452367663383484, 1.1723713874816895, 0.06261301040649414]
247 0.11127562075853348 [-0.4625897705554962, 1.1108810901641846, 0.05357524752616882]
248 0.10554354637861252 [-0.478925883769989, 1.086905598640442, 0.06570899486541748]
249 0.10007624328136444 [-0.440180242061615, 1.0978866815567017, -0.01884511113166809]
250 0.09496058523654938 [-0.46213212609291077, 1.004744052886963, 0.09188888967037201]
251 0.09010285884141922 [-0.42758938670158386, 0.9986081123352051, 0.07143145799636841]
252 0.08547645807266235 [-0.40651318430900574, 0.9894540905952454, 0.04807758331298828]
253 0.08104435354471207 [-0.3834814131259918, 0.9409458637237549, -0.0067027658224105835]
254 0.07687924057245255 [-0.38620200753211975, 0.8883283138275146, 0.01753363013267517]
255 0.07293808460235596 [-0.40059733390808105, 0.8803210258483887, 0.03474313020706177]
256 0.06914795190095901 [-0.3543255925178528, 0.842383623123169, -0.004376031458377838]
257 0.06562826782464981 [-0.407565236

347 0.0008374816388823092 [-0.03741130605340004, 0.10937657952308655, 0.015134800225496292]
348 0.0008058507810346782 [-0.008497720584273338, 0.07720434665679932, 0.014086637645959854]
349 0.0007754936814308167 [-0.035523757338523865, 0.1264377236366272, 0.02528982050716877]
350 0.0007478738552890718 [-0.0071055348962545395, 0.06943398714065552, 0.00854765996336937]
351 0.0007171310717239976 [-0.018345847725868225, 0.09789396077394485, 0.021555066108703613]
352 0.0006909200455993414 [-0.008837169036269188, 0.08575333654880524, 0.01864100620150566]
353 0.000666856940370053 [-0.044064708054065704, 0.0708199292421341, 0.01637352630496025]
354 0.0006427427288144827 [0.00635418388992548, 0.0947391539812088, 0.01717233471572399]
355 0.0006180016789585352 [-0.03325650840997696, 0.0777435153722763, 0.033930446952581406]
356 0.000597791513428092 [-0.011905724182724953, 0.07512092590332031, 0.011970733292400837]
357 0.0005755414022132754 [-0.012660293839871883, 0.06270140409469604, 0.00821595918

447 6.0886319261044264e-05 [-0.0042286217212677, 0.02090591751039028, -0.027572695165872574]
448 6.010740980855189e-05 [-0.012101350352168083, 0.05201668292284012, -0.006080204155296087]
449 5.893987326999195e-05 [-0.011201266199350357, 0.04199391230940819, -0.005394276697188616]
450 5.8001129218610004e-05 [-0.00940634123980999, 0.045266926288604736, -0.010282228700816631]
451 5.7230234233429655e-05 [-0.006836654152721167, 0.03326328843832016, -0.014170914888381958]
452 5.611560118268244e-05 [-0.010723065584897995, 0.04722491279244423, -0.007202731911092997]
453 5.4866555728949606e-05 [-0.008380641229450703, 0.0205314289778471, -0.004266553092747927]
454 5.4326668760040775e-05 [-0.0077810161747038364, 0.041754335165023804, -0.004505848977714777]
455 5.30112738488242e-05 [-0.017177078872919083, 0.022974830120801926, -0.007120747584849596]
456 5.232503463048488e-05 [-0.008803716860711575, 0.04197540506720543, -0.0026329285465180874]
457 5.16133222845383e-05 [-0.011869567446410656, 0.0399