## This notebook implements two constrained training scheme based on Newton's method.

### Newton's method

---

### Given $g: \mathbb{R}^N\rightarrow \mathbb{R}^k$, $x\in \mathbb{R}^N$, $v \in \mathbb{R}^{N\times k}$ find $\lambda\in \mathbb{R}^k$, such that  
$$
\begin{aligned}
   g(x + \tau v \lambda) = 0 
   \end{aligned}
$$

Netwon's method:

$$
\lambda_{n+1} = \lambda_n - \tau^{-1}(\nabla g(x_n + \tau v\lambda_n) v)^{-1} g(x_n + \tau v \lambda_n)
$$



### First, load neceesary packages

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch
import math 
import torch.nn as nn
import random
import itertools 
from tqdm import tqdm
import os
import time

### Optimizer

In [2]:
def newton_projection(g, model, X, v_direction, tau, res_tol=1e-3, max_newton_steps=100):       
    iter_step = 0
    converged = False
    projection_iter_steps = 0 
            
    while iter_step < max_newton_steps:
        g_now = g(model, X)
        
        if iter_step == 0:
            k = len(g_now)
            lam = np.zeros((k))
       # if self.verbose:
        print (f'projection step={iter_step}, g={g_now.detach().item():.3e}')
            
        if torch.linalg.norm(g_now) < res_tol:
            converged = True
            break

        grad_g = [torch.autograd.grad(g_now[idx], model.parameters(), 
                                      create_graph=True, retain_graph=True, allow_unused=True) for idx in range(k)]
               
        with torch.no_grad():
            mat = np.zeros((k, k))
            for i in range(k):
                for j in range(k):
                    for z1, z2 in zip(grad_g[i], v_direction[j]):
                        if z1 is not None: # and z2 is not None:
                            mat[i,j] += (z1 * z2).sum()
                            
            dlam = -1.0 * np.linalg.solve(mat, g_now.numpy()) / tau
            lam += dlam
            for i in range(k):
                for param, z2 in zip(model.parameters(), v_direction[i]):
                    if z2 is not None:
                        param.add_(z2 * dlam[i] * tau)
                    
        iter_step += 1
    
    return converged, iter_step, lam 
        
class ProjectedSGD():
    
    def __init__(self, model, f, g, beta=1.0, tau=1e-1, res_tol=1e-3, max_newton_steps=100, verbose=True):
        self.f = f
        self.g = g
        self.tau = tau
        self.res_tol = res_tol
        self.max_newton_steps = max_newton_steps
        self.verbose = verbose
        self.model = model
        self.beta = beta
        self.X = None
        
    def reset_data(self, X):
        self.X = X
        
    def unconstrained_update(self):
        self.loss = self.f(self.model, self.X)
        self._grad_f = torch.autograd.grad(self.loss, self.model.parameters(), allow_unused=True)

        with torch.no_grad():
            for param, grad in zip(self.model.parameters(), self._grad_f):
                rn = torch.normal(torch.zeros(param.size()), torch.ones(param.size()))
                if grad is not None:
                    param.add_(-1.0 * grad * self.tau + math.sqrt(2.0 * self.tau / self.beta) * rn)       
                else :
                    param.add_(math.sqrt(2.0 * self.tau / self.beta) * rn)       
                    
    def projection_update(self):
        g_val = self.g(self.model, self.X)
        grad_g_prev = [torch.autograd.grad(g_val[idx], self.model.parameters(), 
                                      create_graph=True, retain_graph=True, allow_unused=True) for idx in range(len(g_val))] 
        
        self.converged, self.projection_iter_steps, lam = newton_projection(self.g, self.model, self.X,
                                                                            grad_g_prev, self.tau, 
                                                                            self.res_tol, self.max_newton_steps)
        
    def is_succeed(self):
        return self.converged
    
    def projection_steps(self):
        return self.projection_iter_steps

    def step(self, X):
        self.reset_data(X)
        self.unconstrained_update()
        self.projection_update()
        
class ProjectedLangevin():
    def __init__(self, model, f, g, tau=0.1, beta=1.0, alpha=0.5, res_tol=1e-5, max_newton_steps=100):
        self.model = model
        self.tau = tau
        self.beta = beta
        self.alpha = alpha
        self.res_tol = res_tol
        self.max_newton_steps = max_newton_steps
        self.p_list = []
        self.g = g
        self.f = f
        self.X = None
        for param in model.parameters():
            self.p_list.append(torch.zeros(param.size()))
    
    def reset_data(self, X):
        self.X = X
               
    def unconstrained_update(self):
        self.loss = self.f(self.model, self.X)
        # momnentum update
        grad_f = torch.autograd.grad(self.loss, self.model.parameters(), allow_unused=True)
        
        with torch.no_grad():
            for param, grad in zip (self.p_list, grad_f):
                param.add_(-0.5 * self.tau * grad)

            # position update
            for param, param_p in zip (self.model.parameters(), self.p_list):
                #print (param_p)
                param.add_(self.tau * param_p)
    
    def projection_momentum(self):
        g_val = self.g(self.model, self.X)
        k = len(g_val)
        grad_g = [torch.autograd.grad(g_val[idx], self.model.parameters(), 
                                      create_graph=True, retain_graph=True, allow_unused=True) for idx in range(k)]
        with torch.no_grad():
            mat = np.zeros((k, k))
            for i in range(k):
                for j in range(k):
                    for z1, z2 in zip(grad_g[i], grad_g[j]):
                        mat[i,j] += (z1 * z2).sum()            
                        
            lam = np.zeros((k))
            for idx in range(k):
                for param, grad in zip (self.p_list, grad_g[idx]):        
                    lam[idx] += (param * grad).sum()
                    
            coeff = np.linalg.solve(mat, lam)    
                    
            for idx in range(k):
                for param, grad in zip (self.p_list, grad_g[idx]):        
                    param.add_(-1.0 * coeff[idx] * grad)
                                                    
    def momentum_refresh(self):
        for param in self.p_list:
            rn = torch.normal(torch.zeros(param.size()), torch.ones(param.size()))
            param.add_((self.alpha - 1.0) * param + math.sqrt((1-self.alpha**2)/self.beta) * rn)
        self.projection_momentum()
            
    def projection(self):
        g_val = self.g(self.model, self.X)
        k = len(g_val)
        grad_g_prev = [torch.autograd.grad(g_val[idx], self.model.parameters(), 
                                      create_graph=True, retain_graph=True, allow_unused=True) for idx in range(k)] 
        
        #project position, x -> x_{1}
        converged, iter_steps, lam = newton_projection(self.g, self.model, self.X, grad_g_prev, self.tau, 
                                                       self.res_tol, self.max_newton_steps)
        
        self.loss = self.f(self.model, self.X)
        # momnentum update: p_{1/2}->p_1
        grad_f = torch.autograd.grad(self.loss, self.model.parameters(), allow_unused=True)               
                
        with torch.no_grad():
            # modify momentum to get p_{1/2}
            for idx in range(k):
                for param, grad in zip (self.p_list, grad_g_prev[idx]):
                    param.add_(grad * lam[idx])

            # momnentum update: p_{1/2}->p_1
            for param, grad in zip (self.p_list, grad_f):
                param.add_(-0.5 * self.tau * grad)
        
        self.projection_momentum()
        
        # p_1 -> p_{1,-}
        for param in self.p_list:
            param *= -1.0

    def step(self, X):
        self.reset_data(X)
        self.momentum_refresh()
        self.unconstrained_update()
        self.projection()
#        self.momentum_refresh()

In [3]:
class func_g:
    def __init__(self, X, y=1):
        self.k = 1
        self.y = y
        self.X = X
    def __call__(self, model):
        Z = model(self.X)
        tmp = ((Z**2).sum()-self.y).reshape((1))
        return tmp
    
class Simple(torch.nn.Module):
    def __init__(self):
        super(Simple, self).__init__()
        self.x = torch.nn.Parameter(torch.tensor([3.0, 1.0]))
    def forward(self, X):
        return X.sum()
    
def simple_f(model, X=None):
    return model.x[0] + model.x[1]

def simple_g(model, X=None):
    return 0.5 * (model.x[0]**2 + model.x[1]**2 - 1.0).reshape((1))

In [7]:
class SimpleBad(torch.nn.Module):
    def __init__(self):
        super(SimpleBad, self).__init__()
        self.x = torch.nn.Parameter(torch.tensor([1., 1.0, 0.1]))
    def forward(self, X):
        return X.sum()
    
def simple_g_bad(model, X=None):
    return 0.5* ((model.x[1]+model.x[2])**2 + (model.x[0]+model.x[2])**2 - 1.0).reshape((1))
#    return 0.5 * ((model.x[1]+model.x[2])**2 + (model.x[0]+model.x[2]-1.0)**2).reshape((1))

model = SimpleBad()

g_val = simple_g_bad(model, None)
grad_g_prev = [torch.autograd.grad(g_val[idx], model.parameters(), 
                                      create_graph=True, retain_graph=True, allow_unused=True) for idx in range(len(g_val))] 
        
newton_projection(simple_g_bad, model, None, grad_g_prev, tau=0.1)

projection step=0, g=7.100e-01
projection step=1, g=1.042e-01
projection step=2, g=4.489e-03
projection step=3, g=9.954e-06


(True, 3, array([-1.19056413]))

### Test 1

In [8]:
model = Simple()
opt = ProjectedSGD(model, simple_f, simple_g, tau=0.01, beta=20.0, verbose=True)
#opt = ProjectedLangevin(model, simple_f, simple_g, tau=0.3, beta=8.0, alpha=0.5)

n_steps = 1000
for idx in range(n_steps):
    opt.step(None)
    if idx % 100 == 0:
        print (f'step={idx}, loss={opt.loss.detach().numpy()}')
    #print (model.x.detach().numpy())
print (model.x.detach().numpy())

projection step=0, g=4.260e+00
projection step=1, g=9.531e-01
projection step=2, g=1.563e-01
projection step=3, g=9.305e-03
projection step=4, g=4.250e-05
step=0, loss=4.0
projection step=0, g=-2.196e-02
projection step=1, g=2.521e-04
projection step=0, g=-1.208e-02
projection step=1, g=7.468e-05
projection step=0, g=7.020e-02
projection step=1, g=2.161e-03
projection step=2, g=2.325e-06
projection step=0, g=2.696e-02
projection step=1, g=3.449e-04
projection step=0, g=-3.545e-02
projection step=1, g=6.762e-04
projection step=0, g=-3.201e-02
projection step=1, g=5.473e-04
projection step=0, g=-1.680e-02
projection step=1, g=1.460e-04
projection step=0, g=-3.042e-02
projection step=1, g=4.928e-04
projection step=0, g=-5.456e-02
projection step=1, g=1.671e-03
projection step=2, g=1.431e-06
projection step=0, g=-4.501e-03
projection step=1, g=1.025e-05
projection step=0, g=-2.022e-02
projection step=1, g=2.130e-04
projection step=0, g=5.796e-02
projection step=1, g=1.505e-03
projection st

projection step=0, g=1.124e-03
projection step=1, g=5.960e-07
projection step=0, g=4.504e-02
projection step=1, g=9.305e-04
projection step=0, g=4.460e-02
projection step=1, g=9.131e-04
projection step=0, g=4.869e-02
projection step=1, g=1.080e-03
projection step=2, g=5.960e-07
projection step=0, g=5.325e-03
projection step=1, g=1.407e-05
projection step=0, g=1.934e-04
projection step=0, g=4.281e-02
projection step=1, g=8.439e-04
projection step=0, g=1.172e-02
projection step=1, g=6.717e-05
projection step=0, g=5.304e-02
projection step=1, g=1.272e-03
projection step=2, g=8.345e-07
projection step=0, g=7.240e-03
projection step=1, g=2.587e-05
projection step=0, g=-3.179e-02
projection step=1, g=5.395e-04
projection step=0, g=6.828e-03
projection step=1, g=2.295e-05
projection step=0, g=2.180e-02
projection step=1, g=2.277e-04
projection step=0, g=6.392e-02
projection step=1, g=1.811e-03
projection step=2, g=1.609e-06
projection step=0, g=1.826e-02
projection step=1, g=1.609e-04
project

projection step=1, g=3.976e-05
projection step=0, g=5.603e-02
projection step=1, g=1.411e-03
projection step=2, g=1.013e-06
projection step=0, g=2.798e-02
projection step=1, g=3.706e-04
projection step=0, g=-2.341e-02
projection step=1, g=2.874e-04
projection step=0, g=-2.052e-02
projection step=1, g=2.195e-04
projection step=0, g=-2.464e-02
projection step=1, g=3.193e-04
projection step=0, g=1.468e-03
projection step=1, g=1.073e-06
projection step=0, g=2.943e-03
projection step=1, g=4.292e-06
projection step=0, g=4.129e-02
projection step=1, g=7.875e-04
projection step=0, g=4.234e-02
projection step=1, g=8.265e-04
projection step=0, g=8.170e-03
projection step=1, g=3.284e-05
projection step=0, g=-1.093e-02
projection step=1, g=6.104e-05
projection step=0, g=-1.602e-02
projection step=1, g=1.326e-04
projection step=0, g=3.324e-02
projection step=1, g=5.181e-04
projection step=0, g=2.121e-02
projection step=1, g=2.158e-04
projection step=0, g=5.898e-02
projection step=1, g=1.556e-03
pro

projection step=0, g=-3.047e-04
projection step=0, g=5.345e-02
projection step=1, g=1.291e-03
projection step=2, g=8.345e-07
projection step=0, g=-1.777e-02
projection step=1, g=1.638e-04
projection step=0, g=3.376e-02
projection step=1, g=5.339e-04
projection step=0, g=4.604e-02
projection step=1, g=9.707e-04
projection step=0, g=1.045e-02
projection step=1, g=5.352e-05
projection step=0, g=5.138e-02
projection step=1, g=1.197e-03
projection step=2, g=7.153e-07
projection step=0, g=4.035e-02
projection step=1, g=7.532e-04
projection step=0, g=3.533e-02
projection step=1, g=5.828e-04
projection step=0, g=2.759e-03
projection step=1, g=3.815e-06
projection step=0, g=1.499e-02
projection step=1, g=1.090e-04
projection step=0, g=1.495e-02
projection step=1, g=1.086e-04
projection step=0, g=4.042e-02
projection step=1, g=7.559e-04
projection step=0, g=4.408e-02
projection step=1, g=8.929e-04
projection step=0, g=1.023e-01
projection step=1, g=4.348e-03
projection step=2, g=9.358e-06
projec