In [5]:
import numpy as np
import torch
import torch.nn as nn

In [6]:
class TanhNewtonImplicitLayer(nn.Module):
    def __init__(self, out_features, tol = 1e-4, max_iter=50):
        super().__init__()
        self.linear = nn.Linear(out_features, out_features, bias=False)
        self.tol = tol
        self.max_iter = max_iter
  
    def forward(self, x):
        # Run Newton's method outside of the autograd framework
        with torch.no_grad():
            z = torch.tanh(x)
            self.iterations = 0
            while self.iterations < self.max_iter:
                z_linear = self.linear(z) + x
                g = z - torch.tanh(z_linear)
                self.err = torch.norm(g)
                if self.err < self.tol:
                    break

                # newton step
                J = torch.eye(z.shape[1])[None,:,:] - (1 / torch.cosh(z_linear)**2)[:,:,None]*self.linear.weight[None,:,:]
                z = z - torch.solve(g[:,:,None], J)[0][:,:,0]
                self.iterations += 1
    
        # reengage autograd and add the gradient hook
        z = torch.tanh(self.linear(z) + x)
        z.register_hook(lambda grad : torch.solve(grad[:,:,None], J.transpose(1,2))[0][:,:,0])
        return z

In [14]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim as optim

In [15]:
mnist_train = datasets.MNIST(".", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST(".", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [16]:
# a generic function for running a single epoch (training or evaluation)
def epoch(loader, model, opt=None, monitor=None):
    total_loss, total_err, total_monitor = 0.,0.,0.
    model.eval() if opt is None else model.train()
    for X,y in loader:
        X,y = X.to(device), y.to(device)
        yp = model(X)
        loss = nn.CrossEntropyLoss()(yp,y)
        if opt:
            opt.zero_grad()
            loss.backward()
            if sum(torch.sum(torch.isnan(p.grad)) for p in model.parameters()) == 0:
              opt.step()
        
        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
        if monitor is not None:
            total_monitor += monitor(model)
    return total_err / len(loader.dataset), total_loss / len(loader.dataset), total_monitor / len(loader)

In [17]:
torch.manual_seed(0)
model = nn.Sequential(nn.Flatten(),
                      nn.Linear(784, 100),
                      TanhNewtonImplicitLayer(100, max_iter=40),
                      nn.Linear(100, 10)
                      ).to(device)
opt = optim.SGD(model.parameters(), lr=1e-1)

In [18]:
for i in range(10):
    if i == 5:
        opt.param_groups[0]["lr"] = 1e-2

    train_err, train_loss, train_fpiter = epoch(train_loader, model, opt, lambda x : x[2].iterations)
    test_err, test_loss, test_fpiter = epoch(test_loader, model, monitor = lambda x : x[2].iterations)
    print(f"Train Error: {train_err:.4f}, Loss: {train_loss:.4f}, Newton Iters: {train_fpiter:.2f} | " +
          f"Test Error: {test_err:.4f}, Loss: {test_loss:.4f}, Newton Iters: {test_fpiter:.2f}")

Train Error: 0.1113, Loss: 0.4020, Newton Iters: 6.66 | Test Error: 0.0661, Loss: 0.2231, Newton Iters: 6.73
Train Error: 0.0582, Loss: 0.1938, Newton Iters: 7.27 | Test Error: 0.1081, Loss: 0.3645, Newton Iters: 6.63
Train Error: 0.0435, Loss: 0.1476, Newton Iters: 6.96 | Test Error: 0.0437, Loss: 0.1473, Newton Iters: 6.63
Train Error: 0.0362, Loss: 0.1220, Newton Iters: 7.20 | Test Error: 0.0382, Loss: 0.1311, Newton Iters: 6.57
Train Error: 0.0310, Loss: 0.1047, Newton Iters: 8.06 | Test Error: 0.0332, Loss: 0.1129, Newton Iters: 7.60
Train Error: 0.0204, Loss: 0.0733, Newton Iters: 8.74 | Test Error: 0.0316, Loss: 0.1052, Newton Iters: 8.11
Train Error: 0.0191, Loss: 0.0681, Newton Iters: 9.12 | Test Error: 0.0311, Loss: 0.1057, Newton Iters: 8.22
Train Error: 0.0183, Loss: 0.0656, Newton Iters: 9.48 | Test Error: 0.0304, Loss: 0.1051, Newton Iters: 8.73
Train Error: 0.0177, Loss: 0.0629, Newton Iters: 9.91 | Test Error: 0.0302, Loss: 0.1051, Newton Iters: 9.53
Train Error: 0.0170