In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

import sladsnet.code.training as training
from sladsnet.code.erd import SladsSklearnModel 
from sladsnet.input_params import TrainingInputParams, SladsModelParams
from tqdm.notebook import tqdm

In [3]:
base_path = Path.cwd().parent

In [4]:
cameraman_path = base_path / 'ResultsAndData/TrainingData/cameraman/'

In [16]:
train_params = TrainingInputParams(input_images_path=cameraman_path / 'Images',
                                   output_dir=cameraman_path, 
                                   num_repeats_per_mask=1,
                                   measurements_per_initial_mask=10,
                                   test_c_values=[2, 4, 8, 16, 32, 64])

In [17]:
train_params

TrainingInputParams(input_images_path=PosixPath('/gpfs/fs1/home/skandel/code/SLADS-Net/ResultsAndData/TrainingData/cameraman/Images'), output_dir=PosixPath('/gpfs/fs1/home/skandel/code/SLADS-Net/ResultsAndData/TrainingData/cameraman'), initial_scan_ratio=0.01, stop_ratio=0.8, scan_method='random', scan_type='transmission', sampling_type='fast_limited', num_repeats_per_mask=1, measurements_per_initial_mask=10, random_seed=111, training_split=0.9, test_c_values=[2, 4, 8, 16, 32, 64], calculate_full_erd_per_step=True)

In [18]:
training.generate_training_databases(train_params)

  0%|          | 0/6 [00:00<?, ?it/s]

Testing for c=   2; Samples:   0%|          | 0/1 [00:00<?, ?it/s]

Masks:   0%|          | 0/1 [00:00<?, ?it/s]

Iterating through test sampling ratios.:   0%|          | 0/10 [00:00<?, ?it/s]

Testing for c=   4; Samples:   0%|          | 0/1 [00:00<?, ?it/s]

Masks:   0%|          | 0/1 [00:00<?, ?it/s]

Iterating through test sampling ratios.:   0%|          | 0/10 [00:00<?, ?it/s]

Testing for c=   8; Samples:   0%|          | 0/1 [00:00<?, ?it/s]

Masks:   0%|          | 0/1 [00:00<?, ?it/s]

Iterating through test sampling ratios.:   0%|          | 0/10 [00:00<?, ?it/s]

Testing for c=  16; Samples:   0%|          | 0/1 [00:00<?, ?it/s]

Masks:   0%|          | 0/1 [00:00<?, ?it/s]

Iterating through test sampling ratios.:   0%|          | 0/10 [00:00<?, ?it/s]

Testing for c=  32; Samples:   0%|          | 0/1 [00:00<?, ?it/s]

Masks:   0%|          | 0/1 [00:00<?, ?it/s]

Iterating through test sampling ratios.:   0%|          | 0/10 [00:00<?, ?it/s]

Testing for c=  64; Samples:   0%|          | 0/1 [00:00<?, ?it/s]

Masks:   0%|          | 0/1 [00:00<?, ?it/s]

Iterating through test sampling ratios.:   0%|          | 0/10 [00:00<?, ?it/s]

In [19]:

output_path = base_path / 'ResultsAndData/TrainingData/cameraman'

In [20]:
val_scores = {}
for c_value in tqdm(train_params.test_c_values):

    print('test c', c_value)

    train_db_path = output_path / Path(f'c_{c_value}') / f'training_database.pkl'

    val_scores[c_value] = []

    for activation in ['relu']:
        save_path = train_db_path.parent / f'erd_model_{activation}.pkl'
        
        erd_model, _ = training.fit_erd_model(training_db_path=train_db_path,
                                              model_params=SladsModelParams(activation=activation),
                                              save_path=save_path,
                                              training_split=0.8, 
                                              random_seed = 111)
        score = training.validate_erd_model_r_squared(training_db_path=train_db_path,
                                                      erd_model_path=save_path,
                                                      training_split=0.8, 
                                                      random_seed = 111)
        print('R squared score', score)
        val_scores[c_value].append(score)
        
    


  0%|          | 0/6 [00:00<?, ?it/s]

test c 2
Validation score is 0.8608224749396053
R squared score 0.8608224749396053
test c 4
Validation score is 0.8544570249680723
R squared score 0.8544570249680723
test c 8
Validation score is 0.5872562725344859
R squared score 0.5872562725344859
test c 16
Validation score is 0.5865042118785144
R squared score 0.5865042118785144
test c 32
Validation score is 0.5489955135457498
R squared score 0.5489955135457498
test c 64
Validation score is 0.6014144720680303
R squared score 0.6014144720680303


### Test on a different database

In [31]:
val_scores = {}

for c_value_train in train_params.test_c_values:
    print(f"For training c {c_value_train}")
    for c_value in tqdm(train_params.test_c_values):
        model_path = output_path  / f'c_{c_value_train}/erd_model_relu.pkl'

        db_path = base_path /  f'ResultsAndData/TrainingData/93/c_{c_value}/training_database.pkl'

        val_scores[c_value] = []

        score = training.validate_erd_model_r_squared(training_db_path=db_path,
                                                      erd_model_path=model_path,
                                                      training_split=0.1, 
                                                      random_seed = 111)
        print('R squared score', score)
        val_scores[c_value].append(score)
        
    


For training c 2


  0%|          | 0/6 [00:00<?, ?it/s]

Validation score is -0.7307725117217063
R squared score -0.7307725117217063
Validation score is -5.558852323991638
R squared score -5.558852323991638
Validation score is -4.908835300362522
R squared score -4.908835300362522
Validation score is -102.53879994668428
R squared score -102.53879994668428
Validation score is -100.00398826682067
R squared score -100.00398826682067
Validation score is -75.01764612564924
R squared score -75.01764612564924
For training c 4


  0%|          | 0/6 [00:00<?, ?it/s]

Validation score is -1.645783563486892
R squared score -1.645783563486892
Validation score is -30.36775898602022
R squared score -30.36775898602022
Validation score is -190.70457263227587
R squared score -190.70457263227587
Validation score is -1163.1733491527689
R squared score -1163.1733491527689
Validation score is -922.6472551965551
R squared score -922.6472551965551
Validation score is -1111.7538519365435
R squared score -1111.7538519365435
For training c 8


  0%|          | 0/6 [00:00<?, ?it/s]

Validation score is -9.692923349218255
R squared score -9.692923349218255
Validation score is -69.41398440088132
R squared score -69.41398440088132
Validation score is -165.20101030706513
R squared score -165.20101030706513
Validation score is -1430.4980379230337
R squared score -1430.4980379230337
Validation score is -1566.8235752540572
R squared score -1566.8235752540572
Validation score is -1807.2038895849905
R squared score -1807.2038895849905
For training c 16


  0%|          | 0/6 [00:00<?, ?it/s]

Validation score is -0.08581632706375819
R squared score -0.08581632706375819
Validation score is -0.138865834442917
R squared score -0.138865834442917
Validation score is -0.28114563604233145
R squared score -0.28114563604233145
Validation score is -2.9516832469287833
R squared score -2.9516832469287833
Validation score is -3.2000719025752833
R squared score -3.2000719025752833
Validation score is -3.8246206177043742
R squared score -3.8246206177043742
For training c 32


  0%|          | 0/6 [00:00<?, ?it/s]

Validation score is -1.414086521826131
R squared score -1.414086521826131
Validation score is -14.993106368860763
R squared score -14.993106368860763
Validation score is -23.19903300683912
R squared score -23.19903300683912
Validation score is -271.5013143838746
R squared score -271.5013143838746
Validation score is -319.530424790054
R squared score -319.530424790054
Validation score is -350.98577022798884
R squared score -350.98577022798884
For training c 64


  0%|          | 0/6 [00:00<?, ?it/s]

Validation score is -30.41664662959345
R squared score -30.41664662959345
Validation score is -258.4177751403839
R squared score -258.4177751403839
Validation score is -486.27815771224095
R squared score -486.27815771224095
Validation score is -5521.129776262753
R squared score -5521.129776262753
Validation score is -5737.691771738541
R squared score -5737.691771738541
Validation score is -6360.783354937342
R squared score -6360.783354937342
