In [1]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

from ray import air, tune
from ray.tune.schedulers import ASHAScheduler

In [2]:
# extra imports for tablebench example
import rtdl
from tablebench.core import TabularDataset, TabularDatasetConfig

from tablebench.datasets.experiment_configs import EXPERIMENT_CONFIGS
from tablebench.models import get_estimator

In [3]:
experiment = "adult"
expt_config = EXPERIMENT_CONFIGS[experiment]

In [4]:
dataset_config = TabularDatasetConfig()
dset = TabularDataset(experiment,
                      config=dataset_config,
                      splitter=expt_config.splitter,
                      grouper=expt_config.grouper,
                      preprocessor_config=expt_config.preprocessor_config,
                      **expt_config.tabular_dataset_kwargs)
train_loader = dset.get_dataloader("train", 512)
loaders = {s: dset.get_dataloader(s, 2048) for s in ("validation", "test")}

[DEBUG] not downloading https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data; exists at tmp/adult.data
[DEBUG] not downloading https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.names; exists at tmp/adult.names
[DEBUG] not downloading https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test; exists at tmp/adult.test
[DEBUG] dropping data columns not in FeatureList: ['fnlwgt']
[DEBUG] checking feature Age
[DEBUG] casting feature Age from dtype int64 to dtype float
[DEBUG] checking feature Workclass
[DEBUG] casting feature Workclass from dtype object to dtype CategoricalDtype
[DEBUG] checking feature Education-Num
[DEBUG] casting feature Education-Num from dtype int64 to dtype CategoricalDtype
[DEBUG] checking feature Marital Status
[DEBUG] casting feature Marital Status from dtype object to dtype CategoricalDtype
[DEBUG] checking feature Occupation
[DEBUG] casting feature Occupation from dtype object to dtype CategoricalDtype
[D

In [5]:
def train_adult(config):
    loss_fn = F.binary_cross_entropy_with_logits
    
    model = get_estimator("mlp", d_in=dset.X_shape[1], d_layers=[config["d_hidden"]] * config["num_layers"])
    optimizer = (
        model.make_default_optimizer()
        if isinstance(model, rtdl.FTTransformer)
        else torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"]))
    
    # Fit the model; results on validation split are reported to tune.
    model.fit(train_loader, optimizer, loss_fn, n_epochs=5, other_loaders=loaders, tune_report_split="validation")

In [8]:
search_space = {
    # Sample a float uniformly between 0.0001 and 0.1, while
    # sampling in log space and rounding to multiples of 0.00005
    "lr": tune.qloguniform(1e-4, 1e-1, 5e-5),
    
    # Sample a float uniformly between 0 and 1,
    # rounding to multiples of 0.1
    "weight_decay": tune.quniform(0., 1., 0.1),
    
    # Random integer between 1 and 4
    "num_layers": tune.randint(1,4),
    
    # Random integer from set
    "d_hidden": tune.choice([64, 128, 256, 512])
}


tuner = tune.Tuner(
    train_adult,
    param_space=search_space,
    tune_config=tune.tune_config.TuneConfig(num_samples=5),
    run_config=air.RunConfig(local_dir="./results", name="test_experiment"),
)
results = tuner.fit()

0,1
Current time:,2022-11-23 16:59:45
Running for:,00:01:00.04
Memory:,5.4/8.0 GiB

Trial name,status,loc,d_hidden,lr,num_layers,weight_decay,iter,total time (s),_metric
train_adult_00348_00000,TERMINATED,127.0.0.1:14610,512,0.08845,3,0.5,5,37.3202,0.849036
train_adult_00348_00001,TERMINATED,127.0.0.1:14620,256,0.0616,2,0.0,5,20.4035,0.856774
train_adult_00348_00002,TERMINATED,127.0.0.1:14623,128,0.0013,1,0.1,5,12.8179,0.854686
train_adult_00348_00003,TERMINATED,127.0.0.1:14626,512,0.0427,3,0.0,5,21.8927,0.857389
train_adult_00348_00004,TERMINATED,127.0.0.1:14623,256,0.0129,3,0.3,5,11.2077,0.856283




[2m[36m(train_adult pid=14610)[0m Epoch 001 train score: 0.8482 | validation score: 0.8455 | test score: 0.8480




[2m[36m(train_adult pid=14620)[0m Epoch 001 train score: 0.8511 | validation score: 0.8511 | test score: 0.8509




[2m[36m(train_adult pid=14623)[0m Epoch 001 train score: 0.8309 | validation score: 0.8299 | test score: 0.8316


Trial name,_metric,date,done,episodes_total,experiment_id,experiment_tag,hostname,iterations_since_restore,node_ip,pid,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
train_adult_00348_00000,0.849036,2022-11-23_16-59-42,True,,adbea1b6041c407799163f08f6097dd7,"0_d_hidden=512,lr=0.0885,num_layers=3,weight_decay=0.5000",Joshuas-MacBook-Pro-10.local,5,127.0.0.1,14610,37.3202,3.94653,37.3202,1669240782,0,,5,00348_00000,0.00669622
train_adult_00348_00001,0.856774,2022-11-23_16-59-33,True,,3d668c8411d8422081205db6efafbc28,"1_d_hidden=256,lr=0.0616,num_layers=2,weight_decay=0.0000",Joshuas-MacBook-Pro-10.local,5,127.0.0.1,14620,20.4035,2.21829,20.4035,1669240773,0,,5,00348_00001,0.00395894
train_adult_00348_00002,0.854686,2022-11-23_16-59-30,True,,5ca2c5154bc24dc685ead74e9ce865a0,"2_d_hidden=128,lr=0.0013,num_layers=1,weight_decay=0.1000",Joshuas-MacBook-Pro-10.local,5,127.0.0.1,14623,12.8179,2.11711,12.8179,1669240770,0,,5,00348_00002,0.00308609
train_adult_00348_00003,0.857389,2022-11-23_16-59-45,True,,07b32efe9eac4daeb8b05191982228b1,"3_d_hidden=512,lr=0.0427,num_layers=3,weight_decay=0.0000",Joshuas-MacBook-Pro-10.local,5,127.0.0.1,14626,21.8927,2.50624,21.8927,1669240785,0,,5,00348_00003,0.00345922
train_adult_00348_00004,0.856283,2022-11-23_16-59-41,True,,5ca2c5154bc24dc685ead74e9ce865a0,"4_d_hidden=256,lr=0.0129,num_layers=3,weight_decay=0.3000",Joshuas-MacBook-Pro-10.local,5,127.0.0.1,14623,11.2077,2.21475,11.2077,1669240781,0,,5,00348_00004,0.00308609


[2m[36m(train_adult pid=14623)[0m Epoch 002 train score: 0.8483 | validation score: 0.8511 | test score: 0.8498
[2m[36m(train_adult pid=14620)[0m Epoch 002 train score: 0.8489 | validation score: 0.8444 | test score: 0.8458
[2m[36m(train_adult pid=14623)[0m Epoch 003 train score: 0.8527 | validation score: 0.8528 | test score: 0.8514
[2m[36m(train_adult pid=14620)[0m Epoch 003 train score: 0.8586 | validation score: 0.8559 | test score: 0.8548
[2m[36m(train_adult pid=14623)[0m Epoch 004 train score: 0.8547 | validation score: 0.8542 | test score: 0.8557
[2m[36m(train_adult pid=14610)[0m Epoch 002 train score: 0.8552 | validation score: 0.8536 | test score: 0.8533
[2m[36m(train_adult pid=14626)[0m Epoch 001 train score: 0.8512 | validation score: 0.8526 | test score: 0.8468
[2m[36m(train_adult pid=14623)[0m Epoch 005 train score: 0.8563 | validation score: 0.8547 | test score: 0.8562
[2m[36m(train_adult pid=14620)[0m Epoch 004 train score: 0.8605 | validation 

2022-11-23 16:59:45,369	INFO tune.py:777 -- Total run time: 61.44 seconds (60.03 seconds for the tuning loop).


[2m[36m(train_adult pid=14626)[0m Epoch 005 train score: 0.8628 | validation score: 0.8574 | test score: 0.8538


In [9]:
print(results[0].log_dir)
results[0].metrics_dataframe

/Users/jpgard/Documents/github/tablebench/notebooks/results/test_experiment/train_adult_00348_00000_0_d_hidden=512,lr=0.0885,num_layers=3,weight_decay=0.5000_2022-11-23_16-58-55


Unnamed: 0,_metric,time_this_iter_s,done,timesteps_total,episodes_total,training_iteration,trial_id,experiment_id,date,timestamp,time_total_s,pid,hostname,node_ip,time_since_restore,timesteps_since_restore,iterations_since_restore,warmup_time
0,0.845474,3.81252,False,,,1,00348_00000,adbea1b6041c407799163f08f6097dd7,2022-11-23_16-59-09,1669240749,3.81252,14610,Joshuas-MacBook-Pro-10.local,127.0.0.1,3.81252,0,1,0.006696
1,0.853581,19.890008,False,,,2,00348_00000,adbea1b6041c407799163f08f6097dd7,2022-11-23_16-59-28,1669240768,23.702528,14610,Joshuas-MacBook-Pro-10.local,127.0.0.1,23.702528,0,2,0.006696
2,0.850878,5.514633,False,,,3,00348_00000,adbea1b6041c407799163f08f6097dd7,2022-11-23_16-59-34,1669240774,29.217161,14610,Joshuas-MacBook-Pro-10.local,127.0.0.1,29.217161,0,3,0.006696
3,0.843262,4.156528,False,,,4,00348_00000,adbea1b6041c407799163f08f6097dd7,2022-11-23_16-59-38,1669240778,33.373689,14610,Joshuas-MacBook-Pro-10.local,127.0.0.1,33.373689,0,4,0.006696
4,0.849036,3.946527,False,,,5,00348_00000,adbea1b6041c407799163f08f6097dd7,2022-11-23_16-59-42,1669240782,37.320216,14610,Joshuas-MacBook-Pro-10.local,127.0.0.1,37.320216,0,5,0.006696


In [None]:
dfs = {result.log_dir: result.metrics_dataframe for result in results}
[d._metric.plot() for d in dfs.values()]

In [None]:
list(dfs.values())[0]

In [None]:
results.__dict__.keys()