In [184]:
import torch as th
from tqdm.notebook import trange

In [185]:
device = "mps"
dtype = th.float32

In [186]:
class Quadratic(th.nn.Module):
    
    def __init__(self, n):
        super(Quadratic, self).__init__()
        self.A = th.nn.Parameter(th.randn(n, n))
        self.b = th.nn.Parameter(th.randn(n))
        self.c = th.nn.Parameter(th.randn(1))
    
    def forward(self, x):
        return 0.5 * th.einsum("ij,bi,bj->b", self.A, x, x) + th.einsum("i,bi->b", self.b, x) + self.c

    def string(self):
        return f"0.5 * x^T {self.A} x + {self.b}^T x + {self.c}"

In [188]:
class GradientOptimizer(th.optim.Optimizer):
    
    def __init__(self, params, lr) -> None:
        super().__init__(params, {})
        self.lr = lr
    
    def step(self, closure=None):
        with th.no_grad():
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue
                    d = p.grad
                    p.add_(-d, alpha=self.lr)
        return None

In [189]:
model = th.nn.Sequential(
    th.nn.Linear(2, 10),
    th.nn.ReLU(),
    th.nn.Linear(10, 10),
    th.nn.ReLU(),
    Quadratic(10),
).to(device=device)

In [190]:
# use and as data
X = th.tensor([[0, 0], [1, 1], [0, 1], [1, 0]], device=device, dtype=dtype)
Y = th.tensor([0, 0, 0, 1], device=device, dtype=dtype)

In [191]:
model(X)

tensor([0.3920, 0.3726, 0.4757, 0.2120], device='mps:0',
       grad_fn=<AddBackward0>)

In [192]:
optimizer = GradientOptimizer(model.parameters(), lr=1e-1)

steps = 1000
logging_steps = 100

for step in trange(steps):
    optimizer.zero_grad()
    loss = th.nn.functional.mse_loss(model(X), Y)
    loss.backward()
    optimizer.step()
    if step % logging_steps == logging_steps - 1:
        print(f"Step {step}: Loss {loss.item()}")

  0%|          | 0/1000 [00:00<?, ?it/s]

Step 99: Loss 8.540403939605312e-08
Step 199: Loss 1.3522516439934407e-13
Step 299: Loss 6.994405055138486e-15
Step 399: Loss 1.3877787807814457e-15
Step 499: Loss 3.219646771412954e-15
Step 599: Loss 2.4980018054066022e-15
Step 699: Loss 1.942890293094024e-15
Step 799: Loss 1.1657341758564144e-15
Step 899: Loss 1.1657341758564144e-15
Step 999: Loss 3.3306690738754696e-16


In [193]:
print("Predictions:")
print(
    list(
        map(
            lambda x: "{:.4f}".format(x), 
            model(X).detach().cpu().numpy().tolist()
        )
    )
)

Predictions:
['0.0000', '0.0000', '-0.0000', '1.0000']
