# PyTorch: Tensors

## A PyTorch Tensor is conceptually identical to a numpy array:
a Tensor is an n-demsional array, and PyTorch provides many functions for operating on these tensors. Like numpy arrays, PyTorch Tensors do not know anything about deep learning or computational graphs or gradients; they are a generic tool for scientific computing. 
Utilization of GPUs is the great advantage that Tensors have over numpy arrays. To run a PyTorch Tensor, you simply need to cast it to a new datatype.


Below we use PyTorch Tensors to fit a two-layaer network to random data. Like the numpy example previously, we need to manually implement the forward and backward propagations through the network.

In [5]:
import torch

dtype = torch.float
device = torch.device("cpu")

batch_size = 64
input_dimension = 1000
hidden_dimension = 100
output_dimension = 10

# Generate random data
x = torch.randn(batch_size, input_dimension, device = device, dtype = dtype)
y = torch.randn(batch_size, output_dimension, device = device, dtype = dtype)

# Initialize random weights 
weight1 = torch.randn(input_dimension, hidden_dimension, device = device, dtype = dtype)
weight2 = torch.randn(hidden_dimension, output_dimension, device = device, dtype = dtype)

learning_rate = 1e-6
for n in range(500):
    # Compute predicted y in a forward pass
    dot_product = x.mm(weight1)
    dot_product_relu = dot_product.clamp(min = 0)
    y_pred = dot_product_relu.mm(weight2)
    
    #Loss
    loss = (y_pred - y).pow(2).sum().item()
    print(n, loss)
    
    #Backprop to compute gradients of weights with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_weight2 = dot_product_relu.t().mm(grad_y_pred)
    grad_dot_product_relu = grad_y_pred.mm(weight2.t())
    grad_dot_product = grad_dot_product_relu.clone()
    grad_dot_product[dot_product < 0] = 0
    grad_weight1 = x.t().mm(grad_dot_product)
    
    #Update weights
    weight1 -= learning_rate * grad_weight1
    weight2 -= learning_rate * grad_weight2

0 36652176.0
1 31930776.0
2 31120002.0
3 29121236.0
4 23405800.0
5 15942288.0
6 9377715.0
7 5240545.5
8 3005233.25
9 1883019.125
10 1302721.875
11 979024.25
12 778268.375
13 640734.875
14 538844.3125
15 459374.34375
16 395290.0625
17 342565.5
18 298545.0
19 261487.546875
20 230059.609375
21 203165.640625
22 180024.1875
23 160020.546875
24 142666.09375
25 127542.8046875
26 114315.9921875
27 102702.765625
28 92476.4296875
29 83439.4140625
30 75446.2890625
31 68351.8984375
32 62035.84375
33 56406.29296875
34 51365.16015625
35 46845.8671875
36 42784.18359375
37 39126.98828125
38 35827.15234375
39 32846.1796875
40 30148.810546875
41 27703.3203125
42 25483.1640625
43 23464.86328125
44 21628.103515625
45 19954.59765625
46 18426.486328125
47 17030.748046875
48 15754.4375
49 14586.224609375
50 13515.2841796875
51 12532.896484375
52 11630.375
53 10800.337890625
54 10036.3896484375
55 9332.578125
56 8683.8056640625
57 8084.9345703125
58 7532.06494140625
59 7021.0234375
60 6548.37060546875
61 6110