Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] LightningTrainer converts relative checkpoint dirpath to absolute path #35894

Closed
woshiyyya opened this issue May 30, 2023 · 1 comment
Assignees
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks ray-team-created Ray Team created train Ray Train Related Issue

Comments

@woshiyyya
Copy link
Member

What happened + What you expected to happen

During restoration, PyTorch Lightning expects all workers to have the same directory structure. But when we specify a relative path for dirpath, PyTorch Lightning creates a new folder under the current working directory (the rank_x folder) like .../LightningTrainer_7282d_00000_0_2023-05-26_01-43-54/rank_x/{dirpath} for each worker.

This will cause the internal state of the ModelCheckpoint callback to not be properly restored, resulting in inconsistent NCCL operations that resulted in timeouts. We need to find a proper way in LightningTrainer to convert relative checkpoint dirpaths to absolute paths to remove this issue.

Versions / Dependencies

master

Reproduction script

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from filelock import FileLock
from torch.utils.data import DataLoader, random_split, Subset
from torchmetrics import Accuracy
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning import trainer
from pytorch_lightning.core import datamodule
from pytorch_lightning.loggers.csv_logs import CSVLogger

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=100):
        super().__init__()
        self.data_dir = os.getcwd()
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def setup(self, stage=None):
        with FileLock(f"{self.data_dir}.lock"):
            mnist = MNIST(
                self.data_dir, train=True, download=True, transform=self.transform
            )

            
            self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        with FileLock(f"{self.data_dir}.lock"):
            self.mnist_test = MNIST(
                self.data_dir, train=False, download=True, transform=self.transform
            )
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)


datamodule = MNISTDataModule(batch_size=128)

class MNISTClassifier(pl.LightningModule):
    def __init__(self, lr=1e-3, feature_dim=128):
        torch.manual_seed(421)
        super(MNISTClassifier, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, 10),
            nn.ReLU(),
        )
        self.lr = lr
        self.accuracy = Accuracy(task="multiclass", num_classes=10)
        self.eval_loss = []
        self.eval_accuracy = []

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.linear_relu_stack(x)
        return x

    def training_step(self, batch, batch_idx):
        if self.current_epoch == 3 and self.global_rank == 0:
            if not os.path.exists("/mnt/cluster_storage/error_flag"):
                os.system("touch /mnt/cluster_storage/error_flag")
                raise RuntimeError("Mock Error")

        import time
        time.sleep(0.01)
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        loss, acc = self._shared_eval(val_batch)
        self.log("val_accuracy", acc)
        self.eval_loss.append(loss)
        self.eval_accuracy.append(acc)
        return {"val_loss": loss, "val_accuracy": acc}

    def test_step(self, test_batch, batch_idx):
        loss, acc = self._shared_eval(test_batch)
        self.log("test_accuracy", acc)
        return {"test_loss": loss, "test_accuracy": acc}

    def _shared_eval(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        acc = self.accuracy(logits, y)
        return loss, acc

    def on_validation_epoch_end(self):
        avg_loss = torch.stack(self.eval_loss).mean()
        avg_acc = torch.stack(self.eval_accuracy).mean()
        self.log("val_loss", avg_loss, sync_dist=True)
        self.log("val_accuracy", avg_acc, sync_dist=True)
        self.eval_loss.clear()
        self.eval_accuracy.clear()
    
        with open(f"/mnt/cluster_storage/ckpt_callback_{self.global_rank}.txt", 'w') as f:
            f.write(f"{self.trainer.checkpoint_callback.__dict__}")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

from pytorch_lightning.callbacks import ModelCheckpoint
from ray.air.config import FailureConfig, RunConfig, ScalingConfig, CheckpointConfig
from ray.train.lightning import (
    LightningTrainer,
    LightningConfigBuilder,
)

def build_lightning_config_from_existing_code(use_gpu):
    config_builder = LightningConfigBuilder()
    config_builder.module(cls=MNISTClassifier, lr=1e-3, feature_dim=128)
    config_builder.checkpointing(
        # dirpath=f'chkpts/{datetime.datetime.now().strftime("%d-%m-%y_%H:%M:%S")}/',
        dirpath="/tmp/my_ckpt_dir",
        monitor="val_accuracy", 
        mode="max", 
        save_top_k=3
    )
    config_builder.trainer(
        max_epochs=10000,
        accelerator="gpu" if use_gpu else "cpu",
        log_every_n_steps=100,
        logger=CSVLogger("logs"),
    )
    config_builder.fit_params(datamodule=datamodule)
    lightning_config = config_builder.build()
    return lightning_config


use_gpu = True
lightning_config = build_lightning_config_from_existing_code(use_gpu=use_gpu)
scaling_config = ScalingConfig(num_workers=2, use_gpu=use_gpu)

run_config = RunConfig(
    name="test-sync-artifacts",
    storage_path="/tmp/ray_results",
    checkpoint_config=CheckpointConfig(
        num_to_keep=3,
        checkpoint_score_attribute="val_accuracy",
        checkpoint_score_order="max",
    ),
    failure_config=FailureConfig(max_failures=-1),
)

trainer = LightningTrainer(
    lightning_config=lightning_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

result = trainer.fit()

Issue Severity

Medium: It is a significant difficulty but I can work around it.

@woshiyyya woshiyyya added bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks train Ray Train Related Issue ray-team-created Ray Team created labels May 30, 2023
@woshiyyya woshiyyya self-assigned this May 30, 2023
@woshiyyya
Copy link
Member Author

Closed by PR #36165

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks ray-team-created Ray Team created train Ray Train Related Issue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant