# core

> Extends pandas with common functions used in finance and economics research

Almost all these functions make a copy of the input DataFrame. When that DataFrame is large, use these functions as `df = func(df)`.

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#|exports
from __future__ import annotations
from typing import List, Callable 
import os, glob 
import pandas as pd
import numpy as np

First, we set up an example dataset to showcase the functions in this module.

In [None]:
raw = pd.DataFrame(np.random.rand(15,2), 
                  columns=list('AB'), 
                  index=pd.MultiIndex.from_product(
                      [[1,2, np.nan],[np.nan,'2010-01','2010-02','2010-02','2010-04']],
                      names = ['permno','date'])
                    ).reset_index()
raw

Unnamed: 0,permno,date,A,B
0,1.0,,0.234344,0.698915
1,1.0,2010-01,0.104762,0.923778
2,1.0,2010-02,0.134963,0.287128
3,1.0,2010-02,0.561185,0.629761
4,1.0,2010-04,0.452953,0.137185
5,2.0,,0.242651,0.010276
6,2.0,2010-01,0.481129,0.734026
7,2.0,2010-02,0.663068,0.887246
8,2.0,2010-02,0.627125,0.525601
9,2.0,2010-04,0.413195,0.860653


## Common panel setup procedures

In [None]:
#|export
def order_columns(df: pd.DataFrame, these_first: List[str]) -> pd.DataFrame:
    """Returns `df` with reordered columns. Use as `df = order_columns(df,_)`"""
    
    remaining = [x for x in df.columns if x not in these_first]
    return df[these_first + remaining]

In [None]:
order_columns(raw, these_first=['B']).head()

Unnamed: 0,B,permno,date,A
0,0.698915,1.0,,0.234344
1,0.923778,1.0,2010-01,0.104762
2,0.287128,1.0,2010-02,0.134963
3,0.629761,1.0,2010-02,0.561185
4,0.137185,1.0,2010-04,0.452953


In [None]:
#|export
def process_dates(df: pd.DataFrame, # Function returns copy of this df with `dtdate_var` and `f'{freq}date'` cols added
                time_var: str='date', # This will be the date variable used to generate datetime var `dtdate_var`
                time_var_format: str='%Y-%m-%d', # Format of `time_var`; must be valid pandas `strftime`
                dtdate_var: str='dtdate', # Name of datetime var to be created from `time_var`
                freq: str=None, # Used to create `f'{freq}date'` period date; must be valid pandas offset string
                ) -> pd.DataFrame:
    """Makes datetime date `dtdate_var` from `time_var`; adds period date `f'{freq}date'`."""
    
    df = df.copy()
    df[dtdate_var] = pd.to_datetime(df[time_var], format=time_var_format)
    df[f'{freq}date'] = df['dtdate'].dt.to_period(freq)
    return order_columns(df, [time_var,dtdate_var,f'{freq}date'])

In [None]:
newdf = process_dates(raw, time_var_format="%Y-%m", freq='M')
newdf.head()

Unnamed: 0,date,dtdate,Mdate,permno,A,B
0,,NaT,NaT,1.0,0.234344,0.698915
1,2010-01,2010-01-01,2010-01,1.0,0.104762,0.923778
2,2010-02,2010-02-01,2010-02,1.0,0.134963,0.287128
3,2010-02,2010-02-01,2010-02,1.0,0.561185,0.629761
4,2010-04,2010-04-01,2010-04,1.0,0.452953,0.137185


In [None]:
#|export
def setup_tseries(df: pd.Series|pd.DataFrame, # Input DataFrame; a copy is returned
# Params passed to `process_dates`
                dates_processed: bool=False, # If True, assumes dates are already processed with `process_dates`
                time_var: str='date', # This will be the date variable used to generate datetime var `dtdate_var`
                time_var_format: str='%Y-%m-%d', # Format of `time_var`; must be valid pandas `strftime`
                dtdate_var: str='dtdate', # Name of datetime var to be created from `time_var`
                freq: str=None, # Used to create `f'{freq}date'` period date; must be valid pandas offset string
# Params for cleaning dates                
                drop_missing_index_vals: bool=True, # What to do with missing `f'{freq}date'`
                drop_index_duplicates: bool=True, # What to do with duplicates in `f'{freq}date'` values
                duplicates_which_keep: str='last', # If duplicates in index, which to keep; must be 'first', 'last' or `False`
                ) -> pd.DataFrame:
    """Applies `process_dates` to `df`; cleans up resulting `f'{freq}date'` period date and sets it as index."""

    if isinstance(df, pd.Series): df = df.to_frame()
    if not dates_processed:
        df = process_dates(df, time_var=time_var, time_var_format=time_var_format, dtdate_var=dtdate_var, freq=freq)

    if drop_missing_index_vals:
        df = df.dropna(subset=[time_var])
    df = df.sort_values([dtdate_var])
    if drop_index_duplicates:
        df = df.drop_duplicates(subset=[f'{freq}date'], keep=duplicates_which_keep)
    df = df.set_index([f'{freq}date']) 
    return order_columns(df,[time_var,dtdate_var]) 

In [None]:
raw.query('permno==1')

Unnamed: 0,permno,date,A,B
0,1.0,,0.234344,0.698915
1,1.0,2010-01,0.104762,0.923778
2,1.0,2010-02,0.134963,0.287128
3,1.0,2010-02,0.561185,0.629761
4,1.0,2010-04,0.452953,0.137185


In [None]:
df = setup_tseries(raw.query('permno==1'),
                 time_var='date', time_var_format="%Y-%m",
                 freq='M')
df

Unnamed: 0_level_0,date,dtdate,permno,A,B
Mdate,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2010-01,2010-01,2010-01-01,1.0,0.104762,0.923778
2010-02,2010-02,2010-02-01,1.0,0.561185,0.629761
2010-04,2010-04,2010-04-01,1.0,0.452953,0.137185


In [None]:
#|export
def setup_panel(df: pd.DataFrame, # Input DataFrame; a copy is returned
                panel_ids :str=None, # Name of variable that identifies panel entities
# Params passed to `process_dates`
                dates_processed: bool=False, # If True, assumes dates are already processed with `process_dates`
                time_var: str='date', # This will be the date variable used to generate datetime var `dtdate_var`
                time_var_format: str='%Y-%m-%d', # Format of `time_var`; must be valid pandas `strftime`
                dtdate_var: str='dtdate', # Name of datetime var to be created from `time_var`
                freq: str=None, # Used to create `f'{freq}date'` period date; must be valid pandas offset string
# Params for cleaning panel_ids and dates                
                drop_missing_index_vals: bool=True, # What to do with missing `panel_ids` or `f'{freq}date'`
                panel_ids_toint: str='Int64', # Converts `panel_ids` to int in place; use falsy value if not wanted
                drop_index_duplicates: bool=True, # What to do with duplicates in (`panel_ids`, `f'{freq}date'`) values
                duplicates_which_keep: str='last', # If duplicates in index, which to keep; must be 'first', 'last' or `False`
                ) -> pd.DataFrame:
    """Applies `process_dates` to `df`; cleans up (`panel_ids` ,`f'{freq}date'`) and sets it as index."""

    if not dates_processed:
        df = process_dates(df, time_var=time_var, time_var_format=time_var_format, dtdate_var=dtdate_var, freq=freq)
    if drop_missing_index_vals:
        df = df.dropna(subset=[panel_ids,time_var])
    if panel_ids_toint:
        df[panel_ids] = df[panel_ids].astype(panel_ids_toint)
    df = df.sort_values([panel_ids, dtdate_var])
    if drop_index_duplicates:
        df = df.drop_duplicates(subset=[panel_ids, f'{freq}date'], keep=duplicates_which_keep)
    df = df.set_index([panel_ids, f'{freq}date'])
    return order_columns(df,[time_var,dtdate_var]) 

In [None]:
df = setup_panel(raw,
                 panel_ids='permno',
                 time_var='date', time_var_format="%Y-%m",
                 freq='M')
df

Unnamed: 0_level_0,Unnamed: 1_level_0,date,dtdate,A,B
permno,Mdate,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,2010-01,2010-01,2010-01-01,0.104762,0.923778
1,2010-02,2010-02,2010-02-01,0.561185,0.629761
1,2010-04,2010-04,2010-04-01,0.452953,0.137185
2,2010-01,2010-01,2010-01-01,0.481129,0.734026
2,2010-02,2010-02,2010-02-01,0.627125,0.525601
2,2010-04,2010-04,2010-04-01,0.413195,0.860653


## Robust lagging

Lagging with `shift` fails when we have (1) panel data, (2) gaps in the time-series, (3) duplicate dates, (4) data is not sorted by dates (5) NaN dates.

The `fast_lag` function below correctly lags data (using `shift()`), assuming we do not have problems (2), (3), (4), and (5).

The `lag` function below correctly lags data (using `merge()`), assuming we do not have problem (5).

In [None]:
#|export
def fast_lag(df: pd.Series|pd.DataFrame, # Index of `df` (or level 1 of MultiIndex) must be pandas period date.
        n: int=1, # Number of periods to lag based on frequency of df.index; Negative values means lead.
        ) -> pd.Series: # Series with lagged values of `df`; Name is taken from `df.columns[0]`, with '_lag{n}' or '_lead{n}' suffixed.
    """Lag data in `df` by `n` periods. 
    ASSUMES DATA IS SORTED BY DATES AND HAS NO DUPLICATE OR NaN DATES, AND NO GAPS IN THE TIME SERIES.
    Apply `df = setup_panel(df)` before using."""

    if isinstance(df,pd.Series): df = df.to_frame()
    if len(df.columns) > 1: raise ValueError("<df> must have a single column")
    dfl = df.copy()
    old_name = str(df.columns[0])
    new_varname = old_name + f'_lag{n}' if n>=0 else old_name + f'_lead{-n}'
    
    if isinstance(df.index, pd.MultiIndex):
        if f'{df.index.levels[1].dtype}'.startswith('period'):
            (panelvar, timevar) = dfl.index.names
            dfl = dfl.reset_index()
            dfl[['lag_panel','lag_time',new_varname]] = dfl[[panelvar, timevar, old_name]].shift(n)
            dfl[new_varname] = np.where((dfl[panelvar]==dfl['lag_panel']) & (dfl[timevar]==dfl['lag_time']+n),
                                        dfl[new_varname], np.nan)
            dfl = dfl.set_index([panelvar, timevar])
        else:
            raise ValueError('Dimension 1 of multiindex must be period date')
    else:
        if f'{df.index.dtype}'.startswith('period'):
            timevar = dfl.index.name
            dfl = dfl.reset_index()
            dfl[['lag_time',new_varname]] = dfl[[timevar, old_name]].shift(n)
            dfl[new_varname] = np.where((dfl[timevar]==dfl['lag_time']+n),
                                        dfl[new_varname], np.nan)
            dfl = dfl.set_index([timevar])
        else:
            raise ValueError('Index must be period date')
    return dfl[new_varname].squeeze()

In [None]:
#|export
def lag(df: pd.Series|pd.DataFrame, # Index (or level 1 of MultiIndex) must be period date with no missing values.
        n: int=1, # Number of periods to lag based on frequency of `df.index`; Negative values means lead.
        fast: bool=False, # If True, uses `fast_lag()`, which assumes data is sorted by date and has no duplicate or missing dates
        ) -> pd.Series: # Series with lagged values of `df`; Name is taken from `df.columns[0]`, with '_lag{n}' or '_lead{n}' suffixed.
    """Lag data in 'df' by 'n' periods. ASSUMES NO NaN DATES. Apply `df = setup_panel(df)` before using."""

    if fast: return fast_lag(df,n)

    if isinstance(df,pd.Series): df = df.to_frame()
    if len(df.columns) > 1: raise ValueError("'df' parameter must have a single column")
    dfl = df.copy()
    dfl.columns = [str(df.columns[0]) + f'_lag{n}'] if n>=0 else df.columns + f'_lead{-n}'

    if isinstance(df.index, pd.MultiIndex):
        if f'{df.index.levels[1].dtype}'.startswith('period'):
            dfl.index = dfl.index.set_levels(df.index.levels[1]+n, level=1)
        else:
            raise ValueError('Dimension 1 of multiindex must be period date')
    else:
        if f'{df.index.dtype}'.startswith('period'):
            dfl.index += n
        else:
            raise ValueError('Index must be period date')

    dfl = df.join(dfl).drop(columns=df.columns)
    return dfl.squeeze()

The index of the `df` parameter can not contain missing values.

In [None]:
lag(df['A'])

permno  Mdate  
1       2010-01         NaN
        2010-02    0.104762
        2010-04         NaN
2       2010-01         NaN
        2010-02    0.481129
        2010-04         NaN
Name: A_lag1, dtype: float64

In [None]:
lag(df['A'],fast=False)

permno  Mdate  
1       2010-01         NaN
        2010-02    0.104762
        2010-04         NaN
2       2010-01         NaN
        2010-02    0.481129
        2010-04         NaN
Name: A_lag1, dtype: float64

In [None]:
#|export
def add_lags(df: pd.Series|pd.DataFrame, # If pd.Series, it must have a name equal to `vars` param
             vars: str|List[str], # Variables to be lagged; must be a subset of `df.columns()`
             lags: int|List[int]=1, # Which lags to be added
             lag_suffix: str='_lag', # Used to create new lagged variable names
             lead_suffix: str='_lead', # Used to create new lead variable names
             use_fast_lags: bool=False, # Weather to use `fast_lag()` function when lagging
             ) -> pd.DataFrame:
    """Returns a copy of `df` with all `lags` of all `vars` added to it."""

    df = df.copy()
    if isinstance(df, pd.Series): df = df.to_frame()  
    if isinstance(vars, str): vars = [vars]
    if isinstance(lags, int): lags = [lags]

    for var in vars:
        for n in lags:
            suffix = f'{lag_suffix}{n}' if n>=0 else f'{lead_suffix}{-n}'
            df[f'{var}{suffix}'] = lag(df[var], n, use_fast_lags)
    return df

Because this makes a copy of `df`, when `df` is a large dataset, this should be used as `df = add_lags(df)`.

In [None]:
add_lags(df['A'], vars='A')

Unnamed: 0_level_0,Unnamed: 1_level_0,A,A_lag1
permno,Mdate,Unnamed: 2_level_1,Unnamed: 3_level_1
1,2010-01,0.104762,
1,2010-02,0.561185,0.104762
1,2010-04,0.452953,
2,2010-01,0.481129,
2,2010-02,0.627125,0.481129
2,2010-04,0.413195,


In [None]:
add_lags(df, vars=['A','B'], lags=[2,-1])

Unnamed: 0_level_0,Unnamed: 1_level_0,date,dtdate,A,B,A_lag2,A_lead1,B_lag2,B_lead1
permno,Mdate,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
1,2010-01,2010-01,2010-01-01,0.104762,0.923778,,0.561185,,0.629761
1,2010-02,2010-02,2010-02-01,0.561185,0.629761,,,,
1,2010-04,2010-04,2010-04-01,0.452953,0.137185,0.561185,,0.629761,
2,2010-01,2010-01,2010-01-01,0.481129,0.734026,,0.627125,,0.525601
2,2010-02,2010-02,2010-02-01,0.627125,0.525601,,,,
2,2010-04,2010-04,2010-04-01,0.413195,0.860653,0.627125,,0.525601,


In [None]:
add_lags(df,vars=['A','B'],lags=[2,-2], lag_suffix='_lg', lead_suffix='_ld')

Unnamed: 0_level_0,Unnamed: 1_level_0,date,dtdate,A,B,A_lg2,A_ld2,B_lg2,B_ld2
permno,Mdate,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
1,2010-01,2010-01,2010-01-01,0.104762,0.923778,,,,
1,2010-02,2010-02,2010-02-01,0.561185,0.629761,,0.452953,,0.137185
1,2010-04,2010-04,2010-04-01,0.452953,0.137185,0.561185,,0.629761,
2,2010-01,2010-01,2010-01-01,0.481129,0.734026,,,,
2,2010-02,2010-02,2010-02-01,0.627125,0.525601,,0.413195,,0.860653
2,2010-04,2010-04,2010-04-01,0.413195,0.860653,0.627125,,0.525601,


And remember that by default, `lag` uses `fast=True`, which is not robust to duplicate dates (or unsorted dates).

## Utilities using robust lagging

In [None]:
#|export
def rpct_change(df: pd.Series, n: int=1, use_fast_lags=False):
    """Percentage change using robust `lag()` or `fast_lag()` function."""
    return df / lag(df, n, use_fast_lags) - 1

In [None]:
rpct_change(df['A'])

permno  Mdate  
1       2010-01         NaN
        2010-02    4.356736
        2010-04         NaN
2       2010-01         NaN
        2010-02    0.303446
        2010-04         NaN
dtype: float64

In [None]:
#|export
def rdiff(df: pd.Series, n: int=1, use_fast_lags=False):
    """Difference using robust `lag()` or `fast_lag()` function."""
    return df - lag(df, n, use_fast_lags)

In [None]:
rdiff(df['A'])

permno  Mdate  
1       2010-01         NaN
        2010-02    0.456422
        2010-04         NaN
2       2010-01         NaN
        2010-02    0.145997
        2010-04         NaN
dtype: float64

In [None]:
#|export
def rrolling(df: pd.Series|pd.DataFrame, # Must have period date Index (if Series) or (panel_id, period_date) Multiindex (if DataFrame) 
            func: str, # Name of any pandas aggregation function (to applied to `df` data within each rolling window
            window:int=None, # Rolling window length; if None, uses 'expanding' without fixing lags 
            skipna: bool|None=False, # Use None if `func` does not take `skipna` arg.
            use_fast_lags: bool=False
            ) -> pd.Series:
    """Like `pd.DataFrame.rolling` but using robust `lag`s. 
    Run `df = setup_tseries(df)` or `df = setup_panel(df)` prior to using."""

    if isinstance(df,pd.Series): df = df.to_frame()
    if len(df.columns) > 1: raise ValueError("`df` must have a single column")
    varname = df.columns[0]
    out = df.copy()

    if window:
        out = add_lags(out, vars=varname, lags=range(window), use_fast_lags=use_fast_lags)

        if skipna is None:
            return getattr(out[[f'{varname}_lag{n}' for n in range(window)]], func)(axis=1)
        else:
            return getattr(out[[f'{varname}_lag{n}' for n in range(window)]], func)(axis=1, skipna=skipna)
    else:
        if skipna is None:
            return getattr(df.groupby(axis=0, level=0).expanding(), func)().droplevel(0)
        else:
            return getattr(df.groupby(axis=0, level=0).expanding(), func)(skipna=skipna).droplevel(0)


In [None]:
df.assign(rolling_A = rrolling(df['A'], func='mean', window=2, skipna=True))

Unnamed: 0_level_0,Unnamed: 1_level_0,date,dtdate,A,B,rolling_A
permno,Mdate,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1,2010-01,2010-01,2010-01-01,0.104762,0.923778,0.104762
1,2010-02,2010-02,2010-02-01,0.561185,0.629761,0.332974
1,2010-04,2010-04,2010-04-01,0.452953,0.137185,0.452953
2,2010-01,2010-01,2010-01-01,0.481129,0.734026,0.481129
2,2010-02,2010-02,2010-02-01,0.627125,0.525601,0.554127
2,2010-04,2010-04,2010-04-01,0.413195,0.860653,0.413195


In [None]:
df.assign(rolling_A = rrolling(df['A'], func='mean', window=2, skipna=False))

Unnamed: 0_level_0,Unnamed: 1_level_0,date,dtdate,A,B,rolling_A
permno,Mdate,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1,2010-01,2010-01,2010-01-01,0.104762,0.923778,
1,2010-02,2010-02,2010-02-01,0.561185,0.629761,0.332974
1,2010-04,2010-04,2010-04-01,0.452953,0.137185,
2,2010-01,2010-01,2010-01-01,0.481129,0.734026,
2,2010-02,2010-02,2010-02-01,0.627125,0.525601,0.554127
2,2010-04,2010-04,2010-04-01,0.413195,0.860653,


Test that it works for a time-series, not just a panel

In [None]:
df['A'].loc[1]

Mdate
2010-01    0.104762
2010-02    0.561185
2010-04    0.452953
Freq: M, Name: A, dtype: float64

In [None]:
rrolling(df['A'].loc[1], func='mean', window=2, skipna=True)

Mdate
2010-01    0.104762
2010-02    0.332974
2010-04    0.452953
Freq: M, dtype: float64

## Very common data transformations

In [None]:
#|export
def wins(df: pd.Series|pd.DataFrame, 
         low = 0.01, # Lower quantile at which to winsorize
         high = 0.99, # Upper quantile at which to winsorize
         byvars: List[str]=None # If None, quantiles use full sample, o/w they are calculate within each group given by `byvars`
         ) -> pd.DataFrame:
    """Winsorizes all columns in `df`."""

    if isinstance(df,pd.Series): df = df.to_frame()
    if byvars:
        return (df.groupby(byvars)
                    .apply(lambda x: df[x].clip(df[x].quantile(low), df[x].quantile(high), axis=1))
                    .reset_index()
                    .set_index(df.index))
    else:
        return df.clip(df.quantile(low), df.quantile(high), axis=1).squeeze()

In [None]:
#|export
def norm(df: pd.Series|pd.DataFrame, 
         divide_by_mean = False
         ) -> pd.DataFrame:
    """Subtract means from all columns of `df` and divide by their std. deviations, unless `divide_by_mean` is True"""

    if isinstance(df,pd.Series): df = df.to_frame()
    if divide_by_mean:
        return (df.copy() - df.mean()) / df.mean()
    else:
        return (df.copy() - df.mean()) / df.std()

## I/O

In [None]:
#| export
def to_stata(df: pd.DataFrame=None,
             outfile: str=None, # Output file path; must include .dta extension
             obj_drop: bool=False, # Whether to drop all columns of `object` type
             obj_to_str: bool=False, # Whether to convert all columns of `object` type to `string` type
             **to_stata_kwargs # Other kwargs to pass to `pd.to_stata`
             ):
    """Writes `df` to stata `outfile` """

    if df.index.equals(pd.RangeIndex(start=0, stop=len(df), step=1)): df = df.copy()
    else: df = df.reset_index().copy()

    #Deal with `object` and `string` data types
    for v in list(df.columns):
        if df[v].dtype=='string': df[v] = df[v].fillna('').astype(str)
        if df[v].dtype=='object':
            if obj_drop: df = df.drop(v, axis=1)
            elif obj_to_str and df[v].dropna().apply(lambda x: isinstance(x, str)).all():
                df[v] = df[v].fillna('').astype(str)

    #Deal with time data
    dates_to_td = {}
    for v in list(df.columns):
        if str(df[v].dtype).startswith('period'): df = df.drop(v, axis=1)
        elif df[v].dtype=='datetime64[ns]':
            if df[v].apply(lambda x: x.tz is not None).any(): df = df.drop(v, axis=1)
            else: dates_to_td[v] = 'td'
        pass
    
    df.to_stata(outfile, convert_dates=dates_to_td, write_index=False, **to_stata_kwargs)   

Index data is automatically included in the output, unless the index is the default range from 0 to len(df).

Columns of `object` data type are by default left to `pd.to_stata` to figure out how to convert. Note that columns of strings with missing data might cause an error in this default case. In this case, the best thing to do is for you convert it to `string` type before calling this function. If not, setting `obj_to_str` to True will deal with this internally but it will be slower.

For time data, columns of `period` type and columns of `datetime64[ns]` type with time zone information will be dropped. All other `datetime64[ns]` variables will be converted to `td` dates in Stata. 

In [None]:
to_stata(df, outfile='../data/df.dta', version=117)

In [None]:
df = pd.read_stata('../data/df.dta')
df

Unnamed: 0,permno,date,dtdate,A,B
0,1,2010-01,2010-01-01,0.104762,0.923778
1,1,2010-02,2010-02-01,0.561185,0.629761
2,1,2010-04,2010-04-01,0.452953,0.137185
3,2,2010-01,2010-01-01,0.481129,0.734026
4,2,2010-02,2010-02-01,0.627125,0.525601
5,2,2010-04,2010-04-01,0.413195,0.860653


# Sorting and binning

In [None]:
#| export
def bins_using_masked_cutoffs(df: pd.DataFrame=None, #Dataframe containing `sortvar` and `maskvar`. Must have panelvar x datevar multiindex 
                       sortvar:str=None, #Variable containing the values to be binned
                       maskvar: str=None, #Mask to be applied to `df[sortvar]` before bin cutoffs are calculated
                       quantiles: list=None, #List of quantiles to be applied to df.loc[df[maskvar], sortvar] to determine bin cutoffs 
                       outvar:str=None #Name to give to the column of bins created. If none, will use f"{sortvar}_bins"
) ->pd.DataFrame:
    """Returns column of bin numbers (1 to len(`quantiles`)) created by binning `sortvar` based on cuttoffs give by `quantiles` of `df.loc[df[maskvar], sortvar]`"""

    if outvar is None: outvar = f"{sortvar}_bins"

    (panelvar, datevar) = df.index.names
    df = df.reset_index()[[sortvar, maskvar, datevar, panelvar]].copy()

    # Get cutoffs every time period
    cutoffs = (df.loc[df[maskvar]]
                .groupby(datevar)[sortvar]
                .quantile(quantiles).to_frame().unstack() )              

    #Clean up cutoff dataset
    cnames = [sortvar + "_" + str(x) for x in quantiles]
    cutoffs.columns = cnames
    cutoffs = cutoffs.reset_index()    
    df = df.merge(cutoffs, how = "left", on = datevar)

    df[outvar] = np.nan #code for missing sortvar
    df.loc[(df[sortvar] < df[cnames[0]]) & df[sortvar].notna(), outvar] = 1 #first bin
    for c in range(1,len(cnames)+1):
        df.loc[(df[sortvar] >= df[cnames[c-1]]) & df[sortvar].notna(), outvar] = c + 1
        
    return df.set_index([panelvar,datevar])[[outvar]].copy()

In [None]:
#| hide 
for f in glob.glob('../data/*'): os.remove(f)
with open('../data/.gitkeep', 'w') as f: pass #empty file to force github to track the data folder

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()