In [28]:
import numpy as np
import torch
import math
from torch import nn

In [29]:
# Parameters
d = 10  # Dimensionality
# The number of samples is supposed to be infinite in the paper.
batch_size = 50  # Batch size for training
learning_rate = 1e-3  # Learning rate for SGD
epochs = int(2e6)  # Number of training epochs
np.random.seed(42)

In [30]:
def generate_3_spin_coefficients(d):
    """Generate random Gaussian coefficients a_pqr for the 3-spin model."""
    return np.random.normal(0, 1, size=(d, d, d))

def spherical_3_spin(x, a):
    """Compute the 3-spin model function f(x)."""
    f = np.einsum('pqr,p,q,r->', a, x, x, x)
    return f / x.shape[0]

def sample_from_sphere(d, n_samples):
    """Sample points from the d-dimensional sphere of radius sqrt(d)."""
    x = np.random.normal(0, 1, size=(n_samples, d))
    norms = np.linalg.norm(x, axis=1, keepdims=True)
    return math.sqrt(d) * x / norms


a = generate_3_spin_coefficients(d)

In [31]:
n_neurons = int(1e5)
model = nn.Sequential(
    nn.Linear(d, n_neurons),
    nn.Linear(n_neurons, 1),
    nn.ReLU(),
)


optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss(reduction='mean')

for _ in range(epochs):
    # Random batch
    X_batch = sample_from_sphere(d, batch_size)
    y_batch = np.array([spherical_3_spin(x, a) for x in X_batch]).reshape(-1, 1)

    # To tensor
    X_batch = torch.tensor(X_batch.astype(np.float32))
    y_batch = torch.tensor(y_batch.astype(np.float32))

    # Forward pass
    predictions = model(X_batch)
    loss = criterion(predictions, y_batch)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

   # Log progress
    if _ % (epochs // 100) == 0:
        print(f"Step {_}/{epochs}, Loss: {loss.item():.4f}")




Step 0/2000000, Loss: 6.9826


KeyboardInterrupt: 

In [25]:
# Step 4: Evaluate the Network
def evaluate_network(model, X, y):
    """Compute the loss and signed discrepancy."""
    y_pred = model(torch.tensor(X.astype(np.float32)))
    mse = criterion(y_pred, y)  # Mean squared error
    relative = (y_pred - y)  * (y > 0)
    signed_discrepancy = np.mean(relative.detach().numpy())  # Signed discrepancy for positive f(x)
    return mse, signed_discrepancy

mse, signed_discrepancy = evaluate_network(model, X, torch.tensor(y.astype(np.float32)))
print(f"Mean Squared Error: {mse}")
print(f"Signed Discrepancy: {signed_discrepancy}")

Mean Squared Error: 9.848605155944824
Signed Discrepancy: -1.1606932878494263
