-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Report
Hi,
I noticed that the train and validation split in MILClassifier allows for instances from one sample to be in both train and validation sets.
How to reproduce
For instance, we can add the following code below mil.train(lr=1e-3) in the Classification with multiple instance learning (MIL) tutorial:
train_patients = adata.obs.iloc[mil.train_indices_]['sample'].value_counts().index
val_patients = adata.obs.iloc[mil.validation_indices_]['sample'].value_counts().index
train_val_overlap = set(train_patients).intersection(val_patients)
print(len(train_val_overlap))
The output is 108, the number of samples in both train and validation split
Is this a bug?
I believe this creates data leakage, as on my data (with MILClassifier running a regression task) the validation error is consistently below the training error. This is also visible on the tutorial, altough less obvious because it's running a classification task.
Version information
anndata 0.10.8
gdown 5.2.0
multimil 0.2.0
numpy 1.26.4
pandas 2.2.3
scanpy 1.11.0
scvi 0.20.0
session_info 1.0.0
PIL 11.1.0
absl NA
asttokens NA
attr 25.1.0
bs4 4.13.3
certifi 2025.01.31
charset_normalizer 3.4.1
chex 0.1.88
comm 0.2.2
cycler 0.12.1
cython_runtime NA
dateutil 2.9.0.post0
debugpy 1.8.12
decorator 5.1.1
docrep 0.3.2
etils 1.12.0
exceptiongroup 1.2.2
executing 2.1.0
filelock 3.17.0
flax 0.10.3
fsspec 2025.2.0
h5py 3.12.1
idna 3.10
ipykernel 6.29.5
ipywidgets 8.1.5
jax 0.5.0
jaxlib 0.5.0
jedi 0.19.2
joblib 1.4.2
kiwisolver 1.4.8
legacy_api_wrap NA
lightning_fabric 1.9.5
lightning_utilities 0.12.0
llvmlite 0.44.0
matplotlib 3.10.0
matplotlib_inline 0.1.7
ml_collections 1.0.0
ml_dtypes 0.5.1
mpl_toolkits NA
mpmath 1.3.0
msgpack 1.1.0
mudata 0.3.1
multipledispatch 0.6.0
natsort 8.4.0
numba 0.61.0
numpyro 0.17.0
nvidia NA
opt_einsum 3.4.0
optax 0.2.4
packaging 24.2
parso 0.8.4
pickleshare 0.7.5
pkg_resources NA
platformdirs 4.3.6
prompt_toolkit 3.0.50
psutil 6.1.1
pure_eval 0.2.3
pydev_ipython NA
pydevconsole NA
pydevd 3.2.3
pydevd_file_utils NA
pydevd_plugins NA
pydevd_tracing NA
pygments 2.19.1
pyparsing 3.2.1
pyro 1.9.1
pytorch_lightning 1.9.5
pytz 2025.1
requests 2.32.3
rich NA
scipy 1.15.2
setuptools 59.5.0
simplejson 3.20.1
six 1.17.0
sklearn 1.5.2
socks 1.7.1
soupsieve 2.6
stack_data 0.6.3
sympy 1.13.1
threadpoolctl 3.5.0
toolz 1.0.0
torch 2.6.0+cu124
torchgen NA
torchmetrics 1.6.1
tornado 6.4.2
tqdm 4.67.1
traitlets 5.14.3
triton 3.2.0
typing_extensions NA
urllib3 2.3.0
vscode NA
wcwidth 0.2.13
yaml 6.0.2
zmq 26.2.1
zoneinfo NA
IPython 8.32.0
jupyter_client 8.6.3
jupyter_core 5.7.2
Python 3.10.16 | packaged by conda-forge | (main, Dec 5 2024, 14:16:10) [GCC 13.3.0]
Linux-4.18.0-553.36.1.el8_10.x86_64-x86_64-with-glibc2.28
Session information updated at 2025-02-21 19:01
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request