11import abc
22import inspect
3- from typing import TYPE_CHECKING , Iterator , Type
3+ from typing import TYPE_CHECKING , Any , Dict , Iterator , Optional , Type , Union
44
55import numpy as np
66
1818if TYPE_CHECKING :
1919 from pandas import DataFrame , Series , Index
2020
21+ ResType = Dict [int , Any ]
22+
2123
2224def frame_apply (
2325 obj : "DataFrame" ,
@@ -64,10 +66,15 @@ def result_index(self) -> "Index":
6466 def result_columns (self ) -> "Index" :
6567 pass
6668
69+ @property
6770 @abc .abstractmethod
6871 def series_generator (self ) -> Iterator ["Series" ]:
6972 pass
7073
74+ @abc .abstractmethod
75+ def wrap_results_for_axis (self , results : ResType ) -> Union ["Series" , "DataFrame" ]:
76+ pass
77+
7178 # ---------------------------------------------------------------
7279
7380 def __init__ (
@@ -107,8 +114,16 @@ def f(x):
107114
108115 # results
109116 self .result = None
110- self .res_index = None
111- self .res_columns = None
117+ self ._res_index : Optional ["Index" ] = None
118+
119+ @property
120+ def res_index (self ) -> "Index" :
121+ assert self ._res_index is not None
122+ return self ._res_index
123+
124+ @property
125+ def res_columns (self ) -> "Index" :
126+ return self .result_columns
112127
113128 @property
114129 def columns (self ) -> "Index" :
@@ -298,12 +313,12 @@ def apply_standard(self):
298313 return self .obj ._constructor_sliced (result , index = labels )
299314
300315 # compute the result using the series generator
301- self .apply_series_generator ()
316+ results = self .apply_series_generator ()
302317
303318 # wrap results
304- return self .wrap_results ()
319+ return self .wrap_results (results )
305320
306- def apply_series_generator (self ):
321+ def apply_series_generator (self ) -> ResType :
307322 series_gen = self .series_generator
308323 res_index = self .result_index
309324
@@ -330,17 +345,15 @@ def apply_series_generator(self):
330345 results [i ] = self .f (v )
331346 keys .append (v .name )
332347
333- self .results = results
334- self .res_index = res_index
335- self .res_columns = self .result_columns
348+ self ._res_index = res_index
349+ return results
336350
337- def wrap_results (self ):
338- results = self .results
351+ def wrap_results (self , results : ResType ) -> Union ["Series" , "DataFrame" ]:
339352
340353 # see if we can infer the results
341354 if len (results ) > 0 and 0 in results and is_sequence (results [0 ]):
342355
343- return self .wrap_results_for_axis ()
356+ return self .wrap_results_for_axis (results )
344357
345358 # dict of scalars
346359 result = self .obj ._constructor_sliced (results )
@@ -367,10 +380,9 @@ def result_index(self) -> "Index":
367380 def result_columns (self ) -> "Index" :
368381 return self .index
369382
370- def wrap_results_for_axis (self ) :
383+ def wrap_results_for_axis (self , results : ResType ) -> "DataFrame" :
371384 """ return the results for the rows """
372385
373- results = self .results
374386 result = self .obj ._constructor (data = results )
375387
376388 if not isinstance (results [0 ], ABCSeries ):
@@ -406,13 +418,13 @@ def result_index(self) -> "Index":
406418 def result_columns (self ) -> "Index" :
407419 return self .columns
408420
409- def wrap_results_for_axis (self ) :
421+ def wrap_results_for_axis (self , results : ResType ) -> Union [ "Series" , "DataFrame" ] :
410422 """ return the results for the columns """
411- results = self . results
423+ result : Union [ "Series" , "DataFrame" ]
412424
413425 # we have requested to expand
414426 if self .result_type == "expand" :
415- result = self .infer_to_same_shape ()
427+ result = self .infer_to_same_shape (results )
416428
417429 # we have a non-series and don't want inference
418430 elif not isinstance (results [0 ], ABCSeries ):
@@ -423,13 +435,12 @@ def wrap_results_for_axis(self):
423435
424436 # we may want to infer results
425437 else :
426- result = self .infer_to_same_shape ()
438+ result = self .infer_to_same_shape (results )
427439
428440 return result
429441
430- def infer_to_same_shape (self ) -> "DataFrame" :
442+ def infer_to_same_shape (self , results : ResType ) -> "DataFrame" :
431443 """ infer the results to the same shape as the input object """
432- results = self .results
433444
434445 result = self .obj ._constructor (data = results )
435446 result = result .T
0 commit comments