Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Different in performance and best model between analysis.best_config and the best model in analysis.dataframe #17923

Closed
2 tasks
marcovaresi opened this issue Aug 18, 2021 · 3 comments · Fixed by #18850
Assignees
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks tune Tune-related issues

Comments

@marcovaresi
Copy link

What is the problem?

pytorch_lightning: '1.4.2'
ray: '1.5.2'
tensorboardX: '2.4'

callbacks = TuneReportCheckpointCallback
scheduler = ASHAScheduler
search_alg = HyperOptSearch
reporter = CLIReporter

at the end of the tune.run (searching for the best performance on the validation loss) i obtain the best model (best config) and the best checkpoint.
When i move the result in a df using analysis.dataframe if i search for the minimum value for the validation loss, frequently i obtain a different configuration with respect to the best model.
the pl model it runs with earlystopping on validation loss

Reproduction (REQUIRED)

Please provide a short code snippet (less than 50 lines if possible) that can be copy-pasted to reproduce the issue. The snippet should have no external library dependencies (i.e., use fake or mock data / environments):

If the code snippet cannot be run by itself, the issue will be closed with "needs-repro-script".

  • I have verified my script runs in a clean environment and reproduces the issue.
  • I have verified the issue also occurs with the latest wheels.
@marcovaresi marcovaresi added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Aug 18, 2021
@amogkam
Copy link
Contributor

amogkam commented Aug 18, 2021

Hey @marcovaresi thanks for raising this!

Would you be able to provide more info here? What is analysis.best_config giving you vs. what does the full analysis.dataframe look like? What is the default metric and mode you pass into tune.run?

If you could provide a small reproducible example showing this bug, that would be great.

@amogkam amogkam added the needs-repro-script Issue needs a runnable script to be reproduced label Aug 18, 2021
@marcovaresi
Copy link
Author

i made this little example using a tutorial of a model in pytorch lightning:

import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
from ray.tune.suggest.hyperopt import HyperOptSearch
from pytorch_lightning.callbacks import EarlyStopping

import sys   #ray 'ConsoleBuffer' object has no attribute 'fileno'     
sys.stdout.fileno = lambda: False   #https://forums.databricks.com/questions/45772/is-there-step-by-step-guide-how-to-setup-ray-clust.html

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
from typing import *

import pytorch_lightning as pl
from pl_examples import _DATASETS_PATH, cli_lightning_logo
from pl_examples.basic_examples.mnist_datamodule import MNIST
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
    from torchvision import transforms


class LitAutoEncoder(pl.LightningModule):
      """
      >>> LitAutoEncoder()  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
      LitAutoEncoder(
        (encoder): ...
        (decoder): ...
      )
      """

    def __init__(self, conf: Dict):
        super().__init__()
        self.conf = conf
        self.encoder = nn.Sequential(nn.Linear(28 * 28, self.conf["hidden_dim"]), nn.ReLU(), nn.Linear(self.conf["hidden_dim"], 3))
        self.decoder = nn.Sequential(nn.Linear(3, self.conf["hidden_dim"]), nn.ReLU(), nn.Linear(self.conf["hidden_dim"], 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("valid_loss", loss, on_step=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("test_loss", loss, on_step=True)

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        return self.decoder(z)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

class MyDataModule(pl.LightningDataModule):

    def __init__(self, config: Dict):
        super().__init__()
        dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
        self.mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
        self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
        self.config = config

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.config["batch_size"])

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.config["batch_size"])

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.config["batch_size"])

    def predict_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.config["batch_size"])



def train_tune(config: dict, num_epochs: int, patience: int, num_gpus: int) -> None:
  model = LitAutoEncoder(config)
  datamodule = MyDataModule(config)
  early_stopping = EarlyStopping(
  monitor='valid_loss',
  patience = patience,
  mode = 'min',
  strict = True,
  )
  
  trainer = pl.Trainer(
    gpus=num_gpus,
    progress_bar_refresh_rate=0,
    callbacks=[TuneReportCheckpointCallback(
      metrics={
        "LOSS_VALIDATION": "valid_loss",
      },
      on="validation_end"
    ), early_stopping],
  )
  torch.cuda.empty_cache()
  trainer.fit(model, datamodule=datamodule)

def tune_asha(num_samples: int = 10, num_epochs: int = 200, patience: int = 20, grace_period: int = 50, cpus_per_trial: int = 6, gpus_per_trial : int = 1) -> pl.LightningModule:
  config = {
    "hidden_dim": tune.choice([32, 64, 128]),
    "batch_size": tune.choice([32, 64, 128]), 
  }
  
  
  scheduler = ASHAScheduler(
    max_t=num_epochs,
    grace_period=grace_period,
    reduction_factor=2,
  )
  
  algo = HyperOptSearch(
        space=config, 
        metric="LOSS_VALIDATION",
        mode="min",
        n_initial_points = 10,
        random_state_seed=12345 
    )
  
  
  reporter = CLIReporter(
    parameter_columns=[
      "hidden_dim",
      "batch_size",
    ],
    metric_columns=[
      "LOSS_VALIDATION",
      "training_iteration"
    ]
  )
  analysis = tune.run(
    tune.with_parameters(
      train_tune,
      num_epochs = num_epochs,
      num_gpus = gpus_per_trial,
      patience = patience,
    ),
    resources_per_trial={
      "cpu": cpus_per_trial,
      "gpu": gpus_per_trial,
    },
    search_alg=algo,
    metric="LOSS_VALIDATION",
    mode="min",
    num_samples=num_samples,
    scheduler=scheduler,
    progress_reporter=reporter,
    name="logs",
  )
  return analysis

###run###

analysis = tune_asha(num_samples=4, num_epochs=100, patience = 20, grace_period=20, gpus_per_trial=1)

checkpoint = torch.load(analysis.best_checkpoint + "checkpoint")

print("Best hyperparameters ", analysis.best_config)
print("Best checkpoint ", analysis.best_checkpoint)
model = LitAutoEncoder.load_from_checkpoint(analysis.best_checkpoint + "checkpoint", conf = analysis.best_config)
datamodule = MyDataModule(analysis.best_config)

trainer = pl.Trainer(gpus=1)
trainer.test(model, datamodule=datamodule)

here the best model has the checkpoint at epochs 70 (probably my run is not reproducible)

trainer.test(model, dataloaders=datamodule.val_dataloader()) # print of the loss validation at the best epochs

df = analysis.dataframe(metric="LOSS_VALIDATION", mode="min")
display(df[df.LOSS_VALIDATION == df.LOSS_VALIDATION.min()]) 

here the best model is again the same but at epochs 71 so the value of validation loss is different, in this case the loss is bigger, but during the run of my complex model the best model is not the same and the minimum val loss here is smaller than the best model of analysis

@richardliaw richardliaw added P1 Issue that should be fixed within a few weeks tune Tune-related issues and removed needs-repro-script Issue needs a runnable script to be reproduced triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Sep 7, 2021
@krfricke
Copy link
Contributor

krfricke commented Sep 23, 2021

I think this is due to the confusing concept of the analysis.dataframe() method.

analysis.dataframe() fetches the results from each trial. But because there are multiple results for each trial, we have to tell it which result to get.

If you call analysis.dataframe(mode="min") it will get the smallest ever-observed result from that trial and report it. This is not necessarily the last result. And this is where the divergence happens.

Here is an example that illustrates the problem you're running into:

from ray import tune


def train(config):
    if config["var"] == 1:
        tune.report(loss=9)
        tune.report(loss=7)
        # This is the smallest last loss, but not the smallest ever loss
        tune.report(loss=5)
    else:
        tune.report(loss=10)
        # This is the smallest ever loss, but it's not the last!
        tune.report(loss=4)
        tune.report(loss=10)


analysis = tune.run(
    train,
    config={
        "var": tune.grid_search([1, 2])
    },
    metric="loss",
    mode="min")

print("Actual", analysis.best_trial, analysis.best_result["loss"])
# Will print: Actual train_da3d2_00000 5

df = analysis.dataframe(metric="loss", mode="min")
row = df[df.loss == df.loss.min()]
print("Wrong", row.trial_id.values[0], row.loss.values[0])
# Will print: Wrong da3d2_00001 4

There are two problems here. First, the dataframe() documentation should be very explicit about what it does - and possibly not re-use the names metric and mode here. Second, fetching the last result should work when you pass mode=None here - but there is a small bug that currently prevents that. I'll file a PR to fix this.

In the meantime, you can fix this behavior like this (for the example above):

print("Actual", analysis.best_trial, analysis.best_result["loss"])
# Will print: Actual train_da3d2_00000 5

analysis.default_mode = None  # This is a workaround for the bug
df = analysis.dataframe(metric="loss", mode=None)  # Pass None here!
row = df[df.loss == df.loss.min()]
print("Fixed", row.trial_id.values[0], row.loss.values[0])
# Will print: Fixed da3d2_00000 5

We'll revise the experiment analysis experience in the near future to hopefully prevent these kind of problems

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks tune Tune-related issues
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants