In [1]:
%run base_test.ipynb

import numpy as np

from base import ml
from base import plot

In [2]:
def test_regressors(type_, average=False, poly=False, bilinear=False, 
                    gen_one_data=None, test_size=10000, average_top=5):
    if type_ == 'point':
        dir_ = '../saved_models/point'
        if gen_one_data is None:
            gen_one_data = ml.GenSolutionPoint(fenics_from_save=True)
        PolyInterp = ml.PolyInterpPoint
        BilinearInterp = ml.BilinearInterpPoint
    elif type_ == 'grid':
        dir_ = '../saved_models/grid'
        if gen_one_data is None:
            gen_one_data = ml.GenSolutionGrid(fenics_from_save=True)
        PolyInterp = ml.PolyInterpGrid
        BilinearInterp = ml.BilinearInterpGrid
    else:
        raise RuntimeError('Unknown type_ {}'.format(type_))
        
    dnn_factories, names = ml.dnn_factories_from_dir(dir_)
    extra_facs = []
    extra_names = []
    
    if average:
        average_reg = ml.RegressorAverager(regressor_factories=dnn_factories)
        if average_top is not None:
            average_reg.auto_mask(gen_one_data=gen_one_data, top=average_top, 
                                  batch_size=test_size)
        average_reg_fac = ml.RegressorFactory(regressor=average_reg)
        extra_facs.append(average_reg_fac)
        extra_names.append('AverageReg')
        
    if poly:
        # Warning: the polynomial interpolation takes a while.
        for poly_deg in (3, 5, 7, 9):
            poly_fac = ml.RegressorFactory(regressor=PolyInterp(poly_deg=poly_deg))
            extra_facs.append(poly_fac)
            extra_names.append('Poly{}Reg'.format(poly_deg))
            
    if bilinear:
        bilin_fac = ml.RegressorFactory(regressor=BilinearInterp())
        extra_facs.append(bilin_fac)
        extra_names.append('BilinearReg')
        
    results = ml.eval_regressors([*dnn_factories,
                                  *extra_facs],
                                 gen_one_data, 
                                 batch_size=test_size)
    
    bests = []
    for stat in ('average_loss', 'max_deviation'):
        print(stat)
        print('-' * len(stat))
        results_names = sorted(zip(results, [*names, *extra_names]), 
                                   key=lambda x: x[0][stat])
        bests.append([name for _, (result, name) in zip(range(6), results_names)])
        for result, name in results_names:
            try:
                num_neurons = eval(name.split('_')[0].replace('x', '*'))
            except Exception:
                num_neurons = ''
            print('{:>4} | {:>34} | {}'.format(num_neurons, 
                                               name, 
                                               result[stat]))
        print('')

    best = [x for x in bests[0] if all(x in best_ for best_ in bests[1:])]
    print('common best')
    print('-----------')
    for b in best:
        print(b)
    if not best:
        print('No common best')
        
    return results