In [1]:
import torch

dtype = torch.float

# to run on GPU
device = torch.device("cuda:0")

In [3]:
# check cuda settings

torch.cuda.current_device()

torch.cuda.device(0)

torch.cuda.device_count()

torch.cuda.get_device_name(0)

torch.cuda.is_available()

True

In [5]:
# initialization
# N:=batch size
# D_in:=input dimension
# H:=hidden dimension
# D_out:=output dimension
N, D_in, H, D_out = 64, 1000, 100, 10

# create random input and output data
x = torch.randn(N, D_in, device=device, dtype = dtype)
y = torch.randn(N, D_out, device=device, dtype = dtype)

In [9]:
# randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

In [11]:
learning_rate = 1e-6
for t in range(500):
    # forward pass: compute predicted y
    h = x.mm(w1)
    # torch.mm: performs a matrix multiplication
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)
    
    # compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)
    
    # backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)
    
    # update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 22993956.0
1 16102764.0
2 13048335.0
3 11488416.0
4 10448305.0
5 9427003.0
6 8259731.0
7 6946088.0
8 5604575.5
9 4359265.5
10 3303070.0
11 2459244.0
12 1819158.375
13 1347167.625
14 1006146.375
15 761088.0
16 585473.375
17 458687.625
18 366220.3125
19 297819.40625
20 246416.359375
21 207053.71875
22 176340.125
23 151933.9375
24 132215.265625
25 116029.875
26 102571.21875
27 91220.5625
28 81541.5859375
29 73208.71875
30 65973.1875
31 59648.90625
32 54085.2109375
33 49169.875
34 44804.765625
35 40910.75
36 37425.4765625
37 34297.25
38 31481.49609375
39 28939.37109375
40 26639.7265625
41 24555.50390625
42 22674.07421875
43 20963.3671875
44 19404.810546875
45 17981.97265625
46 16680.521484375
47 15488.1767578125
48 14394.158203125
49 13389.115234375
50 12464.763671875
51 11613.8525390625
52 10829.1953125
53 10104.88671875
54 9435.658203125
55 8817.0009765625
56 8244.0498046875
57 7713.0546875
58 7220.705078125
59 6763.7138671875
60 6339.0986328125
61 5944.33203125
62 5576.9921875
63 5235

492 0.00020469704759307206
493 0.00020058234804309905
494 0.00019726663595065475
495 0.00019380725279916078
496 0.00019063030777033418
497 0.00018722054664976895
498 0.00018418370746076107
499 0.00018074018589686602
