In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook
import numpy as np
import numpy.random as rn
import os
import pandas as pd
from scipy.linalg import eigh, norm
import sys
import time
import torch as tr
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import gen_data_package as gen_data_package

In [None]:
cfg = {
    # meta
    'save_losses' : False,

    # dataset
    'dim' : 2,
    'gen_x_func' : 'gen_x_circle_regular',
    'gen_y_func' : 'gen_y_fourier_norm_1',
    'add_phase' : False,
    'n_val' : 1001,
    'n_train' : 1001,
    'ks' : [4, 14],
    'resample' : False,

    # network
    'n_hidden' : 3,
    'n_units' : 256,
    #'kappa' : .05,
    'hidden_bias' : 'zeros', # none/zeros/normal/default
    'outer_fixed' : False,
    'even_only' : False,
    'odd_only' : False,

    # optimization
    'eta' : .001,
    'n_epochs_max' : 100000,
    'n_batch' : 0,
    'stop_threshold_percent' : 1,
    'max_training_time_in_minutes' : 600,
}

In [None]:
def sim(cfg, snapshot_epochs):
    
    if cfg['n_batch'] == 0: cfg['n_batch'] = cfg['n_train']

    device = tr.device("cuda:0" if tr.cuda.is_available() else "cpu")
    use_parallel_gpus = False
    print('\r\ndevice is: %s\r\n' % device)

    ##################################################################################

    # define network

    class Net(nn.Module):

        def __init__(self, d, n_sizes, hidden_bias, outer_fixed, out_size = 1):
            super(Net, self).__init__()

            n_hidden = len(n_sizes)
            hidden_in_sizes = n_sizes[:]
            hidden_in_sizes.insert(0,d)
            hidden_out_sizes = n_sizes
            outer_in_size = n_sizes[-1]
            outer_out_size = out_size

            self.hidden = nn.ModuleList()
            self.hidden.extend([nn.Linear(hidden_in_sizes[i], hidden_out_sizes[i], bias=(hidden_bias!='none')) for i in range(n_hidden)])
            for i in range(n_hidden):
                # init normal weights in hidden layers
                # tr.nn.init.normal_(self.hidden[i].weight, mean=0, std=kappa)
                tr.nn.init.kaiming_normal_(self.hidden[i].weight, a=np.sqrt(5))

                # init either 0's or normal biases in hidden layers
                if hidden_bias == 'zeros':
                    self.hidden[i].bias.data = tr.zeros([hidden_out_sizes[i]])
                elif hidden_bias == 'normal':
                    std = 1 / np.sqrt(3*hidden_in_sizes[i])
                    tr.nn.init.normal_(self.hidden[i].bias, mean=0, std=std)

            if outer_fixed:
                # set requires_grad to False, and initialize with values of +-1
                self.outer = nn.Linear(outer_in_size, outer_out_size, bias=False)
                for param in self.outer.parameters():
                    param.requires_grad = False
                self.outer.weight.data = tr.from_numpy(np.sign(rn.uniform(-1, 1, [outer_out_size, outer_in_size])) / np.sqrt(outer_in_size)).float()
            else:
                # initialize with normal
                self.outer = nn.Linear(outer_in_size, outer_out_size, bias=False)
                tr.nn.init.kaiming_normal_(self.outer.weight, a=np.sqrt(5))

        def forward(self, x):
            for i in range(len(self.hidden)):
                x = F.relu(self.hidden[i](x))
            x = self.outer(x)
            return x

    class XYDataset():
        def __init__(self, data):
            self.x = tr.from_numpy(data['x']).float().view(len(data['x']), -1)
            self.y = tr.from_numpy(data['y']).float().view(len(data['y']), -1)
            self.vals = data['vals']
            self.theta = tr.from_numpy(data['theta']).float().view(len(data['x']), -1)

        def __len__(self):
            return len(self.y)

        def __getitem__(self, idx):
            return self.x[idx], self.y[idx], self.theta[idx]

    def gen_data(cfg, pr=False):
        data_train = gen_data_package.gen_xy(cfg, 'train')
        if cfg['gen_y_func'] == 'gen_y_H_inf' or cfg['gen_y_func'] == 'gen_y_H_0' or cfg['gen_y_func'] == 'gen_y_H_inf_norm_1':
            inds = rn.permutation(cfg['n_train'])
            inds = inds[:cfg['n_train']]
            data_val = {'x' : data_train['x'][inds, ...], 'y' : data_train['y'][inds], 'vals' : data_train['vals'], 'theta' : data_train['theta'][inds]}
        else:
            data_val = gen_data_package.gen_xy(cfg, 'val')

        if pr:
            plt.gray()
            plt.scatter(data_train['x'][:,0], data_train['x'][:,1], c=data_train['y'])
            
        trainset = XYDataset(data_train)
        valset = XYDataset(data_val)

        return {'trainset' : trainset, 'valset' : valset, 'W' : data_train['W']}

    def gen_net(cfg, W, device, use_parallel_gpus = False, pr = False):
        net = Net(d=cfg['dim'],
                    n_sizes=[cfg['n_units'] for i in range(cfg['n_hidden'])],
                    hidden_bias=cfg['hidden_bias'],
                    outer_fixed=cfg['outer_fixed'])
        if pr: print('\r\nnet:\r\n', net)
        if tr.cuda.device_count() > 1 and use_parallel_gpus:
            print("Let's use", tr.cuda.device_count(), "GPUs!")
            net = nn.DataParallel(net)
        net.to(device)
        return net

    def compute_lambda_0(trainset):
        x = trainset.x.numpy()
        gram = np.clip(np.matmul(x, x.T), -1, 1)
        arcs = np.arccos(gram)
        H = gram * (np.pi - arcs) / (2*np.pi)
        vals, eigs = eigh(H, eigvals = (0, 0))
        lambda_0 = vals[0]
        return lambda_0

    def optimize(cfg, device, net, datasets, pr = False):
        
        # gen data
        trainset = datasets['trainset']
        valset = datasets['valset']
        
        # set live plot:
        fig = plt.figure()
        ax_loss = fig.add_subplot(121)
        ax_y = fig.add_subplot(122)
        plt.ion()
        fig.show()
        fig.canvas.draw()
        theta_val = valset.theta.numpy().reshape(-1).copy()
        val_sort_inds = np.argsort(theta_val)
        theta_val_sorted = valset.theta.numpy()[val_sort_inds].reshape(-1)
        y_val_sorted = valset.y.numpy()[val_sort_inds].reshape(-1)
        theta_train = trainset.theta.numpy().reshape(-1).copy()
        train_sort_inds = np.argsort(theta_train)
        theta_train_sorted = trainset.theta.numpy()[train_sort_inds].reshape(-1)
        y_train_sorted = trainset.y.numpy()[train_sort_inds].reshape(-1)
        
        # compute eta
        lambda_0 = compute_lambda_0(trainset)
        print('\r\nlambda_0(H_\\infty) is: %g' % lambda_0)
        eta = lambda_0 / cfg['n_train']**2
        print('\r\nThreshold for learning rate is: %g' % eta)
        print('\r\nConfigured learning rate is: %g' % cfg['eta'])

        # create optimizer
        criterion = nn.MSELoss()
        optimizer = optim.SGD(net.parameters(), lr = cfg['eta'])

        epoch_loss = np.zeros(cfg['n_epochs_max'])
        train_loss = np.zeros(cfg['n_epochs_max'])
        val_loss = np.zeros(cfg['n_epochs_max'])

        with tr.no_grad():
            val_output_before_training = net(valset.x.to(device))
            val_loss_before_training = criterion(val_output_before_training, valset.y.to(device)) * cfg['n_val'] / 2
            train_output_before_training = net(trainset.x.to(device))
            train_loss_before_training = criterion(train_output_before_training, trainset.y.to(device)) * cfg['n_train'] / 2
        print('\r\ninitial val/train loss is %.2f/%.2f\r\n' % (val_loss_before_training, train_loss_before_training))
        
        snapshot_train_outputs = []#[train_output_before_training.numpy()]

        # train
        converged = False
        stopping_criterion = ''
        n_batches = cfg['n_train'] // cfg['n_batch']
        total_training_time = time.time()
        total_optimization_time = 0
        total_iter_time = 0
        total_sampling_time = 0
        total_compute_losses_time = 0
        other_time = 0
        epoch = 0
        
        while True:

            total_sampling_time -= time.time()
            if cfg['resample']:
                epoch_datasets = gen_data(cfg, pr=False)
                trainset = epoch_datasets['trainset']
                #trainloader = epoch_datasets['trainloader']
            total_sampling_time += time.time()

            total_iter_time -= time.time()
            epoch_Is = rn.permutation(cfg['n_train'])
            total_iter_time += time.time()
            accumulated_loss = .0

            if epoch in snapshot_epochs:
                with tr.no_grad():
                    train_output_before = net(trainset.x.to(device))
                print('epoch = %d, saving snapshot' % epoch)
                snapshot_train_outputs.append(train_output_before.numpy())

    #        for i, data in enumerate(trainloader):
            for i_batch in range(n_batches):
                # get the inputs
                total_iter_time -= time.time()
                batch_Is = epoch_Is[i_batch*cfg['n_batch']:(i_batch+1)*cfg['n_batch']]
                x_train_batch, y_train_batch = trainset.x[batch_Is], trainset.y[batch_Is]
                x_train_batch, y_train_batch = x_train_batch.to(device), y_train_batch.to(device)
                total_iter_time += time.time()

                total_optimization_time -= time.time()
                # zero the parameter gradients
                optimizer.zero_grad()   # zero the gradient buffers

                # forward + backward + optimize
                output = net(x_train_batch)
                loss = criterion(output, y_train_batch) * cfg['n_batch'] / 2
                loss.backward()
                optimizer.step()
                total_optimization_time += time.time()

                # print statistics
                accumulated_loss += loss.item()

            total_compute_losses_time -= time.time()
            epoch_loss[epoch] = accumulated_loss / n_batches
            with tr.no_grad():
                val_output = net(valset.x.to(device))
                val_loss[epoch] = criterion(val_output, valset.y.to(device)) * cfg['n_val'] / 2
                train_output_after = net(trainset.x.to(device))
                train_loss[epoch] = criterion(train_output_after, trainset.y.to(device)) * cfg['n_train'] / 2
            total_compute_losses_time += time.time()
            
            if epoch % 100 == 0:
                
                ax_loss.clear()
                ax_loss.plot(train_loss[min(50,epoch):epoch], '.-')
                ax_loss.set_xlabel('epochs')
                ax_loss.set_ylabel('train MSE')
                
                ax_y.clear()
                ax_y.plot(theta_train_sorted, y_train_sorted, '--y', label='target on train set')
                ax_y.plot(theta_val_sorted, val_output.numpy().reshape(-1)[val_sort_inds], '-', label='output on val set')
                ax_y.plot(theta_train_sorted, train_output_before.numpy().reshape(-1)[train_sort_inds], '.', label='output on train set')
                ax_y.set_ylim([2*np.min(y_val_sorted), 2*np.max(y_val_sorted)])
                ax_y.set_xlabel('x')
                ax_y.set_ylabel('label/output')
                ax_y.set_title('k = %d, epoch = %d' % (cfg['ks'][0], epoch))
                ax_y.legend()
                
                fig.canvas.draw()            

            epoch += 1
            if pr and epoch % 500 == 0:
                print('epoch %7d val/train loss: %.2f/%.2f (%.2f%% for training)\r\n' %
                      (epoch, val_loss[epoch-1], train_loss[epoch-1], train_loss[epoch-1] / cfg['zero_val_loss'] * 100))
            #if train_loss[epoch-1] / cfg['zero_val_loss'] < cfg['stop_threshold_percent'] / 100:
            if train_loss[epoch-1] < cfg['stop_threshold_percent'] / 100:
                print('\r\n@@@ after %d epochs, training loss is %.2f, reached stop threshold of %.2f%% and training is done\r\n'
                      % (epoch, train_loss[epoch-1], cfg['stop_threshold_percent']))
                converged = True
                stopping_criterion = 'converged'
                break
            training_time_so_far_in_minutes = (time.time() - total_training_time) / 60
            if training_time_so_far_in_minutes > cfg['max_training_time_in_minutes']:
                print('\r\nafter %d epochs, training time exceeded %d minutes threshold and training is stopped\r\n'
                      % (epoch, cfg['max_training_time_in_minutes']))
                stopping_criterion = 'time_out'
                break
            if np.isnan(epoch_loss[epoch-1]):
                print('\r\nafter %d epochs, epoch loss is NaN and training is stopped\r\n'
                     % epoch)
                stopping_criterion = 'reached_nan'
                break
            if epoch == cfg['n_epochs_max']:
                print('\r\nreached maximal number of epochs %d and training is stopped\r\n'
                     % cfg['n_epochs_max'])
                stopping_criterion = 'epochs_over'
                break

        #total_compute_losses_time += time.time()
        total_training_time = time.time() - total_training_time
        print('\r\nfinished training! training time: %.2f minutes (optimization time: %.2f%%, iter time: %.2f%%, compute train/val loss time: %.2f%%, sampling time: %.2f%%\r\n)'
              % (total_training_time/60,
              total_optimization_time/total_training_time*100,
              total_iter_time/total_training_time*100,
              total_compute_losses_time/total_training_time*100,
              total_sampling_time/total_training_time*100))

        with tr.no_grad():
            val_output_after_training = net(valset.x.to(device))
            val_loss_after_training = criterion(val_output_after_training, valset.y.to(device)) * cfg['n_val'] / 2
            train_output_after_training = net(trainset.x.to(device))
            train_loss_after_training = criterion(train_output_after_training, trainset.y.to(device)) * cfg['n_train'] / 2
        print('\r\nfinal val/train loss is %.2f/%.2f\r\n (%.2f%% for validation)' % (val_loss_after_training, train_loss_after_training,
            val_loss_after_training / cfg['zero_val_loss'] * 100))
        snapshot_epochs.append(epoch)
        snapshot_train_outputs.append(train_output_after_training.numpy())
    
        return {'num_epochs' : epoch,
                'final_train_loss' : train_loss_after_training.numpy().flatten()[0],
                'final_val_loss' : val_loss_after_training.numpy().flatten()[0],
                'converged' : converged,
                'training_time' : total_training_time,
                'stopping_criterion' : stopping_criterion,
                'lambda_0' : lambda_0,
                'theta_train' : trainset.theta.numpy(),
                'y_train' : trainset.y.numpy(),
                'snapshot_epochs' : np.array(snapshot_epochs),
                'snapshot_train_outputs' : np.array(snapshot_train_outputs),
               }

    results = pd.DataFrame()

    if cfg['add_phase']:
        cfg['phases'] = rn.uniform(-np.pi, np.pi, len(cfg['ks']))
    else:
        cfg['phases'] = []
    print('\r\ntest cfg: \r\n', cfg, '\r\n')

    total_loops_time = time.time()

    pr = True

    # generate new dataset
    datasets = gen_data(cfg, pr=False)
    cfg['zero_val_loss'] = tr.sum(datasets['valset'].y**2).cpu().numpy() / 2
    if pr: print('zero val loss is: %.3f\r\n' % cfg['zero_val_loss'])

    # create new network
    net = gen_net(cfg, datasets['W'], device, use_parallel_gpus, pr)

    # optimize and save results
    my_dict = {'ks' : cfg['ks'], 'phases' : cfg['phases'], 'eigenvalues' : datasets['trainset'].vals}
    my_dict.update(optimize(cfg, device, net, datasets, pr))
    results = results.append(my_dict, ignore_index = True)

    total_loops_time = (time.time() - total_loops_time) / 60
    print('\r\nfinished looping! total looping time: %.2f minutes\r\n' % total_loops_time)

    print(results)
    print('\r\nstopping criterion: %s' % results['stopping_criterion'])

    return results


In [None]:
# calling the sim
results = sim(cfg, [0,50,500])

filename = 'superposition_k1%d_k2%d_depth%d' % (cfg['ks'][0], cfg['ks'][1], cfg['n_hidden'])
results.to_pickle('results/%s.pkl' % filename)