In [33]:
import torch

In [34]:
a = torch.randn(4)
print('input', a)
grad = torch.randn(4)
print('grad', grad)
grad[ a < 0 ] = 0
print('grad', grad)

input tensor([-1.0301, -1.0086, -0.2909,  0.0084])
grad tensor([ 0.4880, -0.6135,  0.1897, -0.2371])
grad tensor([ 0.0000,  0.0000,  0.0000, -0.2371])


In [35]:
class MyRelu(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, input):
        output = input.clamp(min=0)
        ctx.save_for_backward(output)
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        output, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[output < 0] = 0
        return grad_input        
        

In [36]:
dtype = torch.float
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

cpu


In [37]:
N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn( N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad = True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad = True)



In [38]:
lr = 1e-6
for i in range(500):
    relu1 = MyRelu.apply
    
    y_pred = relu1(x.mm(w1)).mm(w2)
    loss = (y_pred - y).pow(2).sum()
    
    print(i, loss.item())
    
    loss.backward()
    
    with torch.no_grad():
        w1 -= lr * w1.grad
        w2 -= lr * w2.grad
        
        w1.grad.zero_()
        w2.grad.zero_()

0 36747940.0
1 36226480.0
2 40208152.0
3 40451416.0
4 32164512.0
5 19270864.0
6 9336520.0
7 4363877.5
8 2323745.0
9 1489071.75
10 1096238.625
11 872355.375
12 721807.75
13 609185.375
14 519820.625
15 447153.5625
16 386999.09375
17 336756.625
18 294180.875
19 257993.453125
20 227197.375
21 200820.34375
22 178078.546875
23 158445.90625
24 141415.9375
25 126562.765625
26 113522.78125
27 102093.28125
28 92010.171875
29 83060.0390625
30 75120.359375
31 68069.375
32 61769.9453125
33 56139.8125
34 51119.4921875
35 46627.89453125
36 42594.796875
37 38963.6171875
38 35693.93359375
39 32737.5234375
40 30063.61328125
41 27644.451171875
42 25452.45703125
43 23462.95703125
44 21654.13671875
45 20006.134765625
46 18481.9140625
47 17080.396484375
48 15801.5498046875
49 14632.8330078125
50 13561.3876953125
51 12578.6220703125
52 11674.337890625
53 10841.7783203125
54 10076.19140625
55 9371.197265625
56 8721.3291015625
57 8121.89404296875
58 7568.52197265625
59 7057.4521484375
60 6584.37451171875
61 61

409 0.00026825262466445565
410 0.00025996402837336063
411 0.0002511999337002635
412 0.0002442056138534099
413 0.0002371846785536036
414 0.00022989641001913697
415 0.00022344247554428875
416 0.00021684386592824012
417 0.0002104209561366588
418 0.00020461369422264397
419 0.00019941432401537895
420 0.00019319963757880032
421 0.00018819409888237715
422 0.00018294586334377527
423 0.0001780166639946401
424 0.0001733440294628963
425 0.00016904539370443672
426 0.00016404119378421456
427 0.00015976968279574066
428 0.00015564337081741542
429 0.00015162718773353845
430 0.00014762389764655381
431 0.0001444871595595032
432 0.00014097474922891706
433 0.00013705431774724275
434 0.00013406641664914787
435 0.00013098832278046757
436 0.00012727061402983963
437 0.00012407182657625526
438 0.00012062744644936174
439 0.00011836187331937253
440 0.0001156554208137095
441 0.00011307938257232308
442 0.00011113528307760134
443 0.00010801759344758466
444 0.00010597593791317195
445 0.00010331725206924602
446 0.000