## This notebook implements two constrained training scheme based on Newton's method.

### Newton's method

---

### Given $g: \mathbb{R}^N\rightarrow \mathbb{R}^k$, $x\in \mathbb{R}^N$, $v \in \mathbb{R}^{N\times k}$ find $\lambda\in \mathbb{R}^k$, such that  
$$
\begin{aligned}
   g(x + \tau v \lambda) = 0 
   \end{aligned}
$$

Netwon's method:

$$
\lambda_{n+1} = \lambda_n - \tau^{-1}(\nabla g(x_n + \tau v\lambda_n) v)^{-1} g(x_n + \tau v \lambda_n)
$$



### First, load neceesary packages

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch
import math 
import torch.nn as nn
import random
import itertools 
from tqdm import tqdm
import os
import time

### Optimizer

In [14]:
def newton_projection(g, model, Y, v_direction, tau, res_tol=1e-3, max_newton_steps=100, verbose=True):       
    iter_step = 0
    converged = False
    projection_iter_steps = 0 
    print (v_direction)
    while iter_step < max_newton_steps:
        g_now = g(model, Y)
        
        if iter_step == 0:
            k = len(g_now)
            lam = np.zeros((k))
            
        if verbose:
            print (f'projection step={iter_step}, g={g_now.detach().item():.3e}, x={model.x.detach().numpy()}')
            
        if torch.linalg.norm(g_now) < res_tol:
            converged = True
            break

        grad_g = [torch.autograd.grad(g_now[idx], model.parameters(), 
                                      create_graph=True, retain_graph=True, allow_unused=True) for idx in range(k)]
               
        with torch.no_grad():
            mat = np.zeros((k, k))
            for i in range(k):
                for j in range(k):
                    for z1, z2 in zip(grad_g[i], v_direction[j]):
                        if z1 is not None: # and z2 is not None:
                            mat[i,j] += (z1 * z2).sum()
            dlam = -1.0 * np.linalg.solve(mat, g_now.numpy()) / tau
            lam += dlam
            for i in range(k):
                for param, z2 in zip(model.parameters(), v_direction[i]):
                    if z2 is not None:
                        param.add_(z2 * dlam[i] * tau)
                    
        iter_step += 1
    
    return converged, iter_step, lam 
        
class ProjectedSGD():
    
    def __init__(self, model, f, g, beta=1.0, tau=1e-1, res_tol=1e-3, max_newton_steps=100, verbose=True):
        self.f = f
        self.g = g
        self.tau = tau
        self.res_tol = res_tol
        self.max_newton_steps = max_newton_steps
        self.verbose = verbose
        self.model = model
        self.beta = beta
        self.X = None
        self.Y = None
        
    def reset_data(self, X, Y):
        self.X = X
        self.Y = Y
        
    def unconstrained_update(self):
        self.loss = self.f(self.model, self.X)
        self._grad_f = torch.autograd.grad(self.loss, self.model.parameters(), allow_unused=True)

        with torch.no_grad():
            for param, grad in zip(self.model.parameters(), self._grad_f):
                rn = torch.normal(torch.zeros(param.size()), torch.ones(param.size()))
                if grad is not None:
                    param.add_(-1.0 * grad * self.tau + math.sqrt(2.0 * self.tau / self.beta) * rn)       
                else :
                    param.add_(math.sqrt(2.0 * self.tau / self.beta) * rn)       
                    
    def projection_update(self):
        g_val = self.g(self.model, self.Y)
        grad_g_prev = [torch.autograd.grad(g_val[idx], self.model.parameters(), 
                                      create_graph=True, retain_graph=True, allow_unused=True) for idx in range(len(g_val))] 
        
        self.converged, self.projection_iter_steps, lam = newton_projection(self.g, self.model, self.Y,
                                                                            grad_g_prev, self.tau, 
                                                                            self.res_tol, self.max_newton_steps)
        
    def is_succeed(self):
        return self.converged
    
    def projection_steps(self):
        return self.projection_iter_steps

    def step(self, X, Y):
        self.reset_data(X, Y)
        self.unconstrained_update()
        self.projection_update()
        
class ProjectedLangevin():
    def __init__(self, model, f, g, tau=0.1, beta=1.0, alpha=0.5, res_tol=1e-5, max_newton_steps=100):
        self.model = model
        self.tau = tau
        self.beta = beta
        self.alpha = alpha
        self.res_tol = res_tol
        self.max_newton_steps = max_newton_steps
        self.p_list = []
        self.g = g
        self.f = f
        self.X = None
        self.Y = None
        
        for param in model.parameters():
            self.p_list.append(torch.zeros(param.size()))
    
    def reset_data(self, X, Y):
        self.X = X
        self.Y = Y
               
    def unconstrained_update(self):
        self.loss = self.f(self.model, self.X)
        # momnentum update
        grad_f = torch.autograd.grad(self.loss, self.model.parameters(), allow_unused=True)
        
        with torch.no_grad():
            for param, grad in zip (self.p_list, grad_f):
                param.add_(-0.5 * self.tau * grad)

            # position update
            for param, param_p in zip (self.model.parameters(), self.p_list):
                #print (param_p)
                param.add_(self.tau * param_p)
    
    def projection_momentum(self):
        g_val = self.g(self.model, self.Y)
        k = len(g_val)
        grad_g = [torch.autograd.grad(g_val[idx], self.model.parameters(), 
                                      create_graph=True, retain_graph=True, allow_unused=True) for idx in range(k)]
        with torch.no_grad():
            mat = np.zeros((k, k))
            for i in range(k):
                for j in range(k):
                    for z1, z2 in zip(grad_g[i], grad_g[j]):
                        mat[i,j] += (z1 * z2).sum()            
                        
            lam = np.zeros((k))
            for idx in range(k):
                for param, grad in zip (self.p_list, grad_g[idx]):        
                    lam[idx] += (param * grad).sum()
                    
            coeff = np.linalg.solve(mat, lam)    
                    
            for idx in range(k):
                for param, grad in zip (self.p_list, grad_g[idx]):        
                    param.add_(-1.0 * coeff[idx] * grad)
                                                    
    def momentum_refresh(self):
        for param in self.p_list:
            rn = torch.normal(torch.zeros(param.size()), torch.ones(param.size()))
            param.add_((self.alpha - 1.0) * param + math.sqrt((1-self.alpha**2)/self.beta) * rn)
        self.projection_momentum()
            
    def projection(self):
        g_val = self.g(self.model, self.Y)
        k = len(g_val)
        grad_g_prev = [torch.autograd.grad(g_val[idx], self.model.parameters(), 
                                      create_graph=True, retain_graph=True, allow_unused=True) for idx in range(k)] 
        
        #project position, x -> x_{1}
        converged, iter_steps, lam = newton_projection(self.g, self.model, self.Y, grad_g_prev, self.tau, 
                                                       self.res_tol, self.max_newton_steps)
        
        self.loss = self.f(self.model, self.X)
        # momnentum update: p_{1/2}->p_1
        grad_f = torch.autograd.grad(self.loss, self.model.parameters(), allow_unused=True)               
                
        with torch.no_grad():
            # modify momentum to get p_{1/2}
            for idx in range(k):
                for param, grad in zip (self.p_list, grad_g_prev[idx]):
                    param.add_(grad * lam[idx])

            # momnentum update: p_{1/2}->p_1
            for param, grad in zip (self.p_list, grad_f):
                param.add_(-0.5 * self.tau * grad)
        
        self.projection_momentum()
        
        # p_1 -> p_{1,-}
        for param in self.p_list:
            param *= -1.0

    def step(self, X, Y):
        self.reset_data(X, Y)
        self.momentum_refresh()
        self.unconstrained_update()
        self.projection()
#        self.momentum_refresh()

In [23]:
from test_problems import SimpleTest

problem = SimpleTest(1)
model = problem.create_model([1, -2.0])

g_val = problem.G(model, None)
grad_g_prev = [torch.autograd.grad(g_val[idx], model.parameters(), 
                                      create_graph=True, retain_graph=True, allow_unused=True) for idx in range(len(g_val))] 

print (grad_g_prev)
newton_projection(problem.G, model, None, grad_g_prev, tau=0.1, res_tol=1e-3)

print (model.x)
print (problem.G(model))

[(tensor([ 0., -1.], grad_fn=<AddBackward0>),)]
[(tensor([ 0., -1.], grad_fn=<AddBackward0>),)]
projection step=0, g=5.000e-01, x=[ 1. -2.]
projection step=1, g=1.250e-01, x=[ 1.  -1.5]
projection step=2, g=3.125e-02, x=[ 1.   -1.25]
projection step=3, g=7.812e-03, x=[ 1.    -1.125]
projection step=4, g=1.953e-03, x=[ 1.     -1.0625]
projection step=5, g=4.883e-04, x=[ 1.      -1.03125]
Parameter containing:
tensor([ 1.0000, -1.0312], requires_grad=True)
tensor([0.0005], grad_fn=<MulBackward0>)


### Test 1

In [19]:
from test_problems import SimpleTest

problem = SimpleTest(1.0)

# make sure the initial parameters satisfy the constraint.
model = problem.create_model([1.0, -1.0])

opt = ProjectedSGD(model, problem.F, problem.G, tau=0.01, beta=20.0, verbose=True)
#opt = ProjectedLangevin(model, simple_f, simple_g, tau=0.3, beta=8.0, alpha=0.5)

n_steps = 1000
for idx in range(n_steps):
    opt.step(None, None)
    if idx % 100 == 0:
        print (f'step={idx}, loss={opt.loss.detach().numpy()}')
    #print (model.x.detach().numpy())
print (model.x.detach().numpy())


[(tensor([ 0.9379, -0.0242], grad_fn=<AddBackward0>),)]
projection step=0, g=-3.693e-02, x=[ 0.96205753 -0.9862185 ]
projection step=1, g=1.509e-03, x=[ 1.0014076 -0.9872322]
projection step=2, g=2.146e-06, x=[ 0.99992114 -0.98719394]
step=0, loss=0.0
[(tensor([1.0778, 0.0418], grad_fn=<AddBackward0>),)]
projection step=0, g=3.752e-02, x=[ 1.036003   -0.99419504]
projection step=1, g=1.256e-03, x=[ 1.0012394  -0.99554354]
projection step=2, g=1.609e-06, x=[ 0.99999195 -0.99559194]
[(tensor([ 0.9919, -0.0023], grad_fn=<AddBackward0>),)]
projection step=0, g=-5.769e-03, x=[ 0.9942112 -0.9965046]
projection step=1, g=3.380e-05, x=[ 1.0000277 -0.9965181]
[(tensor([1.0612, 0.0655], grad_fn=<AddBackward0>),)]
projection step=0, g=-2.210e-03, x=[ 0.995633  -0.9300993]
projection step=1, g=4.530e-06, x=[ 0.9977078  -0.92997116]
[(tensor([ 0.9494, -0.0040], grad_fn=<AddBackward0>),)]
projection step=0, g=-4.548e-02, x=[ 0.9534242 -0.9574364]
projection step=1, g=2.285e-03, x=[ 1.0013299  -0.957

projection step=0, g=-3.134e-02, x=[ 0.31062844 -1.2275931 ]
projection step=1, g=9.041e-04, x=[ 0.29490247 -1.2513756 ]
[(tensor([-0.7611, -1.0143], grad_fn=<AddBackward0>),)]
projection step=0, g=4.648e-02, x=[ 0.25322258 -1.2675383 ]
projection step=1, g=1.559e-03, x=[ 0.27522054 -1.2382214 ]
projection step=2, g=1.967e-06, x=[ 0.2760112 -1.2371677]
[(tensor([-0.7418, -1.0003], grad_fn=<AddBackward0>),)]
projection step=0, g=3.375e-02, x=[ 0.25854632 -1.2588698 ]
projection step=1, g=8.488e-04, x=[ 0.2746871 -1.2371031]
[(tensor([-0.7666, -1.0260], grad_fn=<AddBackward0>),)]
projection step=0, g=5.996e-02, x=[ 0.25941876 -1.2853944 ]
projection step=1, g=2.540e-03, x=[ 0.2874418 -1.2478878]
projection step=2, g=5.484e-06, x=[ 0.28873855 -1.2461523 ]
[(tensor([-0.6226, -0.9254], grad_fn=<AddBackward0>),)]
projection step=0, g=-2.601e-02, x=[ 0.30278257 -1.2281456 ]
projection step=1, g=6.086e-04, x=[ 0.28976294 -1.2474971 ]
[(tensor([-0.6608, -0.9517], grad_fn=<AddBackward0>),)]
proj

projection step=0, g=2.368e-02, x=[-0.06781242 -0.95334417]
projection step=1, g=3.183e-04, x=[-0.05624184 -0.9424941 ]
[(tensor([-1.0473, -1.0189], grad_fn=<AddBackward0>),)]
projection step=0, g=1.953e-02, x=[-0.0283671  -0.99058276]
projection step=1, g=2.245e-04, x=[-0.01878646 -0.9812616 ]
[(tensor([-1.1167, -1.0580], grad_fn=<AddBackward0>),)]
projection step=0, g=6.139e-02, x=[-0.05870069 -0.99928653]
projection step=1, g=2.011e-03, x=[-0.02972966 -0.9718384 ]
projection step=2, g=2.503e-06, x=[-0.028714   -0.97087616]
[(tensor([-1.0318, -0.9885], grad_fn=<AddBackward0>),)]
projection step=0, g=-1.051e-02, x=[-0.04329547 -0.94519544]
projection step=1, g=6.813e-05, x=[-0.04860457 -0.95028174]
[(tensor([-1.1309, -1.0459], grad_fn=<AddBackward0>),)]
projection step=0, g=5.059e-02, x=[-0.08497721 -0.960944  ]
projection step=1, g=1.367e-03, x=[-0.06086822 -0.9386466 ]
projection step=2, g=1.132e-06, x=[-0.06017929 -0.93800944]
[(tensor([-0.9869, -0.9644], grad_fn=<AddBackward0>),)]

projection step=1, g=4.702e-04, x=[ 0.5159585 -1.3731208]
[(tensor([-0.3394, -0.8568], grad_fn=<AddBackward0>),)]
projection step=0, g=9.420e-04, x=[ 0.5174307 -1.3742559]
[(tensor([-0.3302, -0.8498], grad_fn=<AddBackward0>),)]
projection step=0, g=-3.991e-03, x=[ 0.51953137 -1.3692987 ]
projection step=1, g=1.729e-05, x=[ 0.5179455 -1.3733793]
[(tensor([-0.4244, -0.9041], grad_fn=<AddBackward0>),)]
projection step=0, g=2.377e-02, x=[ 0.47970548 -1.3838148 ]
projection step=1, g=5.519e-04, x=[ 0.48981664 -1.362275  ]
[(tensor([-0.4076, -0.8958], grad_fn=<AddBackward0>),)]
projection step=0, g=2.041e-02, x=[ 0.48818284 -1.384008  ]
projection step=1, g=4.141e-04, x=[ 0.49677297 -1.3651305 ]
step=600, loss=-0.8724583387374878
[(tensor([-0.5299, -0.9590], grad_fn=<AddBackward0>),)]
projection step=0, g=5.185e-02, x=[ 0.42903444 -1.3880115 ]
projection step=1, g=2.330e-03, x=[ 0.45192495 -1.3465891 ]
projection step=2, g=5.603e-06, x=[ 0.4530551 -1.3445439]
[(tensor([-0.5288, -0.9635], gra

[(tensor([-0.6202, -0.9357], grad_fn=<AddBackward0>),)]
projection step=0, g=-1.248e-02, x=[ 0.3154524 -1.2511499]
projection step=1, g=1.375e-04, x=[ 0.30931017 -1.260416  ]
[(tensor([-0.7599, -1.0012], grad_fn=<AddBackward0>),)]
projection step=0, g=3.030e-02, x=[ 0.24129769 -1.2424848 ]
projection step=1, g=6.767e-04, x=[ 0.25587207 -1.2232825 ]
[(tensor([-0.7895, -1.0350], grad_fn=<AddBackward0>),)]
projection step=0, g=6.575e-02, x=[ 0.24551797 -1.2805192 ]
projection step=1, g=2.975e-03, x=[ 0.27615285 -1.2403574 ]
projection step=2, g=7.391e-06, x=[ 0.27767706 -1.2383592 ]
[(tensor([-0.7145, -0.9691], grad_fn=<AddBackward0>),)]
projection step=0, g=1.992e-03, x=[ 0.2545681 -1.2236801]
projection step=1, g=3.099e-06, x=[ 0.2555497 -1.2223488]
[(tensor([-0.7099, -0.9724], grad_fn=<AddBackward0>),)]
projection step=0, g=7.258e-03, x=[ 0.26251984 -1.2349387 ]
projection step=1, g=4.184e-05, x=[ 0.26607412 -1.2300701 ]
[(tensor([-0.8162, -1.0186], grad_fn=<AddBackward0>),)]
projectio

projection step=1, g=5.013e-05, x=[ 0.13454548 -1.1255034 ]
[(tensor([-0.8367, -0.9998], grad_fn=<AddBackward0>),)]
projection step=0, g=1.307e-02, x=[ 0.16304103 -1.1628174 ]
projection step=1, g=1.203e-04, x=[ 0.16947408 -1.1551307 ]
[(tensor([-0.8799, -0.9962], grad_fn=<AddBackward0>),)]
projection step=0, g=2.936e-03, x=[ 0.11627478 -1.1124439 ]
projection step=1, g=5.960e-06, x=[ 0.11773736 -1.1107881 ]
[(tensor([-0.9635, -1.0291], grad_fn=<AddBackward0>),)]
projection step=0, g=3.166e-02, x=[ 0.06558446 -1.0946754 ]
projection step=1, g=6.217e-04, x=[ 0.08093599 -1.0782789 ]
[(tensor([-0.9298, -1.0270], grad_fn=<AddBackward0>),)]
projection step=0, g=3.206e-02, x=[ 0.09713642 -1.1241109 ]
projection step=1, g=6.546e-04, x=[ 0.11266678 -1.1069582 ]
[(tensor([-0.8976, -1.0008], grad_fn=<AddBackward0>),)]
projection step=0, g=6.108e-03, x=[ 0.10315827 -1.1039456 ]
projection step=1, g=2.515e-05, x=[ 0.10619213 -1.100563  ]
[(tensor([-0.9253, -1.0144], grad_fn=<AddBackward0>),)]
proj