In [1]:
import argparse
import glob
import multiprocessing as mp
import os
import time
import yaml
import pickle as pkl
import hashlib
import json
import pandas as pd
import numpy as np
from itertools import product
from datetime import datetime

def get_parser():
    parser = argparse.ArgumentParser(description='Testing')
    parser.add_argument('--config-file', type=str, default='configs/bomp_default.yaml', metavar= "FILE" ,help='path to config file')
    parser.add_argument("--output", type=str, help="Output path")
    return parser


def get_cfg(config_file):
    with open(config_file, 'r') as f:
        config = yaml.safe_load(f)
    return config


def merge_cfg(default_dict, input_dict):
    merged_dict = default_dict.copy()  # Start with default values.
    sections = ['MODEL', 'TEST', 'UTILS']  # Specify sections to update

    for section in sections:
        if section in default_dict and section in input_dict:
            for key in default_dict[section]:
                # Check if the key is in the user input dictionary
                if key in input_dict[section]:
                    # If it is, update the merged dictionary
                    merged_dict[section][key] = input_dict[section][key]
                else:
                    # If not, print a message about using the default value
                    print(f"Missing parameter '{key}' in section '{section}', default value '{default_dict[section][key]}' will be used.")
        else:
            print(f"Missing section '{section}' in the user input, default values will be used.")

    # Check for invalid keys in the user input dictionary
    for section in input_dict:
        if section in sections:
            for key in input_dict[section]:
                if key not in default_dict[section]:
                    print(f"Invalid key '{key}' in section '{section}'. This key will be ignored.")

    return merged_dict
    
def get_output_path(output_path, config_filename):
    if output_path is None:
        # output file will be a pickle file in the outputs folder
        output_path = os.path.join("./memory", config_filename.split("/")[-1].split(".")[0] + ".pkl")
    else:
        # output file will be a pickle file in the specified folder
        output_path = os.path.join(output_path, config_filename.split("/")[-1].split(".")[0] + ".pkl")
    return output_path    

In [2]:
# default
cfg = get_cfg("configs/bomp_default.yaml")
cfg

{'TEST': {'n': 600,
  'p': 1000,
  'm': 20,
  'noise_level': [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
  'model': 'BOMP',
  'cv_num': 5,
  'trial_num': 10},
 'MODEL': {'signal_bag_flag': False,
  'signal_bag_percent': 0.7,
  'atom_bag_percent': 0.7,
  'select_atom_percent': 0.3,
  'replace_flag': False,
  'agg_func': 'weight',
  'K_start': 1,
  'K_end': 40,
  'K_step': 1,
  'random_seed': 0}}

In [3]:
def get_model_params(cfg):
    import numpy as np
    all_params = cfg['MODEL']
    param_grid = {}
    fixed_params = {}
    K_start, K_end, K_step = all_params['K_start'], all_params['K_end'], all_params['K_step']
    if K_start >= K_end:
        raise ValueError("K_start must be smaller than K_end")
    if K_step <= 0:
        raise ValueError("K_step must be positive")
    # Check if K_start, K_end, K_step are integers
    if not isinstance(K_start, int) or not isinstance(K_end, int) or not isinstance(K_step, int):
        raise ValueError("K_start, K_end, K_step must be integers")
    K_list = np.arange(K_start, K_end, K_step, dtype=int)
    # Check if the param is a list or a single value if it is a list save to param_grid or else save to fixed_params
    for param, value in all_params.items():
        if param in ['K_start', 'K_end', 'K_step']:
            continue
        if isinstance(value, list):
            param_grid[param] = value
        else:
            fixed_params[param] = value
    param_grid['K'] = K_list
    return fixed_params, param_grid

In [4]:
default_config = get_cfg("configs/bomp_default.yaml")
input_config = get_cfg("configs/bomp_test.yaml")
config = merge_cfg(default_config, input_config)
config

Missing parameter 'signal_bag_percent' in section 'MODEL', default value '0.7' will be used.
Missing parameter 'atom_bag_percent' in section 'MODEL', default value '0.7' will be used.
Missing section 'UTILS' in the user input, default values will be used.
Invalid key 'N_bag' in section 'MODEL'. This key will be ignored.


{'TEST': {'n': 600,
  'p': 1000,
  'm': 20,
  'noise_level': [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
  'model': 'BOMP',
  'cv_num': 5,
  'trial_num': 10},
 'MODEL': {'signal_bag_flag': False,
  'signal_bag_percent': 0.7,
  'atom_bag_percent': 0.7,
  'select_atom_percent': [0, 0.1, 0.2, 0.3, 0.4, 0.5],
  'replace_flag': False,
  'agg_func': ['avg', 'weight'],
  'K_start': 1,
  'K_end': 41,
  'K_step': 1,
  'random_seed': 0}}

In [5]:
fixed_params, param_grid = get_model_params(config)

print(fixed_params)
print(param_grid)

{'select_atom_percent': [0, 0.1, 0.2, 0.3, 0.4, 0.5], 'agg_func': ['avg', 'weight'], 'K': array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
       35, 36, 37, 38, 39, 40])}


In [6]:
fixed_params['signal_bag_percent'] = 0.5

fixed_params

{'signal_bag_flag': False,
 'signal_bag_percent': 0.5,
 'atom_bag_percent': 0.7,
 'replace_flag': False,
 'random_seed': 0}

In [7]:
from algorithms import BOMP

my_bomp = BOMP(**fixed_params)

my_bomp

In [8]:
from data_generation import GaussianDataGenerator

N = 1000
d = 400
m = 40
noise_level = 0.05
seed = 0

Data_Geneartor = GaussianDataGenerator(N, d, m, 0.05, 0)

true_signal, dictionary, true_indices, true_coefficients, perturbed_signal = Data_Geneartor.shuffle()

perturbed_signal.shape, dictionary.shape

((400, 1), (400, 1000))

In [9]:
from sklearn.metrics import mean_squared_error 
from sklearn.model_selection import GridSearchCV


gs = GridSearchCV(my_bomp, param_grid, cv=5, scoring='neg_mean_squared_error', n_jobs=-1, verbose=1)

gs.fit(dictionary, perturbed_signal)

Fitting 5 folds for each of 480 candidates, totalling 2400 fits
[333]
[333]
[333]
[333]
[644]
[333]
[514]
[34]
[223]
[368]
[412]
[869]
[323]
[386]
[504]
[890]
[776]
[312]
[236]
[503]
[236]
[312]
[503]
[776]
[833]
[799]
[724]
[890]
[77]
[939]
[333]
[333]
[333]
[333]
[333]
[514]
[644]
[412]
[34]
[368]
[223]
[869]
[386]
[323]
[504]
[236]
[312]
[503]
[890]
[236]
[503]
[724]
[799]
[826, 333]
[826, 333]
[776]
[939]
[776]
[890]
[77]
[312]
[826, 333]
[833]
[698, 333]
[808, 684]
[89, 248]
[800, 990]
[565, 795]
[826, 333]
[537, 242]
[889, 200]
[560, 349]
[430, 590]
[696, 895]
[454, 775]
[38, 309]
[899, 111]
[4, 59]
[149, 539]
[513, 454]
[162, 805]
[624, 320]
[557, 897]
[89, 248]
[826, 333]
[800, 990]
[560, 349]
[370, 715]
[190, 37]
[282, 540]
[537, 242]
[698, 333]
[454, 775]
[696, 895]
[513, 454]
[376, 718]
[826, 333]
[4, 59]
[177, 133]
[826, 333]
[808, 684]
[565, 795]
[592, 135]
[162, 805]
[908, 433]
[826, 333]
[889, 200]
[430, 590]
[38, 309]
[899, 111]
[149, 539]
[376, 718]
[370, 715]
[333, 15

In [10]:
gs.best_estimator_

In [11]:
-gs.cv_results_['mean_test_score']

array([0.09975243, 0.12779388, 0.12744629, 0.12914537, 0.12750566,
       0.12745915, 0.09964384, 0.1277956 , 0.12745076, 0.12914436,
       0.12750521, 0.1274586 , 0.08342135, 0.12677103, 0.12681165,
       0.12777034, 0.12804331, 0.127959  , 0.08344374, 0.12680039,
       0.12684667, 0.12776997, 0.12804182, 0.12795912, 0.07508505,
       0.12621397, 0.12740141, 0.1281543 , 0.12869497, 0.12762389,
       0.07507132, 0.12626268, 0.12743959, 0.12813787, 0.12867787,
       0.12762462, 0.07055754, 0.12674394, 0.12420559, 0.1253515 ,
       0.12583792, 0.12949855, 0.07053552, 0.12674331, 0.12425019,
       0.1255831 , 0.12609656, 0.12948856, 0.06335686, 0.12877127,
       0.12862854, 0.12584313, 0.12795596, 0.12457874, 0.06335132,
       0.12875967, 0.12862766, 0.12592112, 0.12797905, 0.1247747 ,
       0.05456651, 0.11994122, 0.12204721, 0.1208641 , 0.12123682,
       0.12526592, 0.05455049, 0.1203763 , 0.12234695, 0.12161922,
       0.12165542, 0.12530505, 0.04913498, 0.12509068, 0.12653