In [20]:
import torch


class MyReLU(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
    
class MyNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(MyNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        # To apply our Function, we use Function.apply method. We alias this as 'relu'.
        relu = MyReLU.apply
        h_relu = relu(self.linear1(x))
        y_pred = self.linear2(h_relu)
        return y_pred

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in).to(device)
y = torch.randn(N, D_out).to(device)
model = MyNet(D_in, H, D_out).to(device)

ceritrion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    y_pred = model(x)
    
    loss = ceritrion(y_pred, y)
    print(t,loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 621.9244384765625
1 575.5208129882812
2 535.8607177734375
3 500.95867919921875
4 469.84149169921875
5 441.9232177734375
6 416.41656494140625
7 392.99151611328125
8 371.4518127441406
9 351.2593078613281
10 332.28765869140625
11 314.4669189453125
12 297.7701416015625
13 282.0753173828125
14 267.284912109375
15 253.34262084960938
16 240.12428283691406
17 227.51284790039062
18 215.4942169189453
19 204.04367065429688
20 193.1083984375
21 182.68704223632812
22 172.78221130371094
23 163.363525390625
24 154.37173461914062
25 145.82127380371094
26 137.69317626953125
27 129.94041442871094
28 122.57762908935547
29 115.59627532958984
30 108.96263122558594
31 102.68822479248047
32 96.76607513427734
33 91.18742370605469
34 85.92356872558594
35 80.95348358154297
36 76.28157043457031
37 71.8812484741211
38 67.74893188476562
39 63.848472595214844
40 60.18348693847656
41 56.737361907958984
42 53.500343322753906
43 50.45751953125
44 47.600128173828125
45 44.91846466064453
46 42.39643096923828
47 40.024

463 7.688200275879353e-05
464 7.51631596358493e-05
465 7.348708459176123e-05
466 7.184722926467657e-05
467 7.0245485403575e-05
468 6.867792399134487e-05
469 6.714754999848083e-05
470 6.564964860444888e-05
471 6.418689736165106e-05
472 6.275794294197112e-05
473 6.135978765087202e-05
474 5.9993413742631674e-05
475 5.8658799389377236e-05
476 5.735241938964464e-05
477 5.607708953903057e-05
478 5.4830641602166e-05
479 5.3609037422575057e-05
480 5.241730832494795e-05
481 5.1254730351502076e-05
482 5.011454049963504e-05
483 4.900058047496714e-05
484 4.791377796209417e-05
485 4.684888699557632e-05
486 4.580941822496243e-05
487 4.47938873548992e-05
488 4.379854362923652e-05
489 4.28262646892108e-05
490 4.187750528217293e-05
491 4.09483109251596e-05
492 4.0041180909611285e-05
493 3.9154580008471385e-05
494 3.8286554627120495e-05
495 3.743721026694402e-05
496 3.660858783405274e-05
497 3.5796459997072816e-05
498 3.5005279642064124e-05
499 3.423124144319445e-05
