In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

import sladsnet.code.training as training
from sladsnet.code.erd import SladsSklearnModel 
from sladsnet.input_params import TrainingInputParams, SladsModelParams
from tqdm.notebook import tqdm

In [3]:
base_path = Path.cwd().parent

In [4]:
cameraman_path = base_path / 'ResultsAndData/TrainingData/cameraman/'

In [9]:
train_params = TrainingInputParams(input_images_path=cameraman_path / 'Images',
                                   output_dir=cameraman_path, 
                                   num_repeats_per_mask=1,
                                   measurements_per_initial_mask=10,
                                   test_c_values=[2, 4, 8, 16, 32, 64])

In [10]:
train_params

TrainingInputParams(input_images_path=PosixPath('/gpfs/fs1/home/skandel/code/SLADS-Net/ResultsAndData/TrainingData/cameraman/Images'), output_dir=PosixPath('/gpfs/fs1/home/skandel/code/SLADS-Net/ResultsAndData/TrainingData/cameraman'), initial_scan_ratio=0.01, initial_mask_type='random', stop_ratio=0.8, scan_method='random', scan_type='transmission', sampling_type='fast_limited', num_repeats_per_mask=1, measurements_per_initial_mask=10, random_seed=111, training_split=0.9, test_c_values=[2, 4, 8, 16, 32, 64], calculate_full_erd_per_step=True)

In [11]:
training.generate_training_databases(train_params)

  0%|          | 0/6 [00:00<?, ?it/s]

Testing for c=   2; Samples:   0%|          | 0/1 [00:00<?, ?it/s]

Masks:   0%|          | 0/1 [00:00<?, ?it/s]

Iterating through test sampling ratios.:   0%|          | 0/10 [00:00<?, ?it/s]

Testing for c=   4; Samples:   0%|          | 0/1 [00:00<?, ?it/s]

Masks:   0%|          | 0/1 [00:00<?, ?it/s]

Iterating through test sampling ratios.:   0%|          | 0/10 [00:00<?, ?it/s]

Testing for c=   8; Samples:   0%|          | 0/1 [00:00<?, ?it/s]

Masks:   0%|          | 0/1 [00:00<?, ?it/s]

Iterating through test sampling ratios.:   0%|          | 0/10 [00:00<?, ?it/s]

Testing for c=  16; Samples:   0%|          | 0/1 [00:00<?, ?it/s]

Masks:   0%|          | 0/1 [00:00<?, ?it/s]

Iterating through test sampling ratios.:   0%|          | 0/10 [00:00<?, ?it/s]

Testing for c=  32; Samples:   0%|          | 0/1 [00:00<?, ?it/s]

Masks:   0%|          | 0/1 [00:00<?, ?it/s]

Iterating through test sampling ratios.:   0%|          | 0/10 [00:00<?, ?it/s]

Testing for c=  64; Samples:   0%|          | 0/1 [00:00<?, ?it/s]

Masks:   0%|          | 0/1 [00:00<?, ?it/s]

Iterating through test sampling ratios.:   0%|          | 0/10 [00:00<?, ?it/s]

In [12]:

output_path = base_path / 'ResultsAndData/TrainingData/cameraman'

In [13]:
val_scores = {}
for c_value in tqdm(train_params.test_c_values):

    print('test c', c_value)

    train_db_path = output_path / Path(f'c_{c_value}') / f'training_database.pkl'

    val_scores[c_value] = []

    save_path = train_db_path.parent / f'erd_model_relu.pkl'

    erd_model, _ = training.fit_erd_model(training_db_path=train_db_path,
                                          model_params=SladsModelParams(max_iter=50),
                                          save_path=save_path,
                                          training_split=0.8, 
                                          random_seed = 111)
    score = training.validate_erd_model_r_squared(training_db_path=train_db_path,
                                                  erd_model_path=save_path,
                                                  training_split=0.8, 
                                                  random_seed = 111)
    print('R squared score', score)
    val_scores[c_value].append(score)
        
    


  0%|          | 0/6 [00:00<?, ?it/s]

test c 2




Validation score is 0.8396186554723385
R squared score 0.8396186554723385
test c 4




Validation score is 0.830788028453655
R squared score 0.830788028453655
test c 8




Validation score is 0.6004770412287905
R squared score 0.6004770412287905
test c 16




Validation score is 0.5878485402930633
R squared score 0.5878485402930633
test c 32




Validation score is 0.5651887365869837
R squared score 0.5651887365869837
test c 64




Validation score is 0.6184885106300815
R squared score 0.6184885106300815


### Test on a different database

In [14]:
val_scores = {}

for c_value_train in train_params.test_c_values:
    print(f"For training c {c_value_train}")
    for c_value in tqdm(train_params.test_c_values):
        model_path = output_path  / f'c_{c_value_train}/erd_model_relu.pkl'

        db_path = base_path /  f'ResultsAndData/TrainingData/93/c_{c_value}/training_database.pkl'

        val_scores[c_value] = []

        score = training.validate_erd_model_r_squared(training_db_path=db_path,
                                                      erd_model_path=model_path,
                                                      training_split=0.1, 
                                                      random_seed = 111)
        print('R squared score', score)
        val_scores[c_value].append(score)
        
    


For training c 2


  0%|          | 0/6 [00:00<?, ?it/s]

Validation score is -35.58546639403132
R squared score -35.58546639403132
Validation score is -253.91279389019058
R squared score -253.91279389019058
Validation score is -3662.95931385278
R squared score -3662.95931385278
Validation score is -2059.9515470586043
R squared score -2059.9515470586043
Validation score is -3326.166371467548
R squared score -3326.166371467548
Validation score is -2614.581722916016
R squared score -2614.581722916016
For training c 4


  0%|          | 0/6 [00:00<?, ?it/s]

Validation score is -2.3142547403402665
R squared score -2.3142547403402665
Validation score is -20.816272020665263
R squared score -20.816272020665263
Validation score is -384.84173276401344
R squared score -384.84173276401344
Validation score is -146.9806945688581
R squared score -146.9806945688581
Validation score is -309.7253318182813
R squared score -309.7253318182813
Validation score is -224.09596759452907
R squared score -224.09596759452907
For training c 8


  0%|          | 0/6 [00:00<?, ?it/s]

Validation score is -0.06921462870912709
R squared score -0.06921462870912709
Validation score is -1.2594007627464756
R squared score -1.2594007627464756
Validation score is -28.493263146556465
R squared score -28.493263146556465
Validation score is -7.217609684765604
R squared score -7.217609684765604
Validation score is -22.203044132306562
R squared score -22.203044132306562
Validation score is -14.818490871435294
R squared score -14.818490871435294
For training c 16


  0%|          | 0/6 [00:00<?, ?it/s]

Validation score is -0.1509448024578628
R squared score -0.1509448024578628
Validation score is -1.8706494920420753
R squared score -1.8706494920420753
Validation score is -37.3779175614544
R squared score -37.3779175614544
Validation score is -9.516837549069342
R squared score -9.516837549069342
Validation score is -30.585486885708544
R squared score -30.585486885708544
Validation score is -20.048240680335283
R squared score -20.048240680335283
For training c 32


  0%|          | 0/6 [00:00<?, ?it/s]

Validation score is -0.18260302020782748
R squared score -0.18260302020782748
Validation score is -2.3409752064988947
R squared score -2.3409752064988947
Validation score is -49.87020738984036
R squared score -49.87020738984036
Validation score is -12.243508154381782
R squared score -12.243508154381782
Validation score is -36.777259267667674
R squared score -36.777259267667674
Validation score is -25.02343333329216
R squared score -25.02343333329216
For training c 64


  0%|          | 0/6 [00:00<?, ?it/s]

Validation score is 0.041473671055817896
R squared score 0.041473671055817896
Validation score is -0.004170542981893233
R squared score -0.004170542981893233
Validation score is -4.541022459360924
R squared score -4.541022459360924
Validation score is -1.3259857268197162
R squared score -1.3259857268197162
Validation score is -3.9291725753856053
R squared score -3.9291725753856053
Validation score is -2.3914548031435143
R squared score -2.3914548031435143
