|
4 | 4 | import os |
5 | 5 | from enum import Enum, auto |
6 | 6 | from time import time |
7 | | -from typing import Any, Dict, List, Optional, Union |
| 7 | +from typing import Any, Dict, List, Mapping, Optional, Union |
8 | 8 |
|
9 | 9 | import numpy as np |
10 | 10 | from numpy.random import RandomState |
@@ -712,7 +712,7 @@ def __init__( |
712 | 712 | method_args: Union[ |
713 | 713 | SamplerArgs, OptimizeArgs, GenerateQuantitiesArgs, VariationalArgs |
714 | 714 | ], |
715 | | - data: Union[str, Dict[str, Any], None] = None, |
| 715 | + data: Union[Mapping[str, Any], str, None] = None, |
716 | 716 | seed: Union[int, List[int], None] = None, |
717 | 717 | inits: Union[int, float, str, List[str], None] = None, |
718 | 718 | output_dir: Optional[str] = None, |
@@ -811,6 +811,7 @@ def validate(self) -> None: |
811 | 811 | 'Argument "sig_figs" must be an integer between 1 and 18,' |
812 | 812 | ' found {}'.format(self.sig_figs) |
813 | 813 | ) |
| 814 | + # TODO: remove at some future release |
814 | 815 | if cmdstan_version_before(2, 25): |
815 | 816 | self.sig_figs = None |
816 | 817 | get_logger().warning( |
@@ -897,7 +898,8 @@ def compose_command( |
897 | 898 | csv_file: str, |
898 | 899 | *, |
899 | 900 | diagnostic_file: Optional[str] = None, |
900 | | - profile_file: Optional[str] = None |
| 901 | + profile_file: Optional[str] = None, |
| 902 | + num_chains: Optional[int] = None |
901 | 903 | ) -> List[str]: |
902 | 904 | """ |
903 | 905 | Compose CmdStan command for non-default arguments. |
@@ -932,13 +934,15 @@ def compose_command( |
932 | 934 | cmd.append('init={}'.format(self.inits[idx])) |
933 | 935 | cmd.append('output') |
934 | 936 | cmd.append('file={}'.format(csv_file)) |
935 | | - if diagnostic_file is not None: |
| 937 | + if diagnostic_file: |
936 | 938 | cmd.append('diagnostic_file={}'.format(diagnostic_file)) |
937 | | - if profile_file is not None: |
| 939 | + if profile_file: |
938 | 940 | cmd.append('profile_file={}'.format(profile_file)) |
939 | 941 | if self.refresh is not None: |
940 | 942 | cmd.append('refresh={}'.format(self.refresh)) |
941 | 943 | if self.sig_figs is not None: |
942 | 944 | cmd.append('sig_figs={}'.format(self.sig_figs)) |
943 | 945 | cmd = self.method_args.compose(idx, cmd) |
| 946 | + if num_chains: |
| 947 | + cmd.append('num_chains={}'.format(num_chains)) |
944 | 948 | return cmd |
0 commit comments