# Functions to train neural nets with non-convex training problem

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import scipy
import torch
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import datetime
torch.set_default_dtype(torch.float64)
plt.rcParams.update({'font.size': 20})

For $L=2,4$, we initialize a subset of the neurons to a solution of the min-norm version of the Lasso problem. The rest of the neurons are initialized randomly.

Train a neural net by running ```nonconvex_nn()```. Plot the results by running ```plotresults()```.

In [2]:
class PrepareData(Dataset):
    def __init__(self, X, y):
        if not torch.is_tensor(X):
            self.X = torch.from_numpy(X)
        else:
            self.X = X
            
        if not torch.is_tensor(y):
            self.y = torch.from_numpy(y)
        else:
            self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
    
    
class DeepNarrowNet1D(nn.Module): #ez
    def __init__(self, m_L, L, par):
        super().__init__()
        
        #generic layers (we don't use)
        self.one2one = nn.Linear(1, 1, bias=True)
        self.m2m = nn.Linear(m_L, m_L, bias=True)
        
        #layers with names (we use)
        self.one2one1 = nn.Linear(1, 1, bias=True)
        self.one2one2 = nn.Linear(1, 1, bias=True)
        self.one2m = nn.Linear(1, m_L, bias=True)
        self.last = nn.Linear(m_L, 1, bias=True)
        self.L=L
        self.par=par
        self.m_L=m_L
        
        self.mask = torch.eye(m_L, dtype=bool)
        
        
    def forward(self, x): #can I write this in matrix form
        
        if self.par:
            x = self.one2m(x)
            for i in range(self.L-2):  
                self.m2m.weight.data *= self.mask
                #self.m2m.bias.data *= self.mask
                x = F.relu(self.m2m(x))
            x = self.last(x) 
        else:
            for i in range(self.L-2): 
                #x = F.relu(self.one2one(x))
                if i==0:
                    x = F.relu(self.one2one1(x)) #name layers so we can initialize them easily 
                elif i==1:
                    x = F.relu(self.one2one2(x)) 
            x = F.relu(self.one2m(x))
            x = self.last(x)      

        return x

    
def print_weights(module): #this function is from: https://wandb.ai/wandb_fc/tips/reports/How-to-Initialize-Weights-in-PyTorch--VmlldzoxNjcwOTg1
        if isinstance(module, nn.Linear):
            print('weights = ', + module.weight.data)
            if module.bias is not None:
                print('bias = ', + module.bias.data)

In [1]:
def setup(L, opt, move0toneg1):
    min_num_nuerons = 1 #default
    
    if opt:
        X = np.array([0,4,5,7,10,11]).astype(np.float64).reshape(-1,1) 
        y = np.array((0,0,0,2, 3, 3)).astype(np.float64) 
        seed= 7
        β=(1e-7)/L
        if L==2:
            m_L = 100
            num_epochs = int(1e3)
            min_num_nuerons = 3
        if L==3:
            m_L = 500 
            num_epochs = int(1e5)
        if L==4:
            m_L = 100
            num_epochs = int(1e3) 
    else:
        X = np.array([0,2,6,7]).astype(np.float64).reshape(-1,1) 
        y = np.array([0,0,3,3]).astype(np.float64) 
        seed= 8
        β=(1e-8)/L
        if L==2:
            m_L = 100
            num_epochs = int(1e6)
            min_num_nuerons = 2
        if L==3:
            m_L = 100 
            num_epochs = int(1e3)
        if L==4:           
            m_L = 100
            if move0toneg1:
                num_epochs = int(1.5e4) 
            else:
                num_epochs = int(1e3)

    if move0toneg1:
        X[0]=-1
        
    setup_dic={'X':X, 'y':y, 'seed':seed, 'β':β, 'm_L': m_L, 'num_epochs':num_epochs, 'min_num_nuerons':min_num_nuerons}
        
    return setup_dic

def prepTrainTestData(X,y,batchsize_frac=1):
    n,d=X.shape
    batch_size= int(n*batchsize_frac) #n/2 can also help
    ds_train = PrepareData(X=X, y=y)
    ds_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
    ds_test = PrepareData(X=X, y=y)
    ds_test = DataLoader(ds_test, batch_size=batch_size, shuffle=False)
    
    return ds_train, ds_test

def initnn(net, L, opt, initkinkat4, move0toneg1, min_num_nuerons, debug=False, custom_init=True, initrescale = True, init_normal=False, std=0.01):
    
    if debug:
        init_perfect = True
    else:
        init_perfect = False
        
        
    #initialize a minimal subset of neurons to generate an optimal solution, and initialize all other neurons randomly.
    if custom_init:
        if init_perfect: #rest of optimal net
            net.one2m.weight.data.fill_(0.0)
            net.one2m.bias.data.fill_(0.0)
            net.last.weight.data.fill_(0.0)
            net.last.bias.data.fill_(0.0)
        if init_normal:
            net.one2m.weight.data.normal_(mean=0.0,std=std)
            net.one2m.bias.data.normal_(mean=0.0,std=std)
            net.last.weight.data.normal_(mean=0.0,std=std)
            net.last.bias.data.normal_(mean=0.0,std=std)

        if L==2: #b is kink points, w is weights
            if opt:  
                α = np.array((1, -2/3, -1/3))
                w = np.array((1,1,1))
                b = np.array((-5,-7,-10))  
            else:
                w = np.array((1,1))
                b = np.array((-2,-5))
                slope = -3/(b[1]-b[0])
                α = np.array((slope, -1*slope))

            γ = np.abs(α)**(1/L)

            #optimal subnet
            net.one2m.weight.data[0:min_num_nuerons] = torch.tensor((w*γ).reshape(min_num_nuerons,1)) 
            net.one2m.bias.data[0:min_num_nuerons] = torch.tensor(b*γ)   
            net.last.weight.data[0][0:min_num_nuerons]= torch.tensor(np.sign(α)*γ)

        if L==4:  
            if opt:
                xj1 = 0
                xj2 = 4
                xj3 = 5 
                α = 1
            else:
                if move0toneg1 and not initkinkat4:
                    xj1 = -1
                    α = 1
                else:
                    xj1 = 0
                    α = 3/2
                xj2 = 2
                xj3 = 2 

            #s, j found in 1DNNs.ipynb examples
            s= np.array((-1,-1,1)) #l=3,k=0 to get ramp

            #a found from reconstruction theorem
            a1 = 2*xj2 - xj1 # R(xj1,xj2) 
            a2 = max(s[0]*(xj2-a1),0)  #4
            a3 = max( s[1]*(max(s[0]*(xj3-a1),0)-a2) , 0)
            a = np.array((a1,a2,a3)) 

            #rescaling
            if initrescale:
                γ = abs(α)**(1/L) #1
                biasγ = np.array([γ**(i+1) for i in range(L)]) #over layers
            else:
                γ=1

            #optimal subnet
            w = s*γ #1 neuron weight over layers 
            b = -1*s*a*biasγ[:-1] #1 neuron bias over layers
            net.one2one1.weight.data[0]=torch.tensor([w[0]])  
            net.one2one1.bias.data[0]=torch.tensor(b[0])
            net.one2one2.weight.data[0]=torch.tensor([w[1]])
            net.one2one2.bias.data[0]=torch.tensor(b[1])
            net.one2m.weight.data[0]=torch.tensor([w[2]])
            net.one2m.bias.data[0]=torch.tensor(b[2])
            if initrescale:
                net.last.weight.data[0][0]=torch.tensor([np.sign(α)*γ])
            else:
                net.last.weight.data[0][0]=torch.tensor([α])
            net.last.bias.data[0]=torch.tensor([0*biasγ[3]])

    return net
    
    
def train(net, num_epochs, ds_train, β, use_scheduler=False, opt_method='Adam'):
    
    criterion = nn.MSELoss(reduction='sum')

    if use_scheduler: 
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=200, verbose=True, factor=0.5, eps=1e-12)

    if opt_method == 'Adam':
        optimizer = optim.Adam(net.parameters(), lr=0.5*1e-2, weight_decay=1e-4)
    elif opt_method == 'SGD':
        optimizer = optim.SGD(net.parameters(), lr=0.5*1e-4, momentum=0.9)

    loss_vec=np.zeros((num_epochs))
    for epoch in range(num_epochs):  # loop over the dataset multiple times

        loss_print = 0.0
        for i, data in enumerate(ds_train):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs=inputs.to("cpu").to(torch.float64)
            labels=labels.to("cpu").to(torch.float64)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels.reshape(-1,1))
            non_bias_params = []

            #add l_L penalty
            for name, param in net.named_parameters():
                if 'bias' not in name:
                    non_bias_params.append(param)
            lL_penalty = β* sum([(torch.pow(p,L)).sum() for p in non_bias_params]) #ez. usual regularization is l2 norm, not l_L norm^L
            loss = loss + lL_penalty

            loss.backward()
            optimizer.step()
            loss_vec[epoch]=loss.item()
            # print statistics
            loss_print = loss.item()
            if num_epochs > 5:
                if epoch % (num_epochs//5)  == 0: 
                    print(f'[{epoch + 1}, {i + 1:1d}] loss: {loss_print :.7f}') #ez changed .5 to .6
            if use_scheduler: 
                scheduler.step(loss_print) 

    print('Finished Training')
    return net, loss_vec


def nonconvex_nn(L,opt,move0toneg1,debug=False,verbose=False,initkinkat4=True,par=False):
    setup_dic = setup(L=L, opt=opt, move0toneg1=move0toneg1)
    X= setup_dic['X']
    y= setup_dic['y']
    β= setup_dic['β']
    m_L= setup_dic['m_L']
    seed=setup_dic['seed']
    num_epochs= setup_dic['num_epochs']
    min_num_nuerons= setup_dic['min_num_nuerons']
    torch.manual_seed(seed)

    if debug:
        m_L = min_num_nuerons #3 if L=2 and opt
        num_epochs = 0
        verbose = True
        init_perfect = True
        custom_init = True

    ds_train, ds_test = prepTrainTestData(X=X,y=y)

    net = DeepNarrowNet1D(m_L=m_L, L=L, par=par).to("cpu") 
    net = initnn(net=net, L=L, opt=opt, initkinkat4=initkinkat4, move0toneg1=move0toneg1, min_num_nuerons=min_num_nuerons, debug=debug)
    net = net.to(torch.float64)
    if verbose:
        print('net=',net) 
        net.apply(print_weights)

    net, loss_vec = train(net=net, num_epochs=num_epochs, ds_train=ds_train, β=β)
    if num_epochs>1:
        print('objective = ', loss_vec[-1])
    
    return net, loss_vec, X, y,  num_epochs

def plotresults(X,y,L):
    noncvxtitle = 'training with Adam'
    linewidth=3.0
    markersize=100
    width=6
    height=3
    datalabel='$(x_n,y_n)$'
    nnlabel = 'net'

    xaxislim = [X[0]-1,X[-1]+1]
    xaxis = np.linspace(min(xaxislim)+1,max(xaxislim)-1,int(max(xaxislim)-min(xaxislim)-1))
    Xtest=torch.tensor(np.linspace(X[0]-0.5,X[-1]+0.5,1000).reshape(-1,1)).to("cpu").to(torch.float64) #old data

    if opt:
        yaxislim = [-0.5,5.5]
    else:
        yaxislim = [-0.2,3.2]

    out1=net(Xtest).detach().to("cpu").numpy()
    xvals = Xtest.cpu().numpy()


    ### plot
    plt.figure(figsize=(width,height))
    plt.scatter(X.reshape(-1),y, marker="o", color="red", label = datalabel, s=markersize)
    plt.plot(Xtest.cpu().numpy(),out1, label = nnlabel, color='blue',linewidth=linewidth)

    plt.xlim(xaxislim)
    plt.ylim(yaxislim)
    if opt and L==2:
        plt.title(noncvxtitle)

    if not opt and L==3:
        plt.title(noncvxtitle)

    ax = plt.gca()
    ax.set_xticks(xaxis[:,0], labels=xaxis[:,0].astype(int))

    if L<4:
        plt.xticks([])
    plt.ylabel(str(L)+' layers')
    plt.legend()

    kink_index = min([i for i in range(len(out1)) if np.abs(out1[i]-3.0)<1e-2])
    kink = xvals[kink_index][0]
    print('kink = ', kink)
    

In [4]:
#L = 4
#opt = False
#move0toneg1 = False

In [5]:
#net, loss_vec, X, y, num_epochs = nonconvex_nn(L=L,opt=opt,move0toneg1=move0toneg1)

In [6]:
#plt.figure()
#plt.plot(np.arange(len(loss_vec[100:])), loss_vec[100:])
#plt.xlabel('epochs')
#plt.ylabel('training loss')

In [7]:
#plotresults(X=X,y=y,L=L,Xtest=Xtest)