In [None]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

from ray import air, tune
from ray.tune.schedulers import ASHAScheduler

# Tuner

In [None]:
# extra imports for tablebench example
import rtdl
from tablebench.core import TabularDataset, TabularDatasetConfig

from tablebench.datasets.experiment_configs import EXPERIMENT_CONFIGS
from tablebench.models import get_estimator

In [None]:
experiment = "adult"
expt_config = EXPERIMENT_CONFIGS[experiment]

In [None]:
dataset_config = TabularDatasetConfig()
dset = TabularDataset(experiment,
                      config=dataset_config,
                      splitter=expt_config.splitter,
                      grouper=expt_config.grouper,
                      preprocessor_config=expt_config.preprocessor_config,
                      **expt_config.tabular_dataset_kwargs)
train_loader = dset.get_dataloader("train", 512)
loaders = {s: dset.get_dataloader(s, 2048) for s in ("validation", "test")}

In [None]:
def train_adult(config):
    loss_fn = F.binary_cross_entropy_with_logits
    
    model = get_estimator("mlp", d_in=dset.X_shape[1], d_layers=[config["d_hidden"]] * config["num_layers"])
    optimizer = (
        model.make_default_optimizer()
        if isinstance(model, rtdl.FTTransformer)
        else torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"]))
    
    # Fit the model; results on validation split are reported to tune.
    model.fit(train_loader, optimizer, loss_fn, n_epochs=5, other_loaders=loaders, tune_report_split="validation")

In [None]:
search_space = {
    # Sample a float uniformly between 0.0001 and 0.1, while
    # sampling in log space and rounding to multiples of 0.00005
    "lr": tune.qloguniform(1e-4, 1e-1, 5e-5),
    
    # Sample a float uniformly between 0 and 1,
    # rounding to multiples of 0.1
    "weight_decay": tune.quniform(0., 1., 0.1),
    
    # Random integer between 1 and 4
    "num_layers": tune.randint(1,4),
    
    # Random integer from set
    "d_hidden": tune.choice([64, 128, 256, 512])
}


tuner = tune.Tuner(
    train_adult,
    param_space=search_space,
    tune_config=tune.tune_config.TuneConfig(num_samples=5),
    run_config=air.RunConfig(local_dir="./results", name="test_experiment"),
)
results = tuner.fit()

In [None]:
print(results[0].log_dir)
results[0].metrics_dataframe

In [None]:
dfs = {result.log_dir: result.metrics_dataframe for result in results}
[d._metric.plot() for d in dfs.values()]

In [None]:
list(dfs.values())[0]

In [None]:
results.__dict__.keys()

# Performance Improvements

## Ray docs example

via https://docs.ray.io/en/latest/ray-air/examples/torch_image_example.html

In [1]:
import ray
from ray.data.datasource import SimpleTorchDatasource
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)


def train_dataset_factory():
    return torchvision.datasets.CIFAR10(
        root="./data", download=True, train=True, transform=transform
    )


def test_dataset_factory():
    return torchvision.datasets.CIFAR10(
        root="./data", download=True, train=False, transform=transform
    )


train_dataset: ray.data.Dataset = ray.data.read_datasource(
    SimpleTorchDatasource(), dataset_factory=train_dataset_factory
)
test_dataset: ray.data.Dataset = ray.data.read_datasource(
    SimpleTorchDatasource(), dataset_factory=test_dataset_factory
)

2022-11-24 22:36:57,514	INFO worker.py:1525 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


[2m[36m(_execute_read_task_nosplit pid=7758)[0m Files already downloaded and verified




[2m[36m(_execute_read_task_nosplit pid=7758)[0m Files already downloaded and verified


In [2]:
train_dataset



In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
from ray import train
from ray.air import session, Checkpoint
from ray.train.torch import TorchCheckpoint
import torch.nn as nn
import torch.optim as optim
import torchvision


def train_loop_per_worker(config):
    model = train.torch.prepare_model(Net())

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    train_dataset_shard = session.get_dataset_shard("train")

    for epoch in range(2):
        running_loss = 0.0
        train_dataset_batches = train_dataset_shard.iter_torch_batches(
            batch_size=config["batch_size"],
        )
        for i, batch in enumerate(train_dataset_batches):
            # get the inputs and labels
            inputs, labels = batch["image"], batch["label"]

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
                running_loss = 0.0

        metrics = dict(running_loss=running_loss)
        checkpoint = TorchCheckpoint.from_state_dict(model.module.state_dict())
        session.report(metrics, checkpoint=checkpoint)

In [5]:
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={"batch_size": 2},
    datasets={"train": train_dataset},
    scaling_config=ScalingConfig(num_workers=2),
)
result = trainer.fit()
latest_checkpoint = result.checkpoint

0,1
Current time:,2022-11-24 22:38:07
Running for:,00:00:39.30
Memory:,5.2/8.0 GiB

Trial name,# failures,error file
TorchTrainer_7c272_00000,1,/Users/jpgard/ray_results/TorchTrainer_2022-11-24_22-37-28/TorchTrainer_7c272_00000_0_2022-11-24_22-37-28/error.txt

Trial name,status,loc
TorchTrainer_7c272_00000,ERROR,127.0.0.1:7771


[2m[36m(RayTrainWorker pid=7779)[0m 2022-11-24 22:37:38,538	INFO config.py:88 -- Setting up process group for: env:// [rank=0, world_size=2]


[2m[36m(_map_block_nosplit pid=7789)[0m Files already downloaded and verified


[2m[33m(raylet)[0m [2022-11-24 22:38:07,170 E 7748 352103] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2022-11-24_22-36-53_113079_7730 is over 95% full, available space: 12514869248; capacity: 250685575168. Object creation will fail if spilling is required.
2022-11-24 22:38:07,661	ERROR trial_runner.py:993 -- Trial TorchTrainer_7c272_00000: Error processing event.
Traceback (most recent call last):
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/tune/execution/ray_trial_executor.py", line 1050, in get_next_executor_event
    future_result = ray.get(ready_future)
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
    return func(*args, **kwargs)
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/_private/worker.py", line 2289, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError: [36m

Trial name,date,experiment_id,hostname,node_ip,pid,timestamp,trial_id
TorchTrainer_7c272_00000,2022-11-24_22-37-33,390045c512d349feabb57aa39b055f79,Joshuas-MacBook-Pro-10.local,127.0.0.1,7771,1669347453,7c272_00000


2022-11-24 22:38:07,863	ERROR tune.py:773 -- Trials did not complete: [TorchTrainer_7c272_00000]
2022-11-24 22:38:07,864	INFO tune.py:778 -- Total run time: 39.53 seconds (39.29 seconds for the tuning loop).


RayTaskError: [36mray::_Inner.train()[39m (pid=7771, ip=127.0.0.1, repr=TorchTrainer)
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/tune/trainable/trainable.py", line 355, in train
    raise skipped from exception_cause(skipped)
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/tune/trainable/function_trainable.py", line 328, in entrypoint
    self._status_reporter.get_checkpoint(),
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/train/base_trainer.py", line 475, in _trainable_func
    super()._trainable_func(self._merged_config, reporter, checkpoint_dir)
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/tune/trainable/function_trainable.py", line 651, in _trainable_func
    output = fn()
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/train/base_trainer.py", line 390, in train_func
    trainer.training_loop()
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/train/data_parallel_trainer.py", line 368, in training_loop
    checkpoint_strategy=None,
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/train/trainer.py", line 154, in __init__
    checkpoint_strategy=checkpoint_strategy,
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/train/trainer.py", line 179, in _start_training
    lambda: self._backend_executor.start_training(
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/train/trainer.py", line 188, in _run_with_error_handling
    return func()
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/train/trainer.py", line 182, in <lambda>
    checkpoint=checkpoint,
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/train/_internal/backend_executor.py", line 332, in start_training
    self.dataset_shards = dataset_spec.get_dataset_shards(actors)
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/train/_internal/dataset_spec.py", line 211, in get_dataset_shards
    locality_hints=training_worker_handles,
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/data/dataset.py", line 984, in split
    blocks = self._plan.execute()
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/data/_internal/plan.py", line 309, in execute
    blocks, clear_input_blocks, self._run_by_consumer
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/data/_internal/plan.py", line 672, in __call__
    fn_constructor_kwargs=self.fn_constructor_kwargs,
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/data/_internal/compute.py", line 128, in _apply
    raise e from None
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/data/_internal/compute.py", line 115, in _apply
    results = map_bar.fetch_until_complete(refs)
  File "/Users/jpgard/Documents/github/tablebench/venv3.7/lib/python3.7/site-packages/ray/data/_internal/progress_bar.py", line 75, in fetch_until_complete
    for ref, result in zip(done, ray.get(done)):
ray.exceptions.RayTaskError: [36mray::_map_block_nosplit()[39m (pid=7789, ip=127.0.0.1)
ray.exceptions.OutOfDiskError: Local disk is full
The object cannot be created because the local object store is full and the local disk's utilization is over capacity (95% by default).Tip: Use `df` on this node to check disk usage and `ray memory` to check object store memory usage.

[2m[33m(raylet)[0m [2022-11-24 22:38:17,232 E 7748 352103] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2022-11-24_22-36-53_113079_7730 is over 95% full, available space: 11884355584; capacity: 250685575168. Object creation will fail if spilling is required.
[2m[33m(raylet)[0m [2022-11-24 22:38:27,302 E 7748 352103] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2022-11-24_22-36-53_113079_7730 is over 95% full, available space: 11884359680; capacity: 250685575168. Object creation will fail if spilling is required.
[2m[33m(raylet)[0m [2022-11-24 22:38:37,370 E 7748 352103] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2022-11-24_22-36-53_113079_7730 is over 95% full, available space: 11883663360; capacity: 250685575168. Object creation will fail if spilling is required.
[2m[33m(raylet)[0m [2022-11-24 22:38:47,427 E 7748 352103] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2022-11-24_22-36-53_113079_7730 is over 95% full, available space: 1188