In [24]:
import torch


class MyReLU(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
    
class MyNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(MyNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        # To apply our Function, we use Function.apply method. We alias this as 'relu'.
        relu = MyReLU.apply
        h_relu = relu(self.linear1(x))
        y_pred = self.linear2(h_relu)
        return y_pred

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in).to(device)
y = torch.randn(N, D_out).to(device)
model = MyNet(D_in, H, D_out).to(device)

ceritrion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    y_pred = model(x)
    
    loss = ceritrion(y_pred, y)
    print(t,loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 679.5858764648438
1 632.1129150390625
2 591.4473876953125
3 555.7750244140625
4 523.6234130859375
5 494.5653076171875
6 468.1240234375
7 443.7981872558594
8 421.26873779296875
9 400.3902587890625
10 380.7362060546875
11 362.162109375
12 344.54931640625
13 327.8200378417969
14 311.86126708984375
15 296.6468505859375
16 282.1530456542969
17 268.2584228515625
18 254.95907592773438
19 242.19790649414062
20 229.9599609375
21 218.24819946289062
22 206.9921112060547
23 196.24472045898438
24 185.965087890625
25 176.13938903808594
26 166.74365234375
27 157.79937744140625
28 149.29173278808594
29 141.20631408691406
30 133.51644897460938
31 126.1762466430664
32 119.212158203125
33 112.59922790527344
34 106.33858489990234
35 100.41510772705078
36 94.79867553710938
37 89.49909973144531
38 84.48545837402344
39 79.73721313476562
40 75.2606201171875
41 71.04267883300781
42 67.05499267578125
43 63.29179763793945
44 59.7442626953125
45 56.398067474365234
46 53.251251220703125
47 50.2790641784668
48 47

351 0.00019270722987130284
352 0.00018635383457876742
353 0.00018021685536950827
354 0.00017428088176529855
355 0.0001685445022303611
356 0.00016299221897497773
357 0.0001576306822244078
358 0.00015244598034769297
359 0.00014743287465535104
360 0.00014258864393923432
361 0.00013790834054816514
362 0.00013338474673219025
363 0.00012900616275146604
364 0.00012477784184738994
365 0.00012068977230228484
366 0.0001167354712379165
367 0.00011291260307189077
368 0.0001092210368369706
369 0.00010564984404481947
370 0.00010219588875770569
371 9.885618783300743e-05
372 9.562785999150947e-05
373 9.250712173525244e-05
374 8.948880713433027e-05
375 8.656815771246329e-05
376 8.37488187244162e-05
377 8.102032006718218e-05
378 7.838472083676606e-05
379 7.583442493341863e-05
380 7.337067654589191e-05
381 7.098808418959379e-05
382 6.867759657325223e-05
383 6.644790119025856e-05
384 6.428840424632654e-05
385 6.220323120942339e-05
386 6.018249405315146e-05
387 5.82317152293399e-05
388 5.634537956211716e-0