# PyTorch basics

In [1]:
import torch
from torch.autograd import Variable

### PyTorch has Tensors, too

In [2]:
dtype = torch.FloatTensor
#dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

X = torch.randn(3, 4).type(dtype)
X


 1.4318  0.0409  0.0046 -0.6761
 0.4244  0.7086 -1.2760 -0.4752
 0.6011 -1.8419  1.8842 -2.6688
[torch.FloatTensor of size 3x4]

### Tensors are wrapped in Variables that will also store gradients

In [3]:
X_Var = torch.autograd.Variable(X)

### Porting the numpy network to PyTorch

In [4]:
n = 64
num_features = 1000
hidden_dim = 100
output_dim = 10

learning_rate = 1e-6
num_epochs = 500

In [5]:
dtype = torch.FloatTensor

In [6]:
# we do not need to compute gradients with respect to these Variables during the backward pass
X = Variable(torch.randn(n, num_features).type(dtype), requires_grad=False)
y = Variable(torch.randn(n, output_dim).type(dtype), requires_grad=False)

In [7]:
# for the weights we do need to ;-)
W1 = Variable(torch.randn(num_features, hidden_dim).type(dtype), requires_grad=True)
W2 = Variable(torch.randn(hidden_dim, output_dim).type(dtype), requires_grad=True)

In [9]:
for epoch in range(num_epochs):
    
  # Forward pass
  # We do not need to keep references to intermediate values
  # since we are not implementing the backward pass by hand!
  y_pred = X.mm(W1).clamp(min=0).mm(W2)
  
  # Compute and print loss using operations on Variables.
  # loss.data is a Tensor of shape (1,); loss.data[0] is a scalar value holding the loss.
  loss = (y_pred - y).pow(2).sum()
  print(epoch, loss.data[0])
  
  # Use autograd to compute the backward pass. 
  # After this call W1.grad and W2.grad will be Variables holding the gradient
  # of the loss with respect to W1 and W2 respectively.
  loss.backward()

  # Update weights using gradient descent
  W1.data -= learning_rate * W1.grad.data
  W2.data -= learning_rate * W2.grad.data

  # Manually zero the gradients after updating the weights
  W1.grad.data.zero_()
  W2.grad.data.zero_()


0 35941164.0
1 28096576.0
2 21300124.0
3 14758547.0
4 9429425.0
5 5839253.0
6 3691893.5
7 2465083.0
8 1753819.375
9 1319161.75
10 1035111.375
11 836602.875
12 690356.625
13 577948.5
14 489062.46875
15 417352.40625
16 358785.8125
17 310220.78125
18 269584.0625
19 235370.390625
20 206338.046875
21 181599.765625
22 160409.1875
23 142152.8125
24 126334.3203125
25 112580.6796875
26 100582.03125
27 90074.7890625
28 80839.2421875
29 72699.796875
30 65507.11328125
31 59134.2109375
32 53477.16015625
33 48439.2578125
34 43947.5546875
35 39933.75390625
36 36337.8984375
37 33118.34375
38 30234.5234375
39 27634.59765625
40 25287.16015625
41 23165.3203125
42 21243.43359375
43 19498.94921875
44 17915.166015625
45 16475.32421875
46 15164.4931640625
47 13969.3134765625
48 12877.97265625
49 11880.6689453125
50 10969.158203125
51 10133.82421875
52 9368.427734375
53 8666.5322265625
54 8021.8447265625
55 7429.76708984375
56 6884.86083984375
57 6383.84130859375
58 5922.94921875
59 5498.048828125
60 5106.296

415 0.00010211514018010348
416 0.00010035572631750256
417 9.809323091758415e-05
418 9.626703831600025e-05
419 9.406750177731737e-05
420 9.264861728297547e-05
421 9.045824117492884e-05
422 8.905331196729094e-05
423 8.764121594140306e-05
424 8.602118032285944e-05
425 8.397189958486706e-05
426 8.235948189394549e-05
427 8.072752098087221e-05
428 7.931103027658537e-05
429 7.787132199155167e-05
430 7.689026824664325e-05
431 7.548759458586574e-05
432 7.407170051010326e-05
433 7.264372106874362e-05
434 7.177959923865274e-05
435 7.053721492411569e-05
436 6.931585812708363e-05
437 6.82772442814894e-05
438 6.734146882081404e-05
439 6.609225238207728e-05
440 6.504130578832701e-05
441 6.392353679984808e-05
442 6.303474947344512e-05
443 6.195652531459928e-05
444 6.106679938966408e-05
445 5.9770889492938295e-05
446 5.908001912757754e-05
447 5.829135261592455e-05
448 5.750300988438539e-05
449 5.650284583680332e-05
450 5.565451647271402e-05
451 5.492006312124431e-05
452 5.390077058109455e-05
453 5.3258