Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.6, 3.7, 3.8, 3.9, "3.10"]
fail-fast: false
python-version: [3.7, 3.8, 3.9, "3.10"]
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
steps:
Expand All @@ -60,6 +59,7 @@ jobs:
pip install codecov

- name: Run flake8, pylint, mypy
if: matrix.python-version == '3.10'
run: |
flake8 cmdstanpy test
pylint -v cmdstanpy test
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ repos:
- id: mypy
# Copied from setup.cfg
exclude: ^test/
additional_dependencies: [ numpy >= 1.21 ]
additional_dependencies: [ numpy >= 1.22, types-ujson ]
# local uses the user-installed pylint, this allows dependency checking
- repo: local
hooks:
Expand Down
5 changes: 1 addition & 4 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,7 @@ def validate(self, chains: Optional[int]) -> None:
if all(isinstance(elem, dict) for elem in self.metric):
metric_files: List[str] = []
for i, metric in enumerate(self.metric):
assert isinstance(
metric, dict
) # make the typechecker happy
metric_dict: Dict[str, Any] = metric
metric_dict: Dict[str, Any] = metric # type: ignore
if 'inv_metric' not in metric_dict:
raise ValueError(
'Entry "inv_metric" not found in metric dict '
Expand Down
26 changes: 11 additions & 15 deletions cmdstanpy/stanfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def __init__(
self._save_warmup = sampler_args.save_warmup
self._sig_figs = runset._args.sig_figs
# info from CSV values, instantiated lazily
self._metric = np.array(())
self._step_size = np.array(())
self._draws = np.array(())
self._metric: np.ndarray = np.array(())
self._step_size: np.ndarray = np.array(())
self._draws: np.ndarray = np.array(())
# info from CSV initial comments and header
config = self._validate_csv_files()
self._metadata: InferenceMetadata = InferenceMetadata(config)
Expand Down Expand Up @@ -246,7 +246,7 @@ def draws(

if concat_chains:
return flatten_chains(self._draws[start_idx:, :, :])
return self._draws[start_idx:, :, :] # type: ignore
return self._draws[start_idx:, :, :]

def _validate_csv_files(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -675,9 +675,7 @@ def stan_variable(
if len(col_idxs) > 0:
dims.extend(self._metadata.stan_vars_dims[var])
# pylint: disable=redundant-keyword-arg
return self._draws[draw1:, :, col_idxs].reshape( # type: ignore
dims, order='F'
)
return self._draws[draw1:, :, col_idxs].reshape(dims, order='F')

def stan_variables(self) -> Dict[str, np.ndarray]:
"""
Expand Down Expand Up @@ -748,7 +746,7 @@ def __init__(
)
self.runset = runset
self.mcmc_sample = mcmc_sample
self._draws = np.array(())
self._draws: np.ndarray = np.array(())
config = self._validate_csv_files()
self._metadata = InferenceMetadata(config)

Expand All @@ -765,7 +763,7 @@ def __repr__(self) -> str:
)
return repr

def _validate_csv_files(self) -> dict:
def _validate_csv_files(self) -> Dict[str, Any]:
"""
Checks that Stan CSV output files for all chains are consistent
and returns dict containing config and column names.
Expand Down Expand Up @@ -910,13 +908,13 @@ def draws(
if concat_chains:
return flatten_chains(self._draws[start_idx:, :, :])
if inc_sample:
return np.dstack( # type: ignore
return np.dstack(
(
np.delete(self.mcmc_sample.draws(), drop_cols, axis=1),
self._draws,
)
)[start_idx:, :, :]
return self._draws[start_idx:, :, :] # type: ignore
return self._draws[start_idx:, :, :]

def draws_pd(
self,
Expand Down Expand Up @@ -1195,9 +1193,7 @@ def stan_variable(
if len(col_idxs) > 0:
dims.extend(self._metadata.stan_vars_dims[var])
# pylint: disable=redundant-keyword-arg
return self._draws[draw1:, :, col_idxs].reshape( # type: ignore
dims, order='F'
)
return self._draws[draw1:, :, col_idxs].reshape(dims, order='F')

def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
"""
Expand Down Expand Up @@ -1229,7 +1225,7 @@ def _assemble_generated_quantities(self) -> None:
# use numpy loadtxt
warmup = self.mcmc_sample.metadata.cmdstan_config['save_warmup']
num_draws = self.mcmc_sample.draws(inc_warmup=warmup).shape[0]
gq_sample = np.empty(
gq_sample: np.ndarray = np.empty(
(num_draws, self.chains, len(self.column_names)),
dtype=float,
order='F',
Expand Down
18 changes: 6 additions & 12 deletions cmdstanpy/stanfit/mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,9 @@ def _set_mle_attrs(self, sample_csv_0: str) -> None:
meta = scan_optimize_csv(sample_csv_0, self._save_iterations)
self._metadata = InferenceMetadata(meta)
self._column_names: Tuple[str, ...] = meta['column_names']
assert isinstance(meta['mle'], np.ndarray) # make the typechecker happy
self._mle = meta['mle']
self._mle: np.ndarray = meta['mle']
if self._save_iterations:
assert isinstance(
meta['all_iters'], np.ndarray
) # make the typechecker happy
self._all_iters = meta['all_iters']
self._all_iters: np.ndarray = meta['all_iters']

@property
def column_names(self) -> Tuple[str, ...]:
Expand Down Expand Up @@ -202,13 +198,13 @@ def stan_variable(
num_rows = self._all_iters.shape[0]
else:
num_rows = 1

result: Union[np.ndarray, float]
if len(col_idxs) > 1: # container var
dims = (num_rows,) + self._metadata.stan_vars_dims[var]
# pylint: disable=redundant-keyword-arg
if num_rows > 1:
result = self._all_iters[:, col_idxs].reshape( # type: ignore
dims, order='F'
)
result = self._all_iters[:, col_idxs].reshape(dims, order='F')
else:
result = self._mle[col_idxs].reshape(dims[1:], order="F")
else: # scalar var
Expand All @@ -217,9 +213,7 @@ def stan_variable(
result = self._all_iters[:, col_idx]
else:
result = float(self._mle[col_idx])
assert isinstance(
result, (np.ndarray, float)
) # make the typechecker happy

return result

def stan_variables(
Expand Down
2 changes: 1 addition & 1 deletion cmdstanpy/stanfit/vb.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ def stan_variable(
raise ValueError('Unknown variable name: {}'.format(var))
col_idxs = list(self._metadata.stan_vars_cols[var])
shape: Tuple[int, ...] = ()
result: Union[np.ndarray, float]
if len(col_idxs) > 1:
shape = self._metadata.stan_vars_dims[var]
result = np.asarray(self._variational_mean)[col_idxs].reshape(
shape, order="F"
)
else:
result = float(self._variational_mean[col_idxs[0]])
assert isinstance(result, (np.ndarray, float))
return result

def stan_variables(self) -> Dict[str, Union[np.ndarray, float]]:
Expand Down
6 changes: 3 additions & 3 deletions cmdstanpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
for line in fd:
iters += 1
if save_iters:
all_iters = np.empty(
all_iters: np.ndarray = np.empty(
(iters, len(dict['column_names'])), dtype=float, order='F'
)
# rescan to capture estimates
Expand All @@ -658,7 +658,7 @@ def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
if save_iters:
all_iters[i, :] = [float(x) for x in xs]
if i == iters - 1:
mle = np.array(xs, dtype=float)
mle: np.ndarray = np.array(xs, dtype=float)
dict['mle'] = mle
if save_iters:
dict['all_iters'] = all_iters
Expand Down Expand Up @@ -944,7 +944,7 @@ def read_metric(path: str) -> List[int]:
with open(path, 'r') as fd:
metric_dict = json.load(fd)
if 'inv_metric' in metric_dict:
dims_np = np.asarray(metric_dict['inv_metric'])
dims_np: np.ndarray = np.asarray(metric_dict['inv_metric'])
return list(dims_np.shape)
else:
raise ValueError(
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@ line_length = 80
disallow_untyped_defs = true
disallow_incomplete_defs = true
no_implicit_optional = true
# disallow_any_generics = true # disabled due to issues with numpy < 1.20
# disallow_any_generics = true # disabled due to issues with numpy
warn_return_any = true
# warn_unused_ignores = true # can't be run on CI due to windows having different ctypes
check_untyped_defs = true
warn_redundant_casts = true
strict_equality = true
disallow_untyped_calls = true

[[tool.mypy.overrides]]
module = [
'tqdm.auto',
'pandas',
'ujson',
'numpy', # these two are required for py36, which numpy 1.21 doesn't support
'numpy.random'
]
ignore_missing_imports = true
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ mypy
testfixtures
tqdm
xarray
types-ujson
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def get_version() -> str:
]
},
install_requires=INSTALL_REQUIRES,
python_requires='>=3.7',
extras_require=EXTRAS_REQUIRE,
classifiers=_classifiers.strip().split('\n'),
)