In [5]:
def define_searchspace(trial):
    model_type = trial.suggest_categorical('model', ['fsr_model.LSTM', 'fsr_model.CNN_LSTM', 'fsr_model.ANN'])
    if model_type == 'fsr_model.LSTM':
        trial.suggest_categorical('model_args/hidden_size', [8, 16, 32, 64, 128])
        trial.suggest_int('model_args/num_layer', 1, 8)
    elif model_type == 'fsr_model.CNN_LSTM':
        trial.suggest_categorical('model_args/cnn_hidden_size', [8, 16, 32, 64, 128])
        trial.suggest_categorical('model_args/lstm_hidden_size', [8, 16, 32, 64, 128])
        trial.suggest_int('model_args/cnn_num_layer', 1, 8)
        trial.suggest_int('model_args/lstm_num_layer', 1, 8)
    elif model_type == 'fsr_model.ANN':
        trial.suggest_categorical('model_args/hidden_size', [8, 16, 32, 64, 128])
        trial.suggest_int('model_args/num_layer', 1, 8)
    trial.suggest_categorical('criterion', ['torch.nn.MSELoss'])
    trial.suggest_categorical('optimizer', [
        'torch.optim.Adam',
        'torch.optim.NAdam',
        'torch.optim.Adagrad',
        'torch.optim.RAdam',
        'torch.optim.SGD',
    ])
    trial.suggest_float('optimizer_args/lr', 1e-5, 1e-1, log=True)
    trial.suggest_categorical('scaler', [
        'sklearn.preprocessing.StandardScaler',
        'sklearn.preprocessing.MinMaxScaler',
        'sklearn.preprocessing.RobustScaler',
    ])

In [6]:
import ray.tune
import ray.air
import ray.air.integrations.wandb
import ray.tune.schedulers
import datasource
from trainable import Trainable
import ray.tune.search
import ray.tune.search.optuna

tuner = ray.tune.Tuner(
    trainable=ray.tune.with_resources(
        ray.tune.with_parameters(Trainable, data=datasource.get_data()),
        {'cpu':2},
    ),
    tune_config=ray.tune.TuneConfig(
        num_samples=-1,
        scheduler=ray.tune.schedulers.ASHAScheduler(
            max_t=100,
            grace_period=1,
            reduction_factor=2,
            brackets=1,
            metric='rmse',
            mode='min',
        ),
        search_alg=ray.tune.search.optuna.OptunaSearch(
            space=define_searchspace,
            metric='rmse',
            mode='min',
        ),
    ),
    run_config=ray.air.RunConfig(
        callbacks=[
            ray.air.integrations.wandb.WandbLoggerCallback(project='FSR-prediction'),
        ],
        checkpoint_config=ray.air.CheckpointConfig(
            num_to_keep=3,
            checkpoint_score_attribute='rmse',
            checkpoint_score_order='min',
            checkpoint_frequency=5,
            checkpoint_at_end=True,
        ),
    ),
) 
results = tuner.fit()

[I 2023-07-02 04:23:04,111] A new study created in memory with name: optuna


0,1
Current time:,2023-07-02 04:25:21
Running for:,00:02:17.64
Memory:,4.4/7.7 GiB

Trial name,status,loc,criterion,model,model_args/cnn_hidde n_size,model_args/cnn_num_l ayer,model_args/hidden_si ze,model_args/lstm_hidd en_size,model_args/lstm_num_ layer,model_args/num_layer,optimizer,optimizer_args/lr,scaler,iter,total time (s),rmse,mae,mape
Trainable_bc0d4e96,RUNNING,172.26.215.93:162685,torch.nn.MSELoss,fsr_model.CNN_LSTM,128.0,4.0,,16.0,1.0,,torch.optim.SGD,0.00322906,sklearn.preproc_cab0,2.0,3.55873,63552.7,25885.6,33.2739
Trainable_569dbc2d,PENDING,,torch.nn.MSELoss,fsr_model.ANN,,,16.0,,,7.0,torch.optim.NAdam,0.0296262,sklearn.preproc_cab0,,,,,
Trainable_79a0a71b,TERMINATED,172.26.215.93:161906,torch.nn.MSELoss,fsr_model.LSTM,,,16.0,,,4.0,torch.optim.NAdam,0.0197608,sklearn.preproc_cab0,100.0,90.4757,63913.7,26607.5,66.8692
Trainable_8f8decb8,TERMINATED,172.26.215.93:162153,torch.nn.MSELoss,fsr_model.CNN_LSTM,16.0,6.0,,64.0,3.0,,torch.optim.Adagrad,0.000151045,sklearn.preproc_cb70,1.0,2.30414,595574.0,238298.0,1.3376e+17
Trainable_1148c251,TERMINATED,172.26.215.93:162320,torch.nn.MSELoss,fsr_model.CNN_LSTM,32.0,8.0,,64.0,6.0,,torch.optim.RAdam,0.0111682,sklearn.preproc_cb10,2.0,8.84129,80658.1,27039.2,1.68783e+18
Trainable_792c775a,TERMINATED,172.26.215.93:162528,torch.nn.MSELoss,fsr_model.ANN,,,8.0,,,5.0,torch.optim.Adam,0.0132239,sklearn.preproc_cb70,1.0,0.800435,294226.0,147143.0,1.05496e+20


2023-07-02 04:23:04,129	INFO wandb.py:320 -- Already logged into W&B.


Trial name,date,done,hostname,iterations_since_restore,mae,mape,node_ip,pid,rmse,time_since_restore,time_this_iter_s,time_total_s,timestamp,training_iteration,trial_id
Trainable_1148c251,2023-07-02_04-25-02,True,DESKTOP-0P789CI,2,27039.2,1.68783e+18,172.26.215.93,162320,80658.1,8.84129,4.01877,8.84129,1688239502,2,1148c251
Trainable_792c775a,2023-07-02_04-25-09,True,DESKTOP-0P789CI,1,147143.0,1.05496e+20,172.26.215.93,162528,294226.0,0.800435,0.800435,0.800435,1688239509,1,792c775a
Trainable_79a0a71b,2023-07-02_04-24-40,True,DESKTOP-0P789CI,100,26607.5,66.8692,172.26.215.93,161906,63913.7,90.4757,0.863655,90.4757,1688239480,100,79a0a71b
Trainable_8f8decb8,2023-07-02_04-24-46,True,DESKTOP-0P789CI,1,238298.0,1.3376e+17,172.26.215.93,162153,595574.0,2.30414,2.30414,2.30414,1688239486,1,8f8decb8
Trainable_bc0d4e96,2023-07-02_04-25-17,False,DESKTOP-0P789CI,1,25350.6,30.8426,172.26.215.93,162685,62214.0,2.18227,2.18227,2.18227,1688239517,1,bc0d4e96


[2m[36m(_WandbLoggingActor pid=161957)[0m wandb: Currently logged in as: seokjin. Use `wandb login --relogin` to force relogin
[2m[36m(_WandbLoggingActor pid=161957)[0m wandb: Tracking run with wandb version 0.15.4
[2m[36m(_WandbLoggingActor pid=161957)[0m wandb: Run data is saved locally in /home/seokj/ray_results/Trainable_2023-07-02_04-23-04/Trainable_79a0a71b_1_criterion=torch_nn_MSELoss,model=fsr_model_LSTM,hidden_size=16,num_layer=4,optimizer=torch_optim_NAdam,lr=0.0_2023-07-02_04-23-04/wandb/run-20230702_042311-79a0a71b
[2m[36m(_WandbLoggingActor pid=161957)[0m wandb: Run `wandb offline` to turn off syncing.
[2m[36m(_WandbLoggingActor pid=161957)[0m wandb: Syncing run Trainable_79a0a71b
[2m[36m(_WandbLoggingActor pid=161957)[0m wandb: ‚≠êÔ∏è View project at https://wandb.ai/seokjin/FSR-prediction
[2m[36m(_WandbLoggingActor pid=161957)[0m wandb: üöÄ View run at https://wandb.ai/seokjin/FSR-prediction/runs/79a0a71b
[2m[36m(_WandbLoggingActor pid=161957)[0m