In [1]:
import torch
import math

In [2]:
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())
# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())

True
True


In [None]:
torch.mps.is_available()

In [3]:
dtype = torch.float
device = torch.device("mps")

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

# Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d


print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

99 65.43809509277344
199 48.837162017822266
299 37.10573959350586
399 28.814699172973633
499 22.954626083374023
599 18.81243324279785
699 15.884288787841797
799 13.814231872558594
899 12.350706100463867
999 11.315933227539062
1099 10.584269523620605
1199 10.066898345947266
1299 9.70103645324707
1399 9.442302703857422
1499 9.25932502746582
1599 9.129915237426758
1699 9.038387298583984
1799 8.973649024963379
1899 8.92785930633545
1999 8.895469665527344
Result: y = -0.00934265460819006 + 0.8561256527900696 x + 0.001611765823327005 x^2 + -0.09324290603399277 x^3
