Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor formatting changes (typehints) #149

Merged
merged 5 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ numpy==1.23.2
openpyxl==3.1.2
pandas==2.0.0
pytest-cov==3.0.0
scipy==1.9.0
statsmodels==0.13.2
tabulate==0.8.10
scipy==1.10.1
statsmodels==0.13.5
tabulate==0.9.0
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@
install_requires=[
'numpy>=1.19.1',
'pandas>=1.4.3',
'scipy>=1.7.0',
'statsmodels>=0.12.1',
'tabulate>=0.8.10',
'scipy>=1.10.1',
'statsmodels>=0.13.5',
'tabulate>=0.9.0',
'Jinja2==3.1.2',
'openpyxl==3.1.2'
],
Expand Down
76 changes: 46 additions & 30 deletions tableone/tableone.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
The tableone package is used for creating "Table 1" summary statistics for
research papers.
"""
from typing import Optional, Union
from typing import Optional, Tuple, Union
import warnings

import numpy as np
Expand All @@ -18,7 +18,7 @@
warnings.simplefilter('always', DeprecationWarning)


def load_dataset(name: str):
def load_dataset(name: str) -> pd.DataFrame:
"""
Load an example dataset from the online repository (requires internet).

Expand Down Expand Up @@ -60,7 +60,7 @@ class InputError(Exception):
pass


class TableOne(object):
class TableOne:
"""

If you use the tableone package, please cite:
Expand Down Expand Up @@ -200,7 +200,8 @@ class TableOne(object):

...
"""
def __init__(self, data: pd.DataFrame, columns: Optional[list] = None,
def __init__(self, data: pd.DataFrame,
columns: Optional[list] = None,
categorical: Optional[list] = None,
groupby: Optional[str] = None,
nonnormal: Optional[list] = None,
Expand Down Expand Up @@ -397,20 +398,23 @@ def __init__(self, data: pd.DataFrame, columns: Optional[list] = None,

# create overall tables if required
if self._categorical and self._groupby and overall:
self.cat_describe_all = self._create_cat_describe(data, False,
['Overall'])
self.cat_describe_all = self._create_cat_describe(data=data,
groupby=None,
groupbylvls=['Overall'])

if self._continuous and self._groupby and overall:
self.cont_describe_all = self._create_cont_describe(data, False)
self.cont_describe_all = self._create_cont_describe(data=data,
groupby=None)

# create descriptive tables
if self._categorical:
self.cat_describe = self._create_cat_describe(data, self._groupby,
self._groupbylvls)
self.cat_describe = self._create_cat_describe(data=data,
groupby=self._groupby,
groupbylvls=self._groupbylvls)

if self._continuous:
self.cont_describe = self._create_cont_describe(data,
self._groupby)
self.cont_describe = self._create_cont_describe(data=data,
groupby=self._groupby)

# compute standardized mean differences
if self._smd:
Expand Down Expand Up @@ -439,13 +443,13 @@ def __init__(self, data: pd.DataFrame, columns: Optional[list] = None,
if display_all:
self._set_display_options()

def __str__(self):
def __str__(self) -> str:
return self.tableone.to_string() + self._generate_remarks('\n')

def __repr__(self):
def __repr__(self) -> str:
return self.tableone.to_string() + self._generate_remarks('\n')

def _repr_html_(self):
def _repr_html_(self) -> str:
return self.tableone._repr_html_() + self._generate_remarks('<br />')

def _set_display_options(self):
Expand All @@ -465,7 +469,7 @@ def _set_display_options(self):
option.""".format(k)
warnings.warn(msg)

def tabulate(self, headers=None, tablefmt='grid', **kwargs):
def tabulate(self, headers=None, tablefmt='grid', **kwargs) -> str:
"""
Pretty-print tableone data. Wrapper for the Python 'tabulate' library.

Expand Down Expand Up @@ -500,7 +504,7 @@ def tabulate(self, headers=None, tablefmt='grid', **kwargs):

return tabulate(df, headers=headers, tablefmt=tablefmt, **kwargs)

def _generate_remarks(self, newline='\n'):
def _generate_remarks(self, newline='\n') -> str:
"""
Generate a series of remarks that the user should consider
when interpreting the summary statistics.
Expand Down Expand Up @@ -546,7 +550,7 @@ def _generate_remarks(self, newline='\n'):

return msg

def _detect_categorical_columns(self, data):
def _detect_categorical_columns(self, data) -> list:
"""
Detect categorical columns if they are not specified.

Expand Down Expand Up @@ -783,7 +787,7 @@ def _normality(self, x):

def _tukey(self, x, threshold):
"""
Count outliers according to Tukey's rule.
Find outliers according to Tukey's rule.

Where Q1 is the lower quartile and Q3 is the upper quartile,
an outlier is an observation outside of the range:
Expand All @@ -806,21 +810,21 @@ def _tukey(self, x, threshold):

return outliers

def _outliers(self, x):
def _outliers(self, x) -> int:
"""
Compute number of outliers
"""
outliers = self._tukey(x, threshold=1.5)
return np.size(outliers)

def _far_outliers(self, x):
def _far_outliers(self, x) -> int:
"""
Compute number of "far out" outliers
"""
outliers = self._tukey(x, threshold=3.0)
return np.size(outliers)

def _t1_summary(self, x):
def _t1_summary(self, x: pd.Series) -> str:
"""
Compute median [IQR] or mean (Std) for the input series.

Expand Down Expand Up @@ -867,7 +871,9 @@ def _t1_summary(self, x):
f = '{{:.{}f}} ({{:.{}f}})'.format(n, n)
return f.format(np.nanmean(x.values), self._std(x))

def _create_cont_describe(self, data, groupby):
def _create_cont_describe(self,
data: pd.DataFrame,
groupby: Optional[str] = None) -> pd.DataFrame:
"""
Describe the continuous data.

Expand Down Expand Up @@ -937,7 +943,10 @@ def _create_cont_describe(self, data, groupby):

return df_cont

def _format_cat(self, row, col):
def _format_cat(self, row, col) -> str:
"""
Format values to n decimal places.
"""
var = row.name[0]
if var in self._decimals:
n = self._decimals[var]
Expand All @@ -946,7 +955,9 @@ def _format_cat(self, row, col):
f = '{{:.{}f}}'.format(n)
return f.format(row[col])

def _create_cat_describe(self, data, groupby, groupbylvls):
def _create_cat_describe(self, data: pd.DataFrame,
groupby: Optional[str] = None,
groupbylvls: Optional[list] = None) -> pd.DataFrame:
"""
Describe the categorical data.

Expand Down Expand Up @@ -1054,7 +1065,7 @@ def _create_cat_describe(self, data, groupby, groupbylvls):

return df_cat

def _create_htest_table(self, data):
def _create_htest_table(self, data: pd.DataFrame) -> pd.DataFrame:
"""
Create a table containing P-Values for significance tests. Add features
of the distributions and the P-Values to the dataframe.
Expand Down Expand Up @@ -1119,7 +1130,7 @@ def _create_htest_table(self, data):

return df

def _create_smd_table(self, data):
def _create_smd_table(self, data: pd.DataFrame) -> pd.DataFrame:
"""
Create a table containing pairwise Standardized Mean Differences
(SMDs).
Expand Down Expand Up @@ -1180,8 +1191,13 @@ def _create_smd_table(self, data):

return df

def _p_test(self, v, grouped_data, is_continuous, is_categorical,
is_normal, min_observed, catlevels):
def _p_test(self, v: str,
grouped_data: dict,
is_continuous: bool,
is_categorical: bool,
is_normal: bool,
min_observed: int,
catlevels: list):
"""
Compute P-Values.

Expand Down Expand Up @@ -1267,7 +1283,7 @@ def _p_test(self, v, grouped_data, is_continuous, is_categorical,

return pval, ptest

def _create_cont_table(self, data, overall):
def _create_cont_table(self, data, overall) -> pd.DataFrame:
"""
Create tableone for continuous data.

Expand Down Expand Up @@ -1582,7 +1598,7 @@ def _create_tableone(self, data):

return table

def _create_row_labels(self):
def _create_row_labels(self) -> dict:
"""
Take the original labels for rows. Rename if alternative labels are
provided. Append label suffix if label_suffix is True.
Expand Down