Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use of Lightning Unified Package Not Currently Supported #8

Closed
funnym0nk3y opened this issue Jan 21, 2023 · 3 comments
Closed

Use of Lightning Unified Package Not Currently Supported #8

funnym0nk3y opened this issue Jan 21, 2023 · 3 comments
Assignees
Labels
design includes a design discussion enhancement New feature or request

Comments

@funnym0nk3y
Copy link

🐛 Bug

When running the default schedule creation instrucitons form the docs I get a ValueError.

To Reproduce

See last lines of BoringModel example below.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Initially based on https://bit.ly/3oQ8Vqf
import re
from functools import partial
from typing import List, Optional
from warnings import WarningMessage

import torch
from lightning import LightningDataModule, LightningModule
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset


def multiwarn_check(
    rec_warns: List, expected_warns: List, expected_mode: bool = False
) -> List[Optional[WarningMessage]]:
    msg_search = lambda w1, w2: re.compile(w1).search(w2.message.args[0])
    if expected_mode:  # we're directed to check that multiple expected warns are obtained
        return [w_msg for w_msg in expected_warns if not any([msg_search(w_msg, w) for w in rec_warns])]
    else:  # by default we're checking that no unexpected warns are obtained
        return [w_msg for w_msg in rec_warns if not any([msg_search(w, w_msg) for w in expected_warns])]


unexpected_warns = partial(multiwarn_check, expected_mode=False)


unmatched_warns = partial(multiwarn_check, expected_mode=True)


class LinearWarmupLR(LambdaLR):
    def __init__(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
        def lr_lambda(current_step: int):
            if current_step < num_warmup_steps:
                return float(current_step) / float(max(1, num_warmup_steps))
            return max(
                0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
            )

        super().__init__(optimizer, lr_lambda, last_epoch)


class CustomLRScheduler:
    def __init__(self, optimizer):
        self.optimizer = optimizer

    def step(self, epoch):
        ...

    def state_dict(self):
        ...

    def load_state_dict(self, state_dict):
        ...


class RandomDictDataset(Dataset):
    def __init__(self, size: int, length: int):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        a = self.data[index]
        b = a + 2
        return {"a": a, "b": b}

    def __len__(self):
        return self.len


class RandomDataset(Dataset):
    def __init__(self, size: int, length: int):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class RandomIterableDataset(IterableDataset):
    def __init__(self, size: int, count: int):
        self.count = count
        self.size = size

    def __iter__(self):
        for _ in range(self.count):
            yield torch.randn(self.size)


class RandomIterableDatasetWithLen(IterableDataset):
    def __init__(self, size: int, count: int):
        self.count = count
        self.size = size

    def __iter__(self):
        for _ in range(len(self)):
            yield torch.randn(self.size)

    def __len__(self):
        return self.count


class BoringModel(LightningModule):
    def __init__(self):
        """Testing PL Module.

        Use as follows:
        - subclass
        - modify the behavior for what you want

        class TestModel(BaseTestModel):
            def training_step(...):
                # do your own thing

        or:

        model = BaseTestModel()
        model.training_epoch_end = None
        """
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def step(self, x):
        x = self(x)
        out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
        return out

    def training_step(self, batch, batch_idx):
        output = self(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def validation_epoch_end(self, outputs) -> None:
        torch.stack([x["x"] for x in outputs]).mean()

    def test_step(self, batch, batch_idx):
        output = self(batch)
        loss = self.loss(batch, output)
        return {"y": loss}

    def test_epoch_end(self, outputs) -> None:
        torch.stack([x["y"] for x in outputs]).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def val_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def test_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def predict_dataloader(self):
        return DataLoader(RandomDataset(32, 64))


class BoringDataModule(LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.non_picklable = None
        self.checkpoint_state: Optional[str] = None
        self.random_full = RandomDataset(32, 64 * 4)

    def setup(self, stage: Optional[str] = None):
        if stage == "fit" or stage is None:
            self.random_train = Subset(self.random_full, indices=range(64))

        if stage in ("fit", "validate") or stage is None:
            self.random_val = Subset(self.random_full, indices=range(64, 64 * 2))

        if stage == "test" or stage is None:
            self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3))

        if stage == "predict" or stage is None:
            self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))

    def train_dataloader(self):
        return DataLoader(self.random_train)

    def val_dataloader(self):
        return DataLoader(self.random_val)

    def test_dataloader(self):
        return DataLoader(self.random_test)

    def predict_dataloader(self):
        return DataLoader(self.random_predict)


class ManualOptimBoringModel(BoringModel):
    def __init__(self):
        super().__init__()
        self.automatic_optimization = False

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        output = self(batch)
        loss = self.loss(batch, output)
        opt.zero_grad()
        self.manual_backward(loss)
        opt.step()
        return loss


from finetuning_scheduler import FinetuningScheduler
from lightning import Trainer

trainer = Trainer(callbacks=[FinetuningScheduler(gen_ft_sched_only=True)])
print(trainer.log_dir)
model = BoringModel()
data = BoringDataModule()
trainer.fit(model,data)

Error

Traceback (most recent call last):
  File "/home/michael/work/michael/ml1-bonus/test.py", line 248, in <module>
    trainer.fit(model,data)
  File "/home/michael/work/michael/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 603, in fit
    call._call_and_handle_interrupt(
  File "/home/michael/work/michael/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/michael/work/michael/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 645, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/home/michael/work/michael/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1024, in _run
    verify_loop_configurations(self)
  File "/home/michael/work/michael/lib/python3.8/site-packages/lightning/pytorch/trainer/configuration_validator.py", line 53, in verify_loop_configurations
    _check_deprecated_callback_hooks(trainer)
  File "/home/michael/work/michael/lib/python3.8/site-packages/lightning/pytorch/trainer/configuration_validator.py", line 238, in _check_deprecated_callback_hooks
    if is_overridden(method_name="on_load_checkpoint", instance=callback) and has_legacy_argument:
  File "/home/michael/work/michael/lib/python3.8/site-packages/lightning/pytorch/utilities/model_helpers.py", line 34, in is_overridden
    raise ValueError("Expected a parent")
ValueError: Expected a parent

Environment

  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 3070
    - NVIDIA GeForce GTX 960
    - available: True
    - version: 11.7
  • Packages:
    - finetuning-scheduler: 0.3.3
    - numpy: 1.24.1
    - pyTorch_debug: False
    - pyTorch_version: 1.13.1+cu117
    - pytorch-lightning: 1.8.4
    - tqdm: 4.64.1
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.10
    - version: #64~20.04.1-Ubuntu SMP Fri Jan 6 16:42:31 UTC 2023

Additional context

@funnym0nk3y funnym0nk3y added the bug Something isn't working label Jan 21, 2023
@speediedan
Copy link
Owner

speediedan commented Jan 25, 2023

Thanks for the valuable feedback @funnym0nk3y!

This issue is a function of importing the unified package (lightning.pytorch) instead of the standalone package (pytorch_lightning). The former is just a mirror of the latter but isinstance doesn't know that:

image

https://github.com/Lightning-AI/lightning/blob/7eb5ff55580c9a2036002a471fff17a8cc471d9c/src/pytorch_lightning/utilities/model_helpers.py#L31

While in the current version, Fine-Tuning Scheduler only officially supports the standalone pytorch-lightning, there's no reason the unified package (lightning.pytorch) can't be supported and the plan is to add this support officially in either a 0.4.x patch version or at latest with the next non-patch FTS release (which would be paired with the first Lightning 2.x release).

In the meantime, replacing your imports of lightning* with pytorch_lightning should allow you to workaround the limitation.

import pytorch_lightning as pl
isinstance(instance, pl.Callback) # True, schedule generation succeeds with the above import
import lightning.pytorch as pl
isinstance(instance, pl.Callback) # False, schedule generation will fail if using the lightning.* unified package mirror 

I'll keep this issue open until support for the unified package is officially added by FTS.

By the way, FTS 0.4.0 has just been released. Support for PyTorch 2.0 is planned as soon as it is stable so should be there with the next non-patch release of FTS.

Thanks again for the valuable feedback!

@speediedan speediedan added enhancement New feature or request design includes a design discussion labels Jan 25, 2023
@speediedan speediedan self-assigned this Jan 25, 2023
@speediedan speediedan changed the title Default schedule creation does not work Use of Lightning Unified Package Not Currently Supported Jan 25, 2023
@speediedan speediedan removed the bug Something isn't working label Jan 25, 2023
@speediedan
Copy link
Owner

Resolved with the release of finetuning-scheduler 2.0.0. Fine-Tuning Scheduler (FTS) by default now depends upon the lightning package rather than the standalone pytorch-lightning package beginning with FTS 2.0 (though the latter can still be installed and used similar to Lightning). Thanks for your feedback/contribution!

@jnyjxn
Copy link

jnyjxn commented Jun 14, 2023

I am trying to integrate finetuning-scheduler with my code which is built around pytorch_lightning and am having import errors in a way that seems to be essentially the mirror of this ticket.

import pytorch_lightning as pl
issubclass(finetuning_scheduler.FinetuningScheduler, pl.Callback)
>>> False
issubclass(finetuning_scheduler.FTSCheckpoint, pl.Callback)
>>> False
issubclass(finetuning_scheduler.FTSEarlyStopping, pl.Callback)
>>> False

import lightning.pytorch as pl
issubclass(finetuning_scheduler.FinetuningScheduler, pl.Callback)
>>> True
issubclass(finetuning_scheduler.FTSCheckpoint, pl.Callback)
>>> True
issubclass(finetuning_scheduler.FTSEarlyStopping, pl.Callback)
>>> True

This is having the result that when trying to specify finetuning callbacks in the config, the following type of error is observed:

main.py: error: Parser key "trainer.callbacks":
  Does not validate against any of the Union subtypes
  Subtypes: (typing.List[pytorch_lightning.callbacks.callback.Callback], <class 'pytorch_lightning.callbacks.callback.Callback'>, <class 'NoneType'>)
  Errors:
    - Expected a <class 'list'>
    - Import path finetuning_scheduler.FinetuningScheduler does not correspond to a subclass of <class 'pytorch_lightning.callbacks.callback.Callback'>
    - Expected a <class 'NoneType'>
  Given value type: <class 'dict'>
  Given value: {'class_path': 'finetuning_scheduler.FinetuningScheduler'}

Relevant parts of pip freeze:

finetuning-scheduler==2.0.2
jsonargparse==4.21.2
jsonschema==3.2.0
lightning==2.0.0
lightning-cloud==0.5.32
lightning-utilities==0.8.0
omegaconf==2.3.0
pytorch-lightning==2.0.3
torch==2.0.0
torchmetrics==0.11.4
torchvision==0.15.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design includes a design discussion enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants