Skip to content

Data leakage between train and validation sets #75

@goepp

Description

@goepp

Report

Hi,
I noticed that the train and validation split in MILClassifier allows for instances from one sample to be in both train and validation sets.

How to reproduce

For instance, we can add the following code below mil.train(lr=1e-3) in the Classification with multiple instance learning (MIL) tutorial:

train_patients = adata.obs.iloc[mil.train_indices_]['sample'].value_counts().index
val_patients = adata.obs.iloc[mil.validation_indices_]['sample'].value_counts().index
train_val_overlap = set(train_patients).intersection(val_patients)
print(len(train_val_overlap))

The output is 108, the number of samples in both train and validation split

Is this a bug?

I believe this creates data leakage, as on my data (with MILClassifier running a regression task) the validation error is consistently below the training error. This is also visible on the tutorial, altough less obvious because it's running a classification task.

Version information


anndata 0.10.8
gdown 5.2.0
multimil 0.2.0
numpy 1.26.4
pandas 2.2.3
scanpy 1.11.0
scvi 0.20.0
session_info 1.0.0

PIL 11.1.0
absl NA
asttokens NA
attr 25.1.0
bs4 4.13.3
certifi 2025.01.31
charset_normalizer 3.4.1
chex 0.1.88
comm 0.2.2
cycler 0.12.1
cython_runtime NA
dateutil 2.9.0.post0
debugpy 1.8.12
decorator 5.1.1
docrep 0.3.2
etils 1.12.0
exceptiongroup 1.2.2
executing 2.1.0
filelock 3.17.0
flax 0.10.3
fsspec 2025.2.0
h5py 3.12.1
idna 3.10
ipykernel 6.29.5
ipywidgets 8.1.5
jax 0.5.0
jaxlib 0.5.0
jedi 0.19.2
joblib 1.4.2
kiwisolver 1.4.8
legacy_api_wrap NA
lightning_fabric 1.9.5
lightning_utilities 0.12.0
llvmlite 0.44.0
matplotlib 3.10.0
matplotlib_inline 0.1.7
ml_collections 1.0.0
ml_dtypes 0.5.1
mpl_toolkits NA
mpmath 1.3.0
msgpack 1.1.0
mudata 0.3.1
multipledispatch 0.6.0
natsort 8.4.0
numba 0.61.0
numpyro 0.17.0
nvidia NA
opt_einsum 3.4.0
optax 0.2.4
packaging 24.2
parso 0.8.4
pickleshare 0.7.5
pkg_resources NA
platformdirs 4.3.6
prompt_toolkit 3.0.50
psutil 6.1.1
pure_eval 0.2.3
pydev_ipython NA
pydevconsole NA
pydevd 3.2.3
pydevd_file_utils NA
pydevd_plugins NA
pydevd_tracing NA
pygments 2.19.1
pyparsing 3.2.1
pyro 1.9.1
pytorch_lightning 1.9.5
pytz 2025.1
requests 2.32.3
rich NA
scipy 1.15.2
setuptools 59.5.0
simplejson 3.20.1
six 1.17.0
sklearn 1.5.2
socks 1.7.1
soupsieve 2.6
stack_data 0.6.3
sympy 1.13.1
threadpoolctl 3.5.0
toolz 1.0.0
torch 2.6.0+cu124
torchgen NA
torchmetrics 1.6.1
tornado 6.4.2
tqdm 4.67.1
traitlets 5.14.3
triton 3.2.0
typing_extensions NA
urllib3 2.3.0
vscode NA
wcwidth 0.2.13
yaml 6.0.2
zmq 26.2.1
zoneinfo NA

IPython 8.32.0
jupyter_client 8.6.3
jupyter_core 5.7.2

Python 3.10.16 | packaged by conda-forge | (main, Dec 5 2024, 14:16:10) [GCC 13.3.0]
Linux-4.18.0-553.36.1.el8_10.x86_64-x86_64-with-glibc2.28

Session information updated at 2025-02-21 19:01

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions