diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6f58f303..ed897ccc 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -39,8 +39,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: [3.6, 3.7, 3.8, 3.9, "3.10"] - fail-fast: false + python-version: [3.7, 3.8, 3.9, "3.10"] env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} steps: @@ -60,6 +59,7 @@ jobs: pip install codecov - name: Run flake8, pylint, mypy + if: matrix.python-version == '3.10' run: | flake8 cmdstanpy test pylint -v cmdstanpy test diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9c846751..c32c7de3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,7 @@ repos: - id: mypy # Copied from setup.cfg exclude: ^test/ - additional_dependencies: [ numpy >= 1.21 ] + additional_dependencies: [ numpy >= 1.22, types-ujson ] # local uses the user-installed pylint, this allows dependency checking - repo: local hooks: diff --git a/cmdstanpy/cmdstan_args.py b/cmdstanpy/cmdstan_args.py index b339f7a5..f5553153 100644 --- a/cmdstanpy/cmdstan_args.py +++ b/cmdstanpy/cmdstan_args.py @@ -202,10 +202,7 @@ def validate(self, chains: Optional[int]) -> None: if all(isinstance(elem, dict) for elem in self.metric): metric_files: List[str] = [] for i, metric in enumerate(self.metric): - assert isinstance( - metric, dict - ) # make the typechecker happy - metric_dict: Dict[str, Any] = metric + metric_dict: Dict[str, Any] = metric # type: ignore if 'inv_metric' not in metric_dict: raise ValueError( 'Entry "inv_metric" not found in metric dict ' diff --git a/cmdstanpy/stanfit/mcmc.py b/cmdstanpy/stanfit/mcmc.py index 70259b8d..571e75c9 100644 --- a/cmdstanpy/stanfit/mcmc.py +++ b/cmdstanpy/stanfit/mcmc.py @@ -95,9 +95,9 @@ def __init__( self._save_warmup = sampler_args.save_warmup self._sig_figs = runset._args.sig_figs # info from CSV values, instantiated lazily - self._metric = np.array(()) - self._step_size = np.array(()) - self._draws = np.array(()) + self._metric: np.ndarray = np.array(()) + self._step_size: np.ndarray = np.array(()) + self._draws: np.ndarray = np.array(()) # info from CSV initial comments and header config = self._validate_csv_files() self._metadata: InferenceMetadata = InferenceMetadata(config) @@ -246,7 +246,7 @@ def draws( if concat_chains: return flatten_chains(self._draws[start_idx:, :, :]) - return self._draws[start_idx:, :, :] # type: ignore + return self._draws[start_idx:, :, :] def _validate_csv_files(self) -> Dict[str, Any]: """ @@ -675,9 +675,7 @@ def stan_variable( if len(col_idxs) > 0: dims.extend(self._metadata.stan_vars_dims[var]) # pylint: disable=redundant-keyword-arg - return self._draws[draw1:, :, col_idxs].reshape( # type: ignore - dims, order='F' - ) + return self._draws[draw1:, :, col_idxs].reshape(dims, order='F') def stan_variables(self) -> Dict[str, np.ndarray]: """ @@ -748,7 +746,7 @@ def __init__( ) self.runset = runset self.mcmc_sample = mcmc_sample - self._draws = np.array(()) + self._draws: np.ndarray = np.array(()) config = self._validate_csv_files() self._metadata = InferenceMetadata(config) @@ -765,7 +763,7 @@ def __repr__(self) -> str: ) return repr - def _validate_csv_files(self) -> dict: + def _validate_csv_files(self) -> Dict[str, Any]: """ Checks that Stan CSV output files for all chains are consistent and returns dict containing config and column names. @@ -910,13 +908,13 @@ def draws( if concat_chains: return flatten_chains(self._draws[start_idx:, :, :]) if inc_sample: - return np.dstack( # type: ignore + return np.dstack( ( np.delete(self.mcmc_sample.draws(), drop_cols, axis=1), self._draws, ) )[start_idx:, :, :] - return self._draws[start_idx:, :, :] # type: ignore + return self._draws[start_idx:, :, :] def draws_pd( self, @@ -1195,9 +1193,7 @@ def stan_variable( if len(col_idxs) > 0: dims.extend(self._metadata.stan_vars_dims[var]) # pylint: disable=redundant-keyword-arg - return self._draws[draw1:, :, col_idxs].reshape( # type: ignore - dims, order='F' - ) + return self._draws[draw1:, :, col_idxs].reshape(dims, order='F') def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]: """ @@ -1229,7 +1225,7 @@ def _assemble_generated_quantities(self) -> None: # use numpy loadtxt warmup = self.mcmc_sample.metadata.cmdstan_config['save_warmup'] num_draws = self.mcmc_sample.draws(inc_warmup=warmup).shape[0] - gq_sample = np.empty( + gq_sample: np.ndarray = np.empty( (num_draws, self.chains, len(self.column_names)), dtype=float, order='F', diff --git a/cmdstanpy/stanfit/mle.py b/cmdstanpy/stanfit/mle.py index a6f9a86d..683f680a 100644 --- a/cmdstanpy/stanfit/mle.py +++ b/cmdstanpy/stanfit/mle.py @@ -54,13 +54,9 @@ def _set_mle_attrs(self, sample_csv_0: str) -> None: meta = scan_optimize_csv(sample_csv_0, self._save_iterations) self._metadata = InferenceMetadata(meta) self._column_names: Tuple[str, ...] = meta['column_names'] - assert isinstance(meta['mle'], np.ndarray) # make the typechecker happy - self._mle = meta['mle'] + self._mle: np.ndarray = meta['mle'] if self._save_iterations: - assert isinstance( - meta['all_iters'], np.ndarray - ) # make the typechecker happy - self._all_iters = meta['all_iters'] + self._all_iters: np.ndarray = meta['all_iters'] @property def column_names(self) -> Tuple[str, ...]: @@ -202,13 +198,13 @@ def stan_variable( num_rows = self._all_iters.shape[0] else: num_rows = 1 + + result: Union[np.ndarray, float] if len(col_idxs) > 1: # container var dims = (num_rows,) + self._metadata.stan_vars_dims[var] # pylint: disable=redundant-keyword-arg if num_rows > 1: - result = self._all_iters[:, col_idxs].reshape( # type: ignore - dims, order='F' - ) + result = self._all_iters[:, col_idxs].reshape(dims, order='F') else: result = self._mle[col_idxs].reshape(dims[1:], order="F") else: # scalar var @@ -217,9 +213,7 @@ def stan_variable( result = self._all_iters[:, col_idx] else: result = float(self._mle[col_idx]) - assert isinstance( - result, (np.ndarray, float) - ) # make the typechecker happy + return result def stan_variables( diff --git a/cmdstanpy/stanfit/vb.py b/cmdstanpy/stanfit/vb.py index f094a7d3..7819c77e 100644 --- a/cmdstanpy/stanfit/vb.py +++ b/cmdstanpy/stanfit/vb.py @@ -126,6 +126,7 @@ def stan_variable( raise ValueError('Unknown variable name: {}'.format(var)) col_idxs = list(self._metadata.stan_vars_cols[var]) shape: Tuple[int, ...] = () + result: Union[np.ndarray, float] if len(col_idxs) > 1: shape = self._metadata.stan_vars_dims[var] result = np.asarray(self._variational_mean)[col_idxs].reshape( @@ -133,7 +134,6 @@ def stan_variable( ) else: result = float(self._variational_mean[col_idxs[0]]) - assert isinstance(result, (np.ndarray, float)) return result def stan_variables(self) -> Dict[str, Union[np.ndarray, float]]: diff --git a/cmdstanpy/utils.py b/cmdstanpy/utils.py index afe5c080..ccdca660 100644 --- a/cmdstanpy/utils.py +++ b/cmdstanpy/utils.py @@ -639,7 +639,7 @@ def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]: for line in fd: iters += 1 if save_iters: - all_iters = np.empty( + all_iters: np.ndarray = np.empty( (iters, len(dict['column_names'])), dtype=float, order='F' ) # rescan to capture estimates @@ -658,7 +658,7 @@ def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]: if save_iters: all_iters[i, :] = [float(x) for x in xs] if i == iters - 1: - mle = np.array(xs, dtype=float) + mle: np.ndarray = np.array(xs, dtype=float) dict['mle'] = mle if save_iters: dict['all_iters'] = all_iters @@ -944,7 +944,7 @@ def read_metric(path: str) -> List[int]: with open(path, 'r') as fd: metric_dict = json.load(fd) if 'inv_metric' in metric_dict: - dims_np = np.asarray(metric_dict['inv_metric']) + dims_np: np.ndarray = np.asarray(metric_dict['inv_metric']) return list(dims_np.shape) else: raise ValueError( diff --git a/pyproject.toml b/pyproject.toml index 8bdfc60b..a1f4b649 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,16 +14,17 @@ line_length = 80 disallow_untyped_defs = true disallow_incomplete_defs = true no_implicit_optional = true -# disallow_any_generics = true # disabled due to issues with numpy < 1.20 +# disallow_any_generics = true # disabled due to issues with numpy warn_return_any = true # warn_unused_ignores = true # can't be run on CI due to windows having different ctypes +check_untyped_defs = true +warn_redundant_casts = true +strict_equality = true +disallow_untyped_calls = true [[tool.mypy.overrides]] module = [ 'tqdm.auto', 'pandas', - 'ujson', - 'numpy', # these two are required for py36, which numpy 1.21 doesn't support - 'numpy.random' ] ignore_missing_imports = true diff --git a/requirements-test.txt b/requirements-test.txt index 2506c7c1..de7b489b 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -7,3 +7,4 @@ mypy testfixtures tqdm xarray +types-ujson diff --git a/setup.py b/setup.py index efe16632..fb6dd683 100644 --- a/setup.py +++ b/setup.py @@ -80,6 +80,7 @@ def get_version() -> str: ] }, install_requires=INSTALL_REQUIRES, + python_requires='>=3.7', extras_require=EXTRAS_REQUIRE, classifiers=_classifiers.strip().split('\n'), )