11from collections import Counter
22from logging import getLogger
3+ from typing import Optional , Tuple
34import pandas
45import numpy
56from .dataframe_helpers import dataframe_shuffle
@@ -449,14 +450,15 @@ def double_merge(d):
449450
450451
451452def train_test_apart_stratify (
452- df ,
453+ df : pandas . DataFrame ,
453454 group ,
454- test_size = 0.25 ,
455- train_size = None ,
456- stratify = None ,
457- force = False ,
458- random_state = None ,
459- ):
455+ test_size : Optional [float ] = 0.25 ,
456+ train_size : Optional [float ] = None ,
457+ stratify : Optional [str ] = None ,
458+ force : bool = False ,
459+ random_state : Optional [int ] = None ,
460+ sorted_indices : bool = False ,
461+ ) -> Tuple ["StreamingDataFrame" , "StreamingDataFrame" ]: # noqa: F821
460462 """
461463 This split is for a specific case where data is linked
462464 in one way. Let's assume we have two ids as we have
@@ -474,6 +476,8 @@ def train_test_apart_stratify(
474476 :param force: if True, tries to get at least one example on the test side
475477 for each value of the column *stratify*
476478 :param random_state: seed for random generators
479+ :param sorted_indices: sort index first,
480+ see issue `41 <https://github.com/sdpython/pandas-streaming/issues/41>`
477481 :return: Two see :class:`StreamingDataFrame
478482 <pandas_streaming.df.dataframe.StreamingDataFrame>`, one
479483 for train, one for test.
@@ -540,10 +544,15 @@ def train_test_apart_stratify(
540544
541545 split = {}
542546 for _ , k in sorted_hist :
543- not_assigned = [c for c in ids [k ] if c not in split ]
547+ indices = sorted (ids [k ]) if sorted_indices else ids [k ]
548+ not_assigned , assigned = [], []
549+ for c in indices :
550+ if c in split :
551+ assigned .append (c )
552+ else :
553+ not_assigned .append (c )
544554 if len (not_assigned ) == 0 :
545555 continue
546- assigned = [c for c in ids [k ] if c in split ]
547556 nb_test = sum (split [c ] for c in assigned )
548557 expected = min (len (ids [k ]), int (test_size * len (ids [k ]) + 0.5 )) - nb_test
549558 if force and expected == 0 and nb_test == 0 :
0 commit comments