In [1]:
%matplotlib inline

In [14]:
torch.backends.cudnn.version()

7005


PyTorch: Defining new autograd functions
----------------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Variables, and uses PyTorch autograd to compute gradients.

In this implementation we implement our own custom autograd function to perform
the ReLU function.



In [3]:
import torch

class MyReLu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        #ctx: context object
        ctx.save_for_backward(input)
        return input.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input< 0]=0
        return grad_input

    

In [4]:
dtype = torch.float
device = torch.device("cuda:0")

In [5]:
N, D_in, H, D_out = 64,1000,100,10

x = torch.randn(N, D_in, device = device, dtype = dtype)
y = torch.randn(N, D_out, device = device, dtype = dtype)

w1 = torch.randn(D_in, H, device = device, dtype  = dtype)
w2 = torch.randn(H, D_out, device = device, dtype = dtype)




In [6]:
learning_rate = 1e-6

In [12]:
for t in range(500):
    relu = MyReLu.apply
    y_pred = relu(x.mm(w1)).mm(w2)
    loss = (y_pred-y).pow(2).sum()
    print(t,loss.item())
    loss.backward()
    with torch.no_grad():
        w1 -= learning_rate* w1.grad
        w2 -= learning_rate* w2.grad
        w1.grad.zero_()
        w2.grad.zero_()

0 4.9764745199354365e-05
1 4.5935888920212165e-05
2 4.524800897343084e-05
3 4.464765515876934e-05
4 4.394744973978959e-05
5 4.3624462705338374e-05
6 4.3103900679852813e-05
7 4.2705050873337314e-05
8 4.243510193191469e-05
9 4.200885450700298e-05
10 4.152847395744175e-05
11 4.0935174183687195e-05
12 4.052142321597785e-05
13 4.001589331892319e-05
14 3.9619240851607174e-05
15 3.9311242289841175e-05
16 3.878085044561885e-05
17 3.846123581752181e-05
18 3.802759601967409e-05
19 3.7658137443941087e-05
20 3.7243640690576285e-05
21 3.670294609037228e-05
22 3.641609873739071e-05
23 3.598678813432343e-05
24 3.571982597350143e-05
25 3.5376531741349027e-05
26 3.500149614410475e-05
27 3.465908957878128e-05
28 3.4532775316620246e-05
29 3.409846976865083e-05
30 3.365150769241154e-05
31 3.3303451346000656e-05
32 3.305594509583898e-05
33 3.2697342248866335e-05
34 3.251175076002255e-05
35 3.224558895453811e-05
36 3.1921434128889814e-05
37 3.173962977598421e-05
38 3.164495137752965e-05
39 3.12862939608749e

341 6.225112429092405e-06
342 6.18920057604555e-06
343 6.137946456874488e-06
344 6.155156370368786e-06
345 6.143699465610553e-06
346 6.099370239098789e-06
347 6.085968834668165e-06
348 6.065021807444282e-06
349 6.074290922697401e-06
350 6.071797088225139e-06
351 6.048549039405771e-06
352 6.0642823882517405e-06
353 6.052992830518633e-06
354 6.060552095732419e-06
355 5.999417680868646e-06
356 5.996284471621038e-06
357 5.962559953331947e-06
358 5.939471975580091e-06
359 5.945589691691566e-06
360 5.917326234339271e-06
361 5.898187282582512e-06
362 5.868364041816676e-06
363 5.850581146660261e-06
364 5.81329140914022e-06
365 5.795682227471843e-06
366 5.779948423878523e-06
367 5.737366791436216e-06
368 5.721458364860155e-06
369 5.729511030949652e-06
370 5.674793555954238e-06
371 5.655017957906239e-06
372 5.640689778374508e-06
373 5.640156359731918e-06
374 5.6094340834533796e-06
375 5.607092589343665e-06
376 5.5680234254396055e-06
377 5.552483344217762e-06
378 5.528969268198125e-06
379 5.50698

In [9]:
import torch


class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply

    # Forward pass: compute predicted y using operations; we compute
    # ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 42173552.0
1 43682488.0
2 45808468.0
3 39609448.0
4 25287888.0
5 12328001.0
6 5464715.0
7 2758648.5
8 1721592.875
9 1257839.0
10 996414.125
11 819235.6875
12 685961.75
13 580758.75
14 495726.90625
15 425791.5625
16 367733.5625
17 319178.90625
18 278294.90625
19 243635.578125
20 214081.609375
21 188807.125
22 167047.15625
23 148227.53125
24 131873.984375
25 117616.890625
26 105149.90625
27 94210.5234375
28 84572.6953125
29 76063.1171875
30 68529.265625
31 61842.9375
32 55894.54296875
33 50594.42578125
34 45870.5625
35 41641.6171875
36 37848.42578125
37 34440.28125
38 31373.376953125
39 28609.373046875
40 26114.857421875
41 23862.6796875
42 21824.97265625
43 19978.060546875
44 18302.728515625
45 16781.451171875
46 15398.1552734375
47 14139.3564453125
48 12992.3349609375
49 11946.6083984375
50 10992.302734375
51 10121.1279296875
52 9324.990234375
53 8596.75
54 7929.87451171875
55 7318.7802734375
56 6758.5751953125
57 6245.44921875
58 5774.2998046875
59 5341.39111328125
60 4943.391601562

389 0.0003833176742773503
390 0.0003735351492650807
391 0.00036407707375474274
392 0.0003546085499692708
393 0.0003453810350038111
394 0.00033786287531256676
395 0.00032963600824587047
396 0.0003212891169823706
397 0.00031433856929652393
398 0.0003062333562411368
399 0.0002991486981045455
400 0.0002917553938459605
401 0.00028528011171147227
402 0.0002780233626253903
403 0.0002716136514209211
404 0.0002654274576343596
405 0.00025929315597750247
406 0.0002541420399211347
407 0.00024775281781330705
408 0.00024241516075562686
409 0.00023630379291716963
410 0.00023138510005082935
411 0.00022669101599603891
412 0.00022202625405043364
413 0.00021744590776506811
414 0.0002126364124706015
415 0.00020852680609095842
416 0.000203644871362485
417 0.0001994050689972937
418 0.00019472902931738645
419 0.00019021061598323286
420 0.00018620635091792792
421 0.00018253407324664295
422 0.00017943288548849523
423 0.00017564449808560312
424 0.00017196011322084814
425 0.00016905255324672908
426 0.00016592560