In [None]:
# ChaGPT For loop version

import torch
import torch.nn as nn

# Check for Metal (MPS) device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Define a simple neural network for coefficients
class CoefficientNet(nn.Module):
    def __init__(self):
        super(CoefficientNet, self).__init__()
        self.fc = nn.Linear(3, 128)  # Fully connected layer
        self.out = nn.Linear(128, 2 * 2 * 2 * 2 * 2 * 2)  # Output layer for coefficients

    def forward(self, x):
        x = torch.relu(self.fc(x))
        return self.out(x).view(2, 2, 2, 2, 2, 2)  # Reshape to required coefficient shape

# Instantiate the neural network and move it to Metal
net = CoefficientNet().to(device)

# Generate coefficients (shared for all query points)
dummy_input = torch.tensor([1.2, 2.4, 3.0], device=device)  # Dummy input to generate shared coefficients
coefficients = net(dummy_input)  # Shape: (2, 2, 2, 2, 2, 2)

# # Define the 8 points in the cube (move to Metal)
# points = torch.tensor([
#     [0.0, 0.0, 0.0],
#     [1.0, 0.0, 0.0],
#     [0.0, 1.0, 0.0],
#     [1.0, 1.0, 0.0],
#     [0.0, 0.0, 1.0],
#     [1.0, 0.0, 1.0],
#     [0.0, 1.0, 1.0],
#     [1.0, 1.0, 1.0],
# ], requires_grad=True, device=device)  # Shape: (8, 3)
points = torch.randn(8, 3, requires_grad=True, device=device)

# Define basis functions
def h0(x):
    return 1 - x

def h1(x):
    return x

basis_functions = [h0, h1]

# Batch size (number of random points)
batch_size = 10

# Generate random query points in space (move to Metal)
query_points = torch.rand((batch_size, 3), requires_grad=True, device=device)

# Compute the loss for multiple random points
loss = 0.0

for b in range(batch_size):
    query_point = query_points[b]
    result = 0.0
    for a in range(2):
        for b in range(2):
            for c in range(2):
                for i in range(2):
                    for j in range(2):
                        for k in range(2):
                            h_i = basis_functions[i](query_point[0] - points[a * 4 + b * 2 + c, 0])
                            h_j = basis_functions[j](query_point[1] - points[a * 4 + b * 2 + c, 1])
                            h_k = basis_functions[k](query_point[2] - points[a * 4 + b * 2 + c, 2])
                            
                            result += coefficients[a, b, c, i, j, k] * h_i * h_j * h_k
    loss += result

# Average the loss across all random points
loss = loss / batch_size

# Backpropagate
loss.backward()

# Outputs
print("Loss:", loss.item())
print("Gradients (query_points):", query_points.grad)
print("Gradients (points):", points.grad)
for name, param in net.named_parameters():
    print(f"Gradients ({name}):", param.grad)


Using device: mps
Loss: -3.4387755393981934
Gradients (query_points): tensor([[-0.0391, -0.0227, -0.0520],
        [ 0.0031, -0.0034, -0.2291],
        [ 0.1975,  0.0723, -0.2301],
        [ 0.0049, -0.0039, -0.2625],
        [ 0.0452,  0.0088, -0.0945],
        [ 0.1012,  0.0353, -0.1524],
        [ 0.1980,  0.0629, -0.0838],
        [-0.0802, -0.0395, -0.2021],
        [ 0.0174,  0.0006, -0.1736],
        [ 0.1164,  0.0281, -0.0615]], device='mps:0')
Gradients (points): tensor([[ 0.1569, -0.0616, -0.0685],
        [ 0.9880,  0.2717, -0.2354],
        [-0.4236, -0.0290,  0.4851],
        [-0.9378, -2.5863,  2.1161],
        [ 0.1350,  0.4964, -0.2464],
        [-1.2858,  1.1961,  1.3405],
        [ 0.4930,  0.5327, -0.9045],
        [ 0.3099,  0.0415, -0.9454]], device='mps:0')
Gradients (fc.weight): tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 6.4231e-01,  1.2846e+00,  1.60