# PyTorch: Defining new autograd functions

In PyTorch we can easily define our own autograd operator by defining a subclass of torch.autograd.Function and implementing the forward and backward functions. We can then use our new autograd operator by constructing an instance and calling it like a function, passing Tensors containing input data.



In [5]:
import torch

class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing 
    torch.autograd.Function and implementing the forward and backward passes 
    which operate on Tensors.
    """
    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive t=a Tensor conatining the input and return a
        Tensor containgn the output. ctx is a context object that can be used to stash information 
        for backward computation. You can cache arbitrary objects for use in the backward pass
        using the ctx.save_for_backward method. 
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing a gradient of the loss with respect 
        to the output, and we need to compute the gradient of the loss with respect to the input 
        """
        
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input 
    
dtype = torch.float
device = torch.device("cpu")

batch_size = 64
input_dimension = 1000
hidden_dimension = 100
output_dimension = 10

# Generate random data
x = torch.randn(batch_size, input_dimension, device = device, dtype = dtype)
y = torch.randn(batch_size, output_dimension, device = device, dtype = dtype)

# Initialize random weights 
weight1 = torch.randn(input_dimension, hidden_dimension, device = device, dtype = dtype, requires_grad=True)
weight2 = torch.randn(hidden_dimension, output_dimension, device = device, dtype = dtype, requires_grad=True)

learning_rate = 1e-6
for n in range(500):
    #To apply our Function, we use Funciton.apply method
    # We alias this as 'relu'
    relu = MyReLU.apply
    
    #Forward pass: compute predicted y using operations: 
    # Compute ReLU using our suctom autograd operation
    y_pred = relu(x.mm(weight1)).mm(weight2)
    
    #Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(n, loss.item())
    
    # Use autograd to compute the backward pass
    loss.backward()
    
    #Update weights sing gradient descent
    with torch.no_grad():
        weight1 -= learning_rate * weight1.grad
        weight2 -= learning_rate * weight2.grad
        
        #Manually zero the gradients after updating weights 
        weight1.grad.zero_()
        weight2.grad.zero_()
        
    
    

0 32772474.0
1 34797584.0
2 39464240.0
3 39190744.0
4 29969980.0
5 16956194.0
6 7743437.5
7 3491127.75
8 1857910.625
9 1215071.25
10 914041.3125
11 737843.3125
12 615360.6875
13 521814.65625
14 446892.0
15 385454.125
16 334393.875
17 291574.34375
18 255395.59375
19 224658.6875
20 198357.734375
21 175747.609375
22 156214.234375
23 139268.28125
24 124476.359375
25 111529.8828125
26 100175.1015625
27 90187.5703125
28 81374.8125
29 73577.140625
30 66653.5234375
31 60481.078125
32 54966.046875
33 50025.23828125
34 45598.41796875
35 41621.765625
36 38040.91015625
37 34811.95703125
38 31891.271484375
39 29246.474609375
40 26850.9453125
41 24673.46875
42 22692.921875
43 20889.822265625
44 19245.728515625
45 17745.619140625
46 16375.1259765625
47 15122.6455078125
48 13975.779296875
49 12924.609375
50 11960.6083984375
51 11076.25390625
52 10263.673828125
53 9516.46484375
54 8828.9951171875
55 8196.22265625
56 7612.61865234375
57 7074.4189453125
58 6577.650390625
59 6118.8837890625
60 5695.527343

418 0.00019025443179998547
419 0.0001861627824837342
420 0.00018186139641329646
421 0.00017762243805918843
422 0.00017407358973287046
423 0.0001700558204902336
424 0.0001661398564465344
425 0.00016269586922135204
426 0.00015920177975203842
427 0.0001555293711135164
428 0.00015228443953674287
429 0.00014878281217534095
430 0.0001455090387025848
431 0.0001430082629667595
432 0.0001393901911797002
433 0.00013678503455594182
434 0.00013393785047810525
435 0.0001310396910412237
436 0.00012833144864998758
437 0.000125556078273803
438 0.0001234533847309649
439 0.00012115157005609944
440 0.00011887362052220851
441 0.00011667267244774848
442 0.00011443143739597872
443 0.00011235303099965677
444 0.00010994490003213286
445 0.00010822820331668481
446 0.0001065387186827138
447 0.00010415763244964182
448 0.00010214298526989296
449 0.00010056769679067656
450 9.865203173831105e-05
451 9.679447975941002e-05
452 9.508022776572034e-05
453 9.35966891120188e-05
454 9.176741150440648e-05
455 9.0350564278196