-
Notifications
You must be signed in to change notification settings - Fork 7k
Description
What happened + What you expected to happen
To collect statistics I'd like to run many TemporalFusionTransformer pipelines with different random number seeds. Using Ray to parallelize these tasks leads to a situation where the remote dataloaders seem to impede each other, leading to very slow training. I only have access to cpus.
Running the code snippet below (the first 80 lines copied from here) with n_tasks = 1 shows training occurring faster than 3 it/s:
Epoch 1: 9%|▉ | 15/161 [00:04<00:39, 3.71it/s, v_num=5.97e+7, train_loss_step=111.0, val_loss=158.0, train_loss_epoch=161.0]
But once n_tasks > 1 loading drops to <= 0.1 it/s. Here's what it looks like for n_tasks = 2:
Epoch 0: 1%| | 2/161 [01:57<2:35:59, 0.02it/s, v_num=5.97e+7, train_loss_step=394.0]
Epoch 0: 10%|▉ | 16/161 [01:59<18:00, 0.13it/s, v_num=5.97e+7, train_loss_step=184.0]
Completely removing Ray from the code, running it serially, returns training to > 3 it/s.
Might the concurrent dataloaders all be competing each other for resources on the driver process, instead of using their own worker process? I've also tried using Actors instead of Tasks, but the problem persists. Closest issue I've found thus far is this one, which is still open. Adding multiprocessing_context="fork", as indicated here (with num_workers=1), shows no effect.
Versions / Dependencies
ubuntu 20.04.6
slurm 23.02.5
lightning 2.1.2
python 3.10.13
pytorch-forecasting 1.0.0
pytorch-lightning 2.1.2
pytorch-optimizer 2.12.0
ray 2.9.0
Reproduction script
import lightning.pytorch as pl
import numpy as np
import ray
from pytorch_forecasting import (TemporalFusionTransformer,
TimeSeriesDataSet)
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.data.examples import get_stallion_data
from pytorch_forecasting.metrics import QuantileLoss
data = get_stallion_data()
# add time index
data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
data["time_idx"] -= data["time_idx"].min()
# add additional features
data["month"] = data.date.dt.month.astype(str).astype("category") # categories have be strings
data["log_volume"] = np.log(data.volume + 1e-8)
data["avg_volume_by_sku"] = data.groupby(["time_idx", "sku"], observed=True).volume.transform("mean")
data["avg_volume_by_agency"] = data.groupby(["time_idx", "agency"], observed=True).volume.transform("mean")
# we want to encode special days as one variable and thus need to first reverse one-hot encoding
special_days = [
"easter_day",
"good_friday",
"new_year",
"christmas",
"labor_day",
"independence_day",
"revolution_day_memorial",
"regional_games",
"fifa_u_17_world_cup",
"football_gold_cup",
"beer_capital",
"music_fest",
]
data[special_days] = data[special_days].apply(lambda x: x.map({0: "-", 1: x.name})).astype("category")
max_prediction_length = 6
max_encoder_length = 24
training_cutoff = data["time_idx"].max() - max_prediction_length
training = TimeSeriesDataSet(
data[lambda x: x.time_idx <= training_cutoff],
time_idx="time_idx",
target="volume",
group_ids=["agency", "sku"],
min_encoder_length=max_encoder_length // 2, # keep encoder length long (as it is in the validation set)
max_encoder_length=max_encoder_length,
min_prediction_length=1,
max_prediction_length=max_prediction_length,
static_categoricals=["agency", "sku"],
static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
time_varying_known_categoricals=["special_days", "month"],
variable_groups={"special_days": special_days}, # group of categorical variables can be treated as one variable
time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"],
time_varying_unknown_categoricals=[],
time_varying_unknown_reals=[
"volume",
"log_volume",
"industry_volume",
"soda_volume",
"avg_max_temp",
"avg_volume_by_agency",
"avg_volume_by_sku",
],
target_normalizer=GroupNormalizer(
groups=["agency", "sku"], transformation="softplus"
), # use softplus and normalize by group
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
)
# create validation set (predict=True) which means to predict
# the last max_prediction_length points in time for each series
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True,
stop_randomization=True)
# ----- below here pertains to Ray ----- #
n_tasks = 1 # number of concurrent training tasks (>1 leads to slow training)
runtime_env = {"env_vars": {"NCCL_SOCKET_IFNAME": "lo,docker0"}}
ray.init(include_dashboard=False,
num_cpus=n_tasks + 1, # just some number >= n_tasks
object_store_memory=50 * 1e9, # large enough
runtime_env=runtime_env)
@ray.remote(num_cpus=1)
def train_task(seed: int,
training: TimeSeriesDataSet,
validation: TimeSeriesDataSet) -> TemporalFusionTransformer:
# create dataloaders for model
batch_size = 128 # set this between 32 to 128
# changing to num_workers=1 shows no effect.
# num_workers=1, multiprocessing_context="fork" shows no effect.
# neither does num_workers=1, multiprocessing_context="spawn", persistent_workers=True
train_dataloader = training.to_dataloader(train=True,
batch_size=batch_size,
num_workers=0)
val_dataloader = validation.to_dataloader(train=False,
batch_size=batch_size * 10,
num_workers=0)
pl.seed_everything(seed)
trainer = pl.Trainer(
accelerator="cpu",
gradient_clip_val=0.1,
max_epochs=1000, # just so that it runs for a while
)
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.03,
hidden_size=16,
attention_head_size=2,
dropout=0.1,
hidden_continuous_size=8,
loss=QuantileLoss(),
optimizer="Ranger",
reduce_on_plateau_patience=4)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")
trainer.fit(tft,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader)
return tft
tfts = []
for seed in range(n_tasks):
tft = train_task.remote(seed, training, validation)
tfts.append(tft)
tfts = ray.get(tfts)
ray.shutdown()
Issue Severity
Medium: It is a significant difficulty but I can work around it.