Skip to content

Overfit batches parameter gives a validation batch #15021

@HekpoMaH

Description

@HekpoMaH

Bug description

When overfitting on a single batch and defining dataloaders in class, the batch provided to the validation step is different from the batch on the training step. I was told in the slack community that this is NOT the intended behaviour.

How to reproduce the bug

import pytorch_lightning as pl
import torch_geometric
import torch

dataset = [torch_geometric.data.Data(x=torch.tensor([i])) for i in range(10)]
val_dataset = [torch_geometric.data.Data(x=torch.tensor([j])) for j in range(10,20)]
class LitModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.tensor([0.]))
    def train_dataloader(self):
        return torch_geometric.loader.DataLoader(dataset, batch_size=2)
    def val_dataloader(self):
        return torch_geometric.loader.DataLoader(val_dataset, batch_size=2)

    def training_step(self, batch, batch_idx):
        print('train', batch.x)
        return torch.nn.functional.mse_loss(self.param,torch.tensor([1.]).to(self.param))

    def validation_step(self, batch, batch_idx):
        print('val', batch.x)
        return torch.nn.functional.mse_loss(self.param,torch.tensor([1.]).to(self.param))

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),
                               lr=.0001)
        return optimizer

litmod = LitModule()
trainer = pl.Trainer(
    overfit_batches=1,
    accelerator='cuda',
    max_epochs=20,
    check_val_every_n_epoch=10,
)
trainer.fit(litmod)
print(litmod)

The val batch is the [10,11] tensor, the train batch is the [0,1] tensor
image


### Environment

  • CUDA:
    • GPU:
      • NVIDIA GeForce RTX 3080 Laptop GPU
    • available: True
    • version: 11.6
  • Lightning:
    • pytorch-lightning: 1.7.7
    • torch: 1.12.1+cu116
    • torch-cluster: 1.6.0
    • torch-geometric: 2.1.0.post1
    • torch-scatter: 2.0.9
    • torch-sparse: 0.6.15
    • torch-spline-conv: 1.2.1
    • torchaudio: 0.12.1+cu116
    • torchmetrics: 0.10.0
    • torchvision: 0.13.1+cu116
  • Packages:
    • absl-py: 1.2.0
    • aiohttp: 3.8.3
    • aiosignal: 1.2.0
    • anndata: 0.8.0
    • astroid: 2.12.10
    • astunparse: 1.6.3
    • async-timeout: 4.0.2
    • attrs: 22.1.0
    • blinker: 1.4
    • brotlipy: 0.7.0
    • cachetools: 5.2.0
    • certifi: 2022.9.24
    • cffi: 1.15.1
    • charset-normalizer: 2.1.1
    • chex: 0.1.5
    • click: 8.0.4
    • colorama: 0.4.5
    • contourpy: 1.0.5
    • cryptography: 37.0.2
    • cycler: 0.11.0
    • dill: 0.3.5.1
    • distlib: 0.3.6
    • dm-clrs: 1.0.0
    • dm-haiku: 0.0.8
    • dm-tree: 0.1.7
    • etils: 0.8.0
    • filelock: 3.8.0
    • flatbuffers: 1.12
    • fonttools: 4.37.3
    • frozenlist: 1.3.1
    • fsspec: 2022.8.2
    • gast: 0.4.0
    • google-auth: 2.12.0
    • google-auth-oauthlib: 0.4.6
    • google-pasta: 0.2.0
    • googleapis-common-protos: 1.56.4
    • grpcio: 1.49.1
    • h5py: 3.7.0
    • idna: 3.4
    • importlib-metadata: 4.11.4
    • importlib-resources: 5.9.0
    • isort: 5.10.1
    • jax: 0.3.21
    • jaxlib: 0.3.20
    • jinja2: 3.1.2
    • jmp: 0.0.2
    • joblib: 1.2.0
    • jsonschema: 4.16.0
    • keras: 2.9.0
    • keras-preprocessing: 1.1.2
    • kiwisolver: 1.4.4
    • lazy-object-proxy: 1.7.1
    • libclang: 14.0.6
    • llvmlite: 0.39.1
    • markdown: 3.4.1
    • markupsafe: 2.1.1
    • matplotlib: 3.6.0
    • mccabe: 0.7.0
    • mkl-fft: 1.3.1
    • mkl-random: 1.2.2
    • mkl-service: 2.4.0
    • msgpack: 1.0.4
    • multidict: 6.0.2
    • natsort: 8.2.0
    • networkx: 2.8.6
    • numba: 0.56.2
    • numexpr: 2.8.3
    • numpy: 1.23.3
    • oauthlib: 3.2.1
    • opt-einsum: 3.3.0
    • optax: 0.1.3
    • packaging: 21.3
    • pandas: 1.5.0
    • patsy: 0.5.2
    • pillow: 9.2.0
    • pip: 22.1.2
    • platformdirs: 2.5.2
    • promise: 2.3
    • protobuf: 3.19.6
    • pyasn1: 0.4.8
    • pyasn1-modules: 0.2.8
    • pycparser: 2.21
    • pydeprecate: 0.3.2
    • pyjwt: 2.5.0
    • pylint: 2.15.3
    • pynndescent: 0.5.7
    • pyopenssl: 22.0.0
    • pyparsing: 3.0.9
    • pyrsistent: 0.18.1
    • pysocks: 1.7.1
    • python-dateutil: 2.8.2
    • pytorch-lightning: 1.7.7
    • pytz: 2022.2.1
    • pyu2f: 0.1.5
    • pyyaml: 6.0
    • ray: 2.0.0
    • requests: 2.28.1
    • requests-oauthlib: 1.3.1
    • rsa: 4.9
    • scanpy: 1.9.1
    • scikit-learn: 1.1.2
    • scikit-misc: 0.1.4
    • scipy: 1.9.1
    • seaborn: 0.12.0
    • session-info: 1.0.0
    • setuptools: 65.4.1
    • six: 1.16.0
    • statsmodels: 0.13.2
    • stdlib-list: 0.8.0
    • tables: 3.7.0
    • tabulate: 0.8.10
    • tensorboard: 2.9.1
    • tensorboard-data-server: 0.6.1
    • tensorboard-plugin-wit: 1.8.1
    • tensorboardx: 2.5.1
    • tensorflow: 2.9.1
    • tensorflow-estimator: 2.9.0
    • tensorflow-io-gcs-filesystem: 0.27.0
    • tensorflow-metadata: 1.10.0
    • termcolor: 2.0.1
    • tfds-nightly: 4.5.2.dev202204190046
    • threadpoolctl: 3.1.0
    • toml: 0.10.2
    • tomli: 2.0.1
    • tomlkit: 0.11.5
    • toolz: 0.12.0
    • torch: 1.12.1+cu116
    • torch-cluster: 1.6.0
    • torch-geometric: 2.1.0.post1
    • torch-scatter: 2.0.9
    • torch-sparse: 0.6.15
    • torch-spline-conv: 1.2.1
    • torchaudio: 0.12.1+cu116
    • torchmetrics: 0.10.0
    • torchvision: 0.13.1+cu116
    • tqdm: 4.64.1
    • typing-extensions: 4.3.0
    • umap-learn: 0.5.3
    • urllib3: 1.26.12
    • virtualenv: 20.16.5
    • werkzeug: 2.2.2
    • wheel: 0.37.1
    • wrapt: 1.14.1
    • yapf: 0.32.0
    • yarl: 1.8.1
    • zipp: 3.8.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.10.4
    • version: #202203181321-Ubuntu SMP PREEMPT Fri Mar 18 13:28:32 UTC 2022


### More info

_No response_

cc @justusschock @awaelchli

Activity

added
bugSomething isn't working
and removed
needs triageWaiting to be triaged by maintainers
on Oct 9, 2022
added this to the pl:1.7.x milestone on Oct 9, 2022
modified the milestones: pl:1.7.x, v1.8.x on Oct 13, 2022
modified the milestones: v1.8.x, v1.9 on Jan 6, 2023
modified the milestones: v1.9, v1.9.x on Jan 16, 2023
removed this from the v1.9.x milestone on Dec 31, 2023
israfelsr

israfelsr commented on Jan 11, 2024

@israfelsr

I had the same problem. I was going crazy because in the documentation they supposed to be the same 😅.

dgcnz

dgcnz commented on Jan 11, 2025

@dgcnz

Same here

nilsleh

nilsleh commented on Mar 13, 2025

@nilsleh

@Borda
The latest documentation including the video snippet suggests that train_batch and val_batch will be identical but it seems that overfit_batches will use a "same" train_batch and a separate "same" val_batch instead.

adosar

adosar commented on Apr 28, 2025

@adosar
Contributor

@nilsleh Indeed the video snippet suggest that their identical which IMO should't, see #20731 (comment).

With regards to the documentation (version 2.5.1):

# default used by the Trainer
trainer = Trainer(overfit_batches=0.0)

# use only 1% of the train & val set
trainer = Trainer(overfit_batches=0.01)

# overfit on 10 of the same batches   <--- This seems confusing
trainer = Trainer(overfit_batches=10)

I think the last comment is a little bit confusing since it gives the idea that the same 10 batches are used for both training and validation. Maybe it should read as # overfit on 10 (same) train batches & 10 (same) val batches.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    Milestone

    No milestone

    Relationships

    None yet

      Participants

      @awaelchli@Borda@HekpoMaH@israfelsr@carmocca

      Issue actions

        Overfit batches parameter gives a validation batch · Issue #15021 · Lightning-AI/pytorch-lightning