Skip to content

Commit

Permalink
Remove deprecated param calls
Browse files Browse the repository at this point in the history
  • Loading branch information
sdrobert committed Jan 30, 2024
1 parent 6c59d58 commit 1fca4d4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 36 deletions.
2 changes: 1 addition & 1 deletion src/pydrobert/torch/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class LangDataParams(param.Parameterized):

subset_ids = param.List(
[],
class_=str,
item_type=str,
bounds=None,
doc="A list of utterance ids. If non-empty, the data set will be "
"revalidateed to these utterances",
Expand Down
104 changes: 69 additions & 35 deletions src/pydrobert/torch/_pl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class PosixPath(param.Parameter):
default
Default value
always_exists
If :obj:`True`, setting to
If :obj:`True`, setting to
"""

__slots__ = "always_exists", "type"
Expand Down Expand Up @@ -104,11 +104,11 @@ class LitDataModuleParamsMetaclass(pabc.AbstractParameterizedMetaclass):
def __init__(mcs: "LitDataModuleParams", name, bases, dict_):
pclass = dict_["pclass"]
super().__init__(name, bases, dict_)
mcs.param.params()["common"].class_ = pclass
mcs.param.params()["train"].class_ = pclass
mcs.param.params()["val"].class_ = pclass
mcs.param.params()["test"].class_ = pclass
mcs.param.params()["predict"].class_ = pclass
mcs.param["common"].class_ = pclass
mcs.param["train"].class_ = pclass
mcs.param["val"].class_ = pclass
mcs.param["test"].class_ = pclass
mcs.param["predict"].class_ = pclass


class LitDataModuleParams(
Expand All @@ -121,28 +121,28 @@ class LitDataModuleParams(
prefer_split: bool = param.Boolean(True)

common: Optional[P] = param.ClassSelector(
param.Parameterized,
class_=param.Parameterized,
instantiate=False,
doc="Common data loader parameters. If set, cannot instantiate train, val, "
"test, or predict",
)
train: Optional[P] = param.ClassSelector(
param.Parameterized,
class_=param.Parameterized,
instantiate=False,
doc="Training data loader parameters. If set, cannot instantiate common",
)
val: Optional[P] = param.ClassSelector(
param.Parameterized,
class_=param.Parameterized,
instantiate=False,
doc="Validation data loader parameters. If set, cannot instantiate common",
)
test: Optional[P] = param.ClassSelector(
param.Parameterized,
class_=param.Parameterized,
instantiate=False,
doc="Test data loader parameters. If set, cannot instantiate common",
)
predict: Optional[P] = param.ClassSelector(
param.Parameterized,
class_=param.Parameterized,
instantiate=False,
doc="Prediction data loader parameters. If set, cannot instantiate common",
)
Expand Down Expand Up @@ -320,18 +320,26 @@ def num_workers(self) -> Optional[int]:
@property
def pin_memory(self) -> Optional[bool]:
"""Optional[bool]: whether to pin memory to the cuda device
If initially unset, the value will be populated during :func:`setup` based on
whether the trainer's accelerator is on the GPU.
"""
return self._pin_memory

@abc.abstractmethod
def construct_dataset(self, partition: Partition, path: str, params: P,) -> DS:
def construct_dataset(
self,
partition: Partition,
path: str,
params: P,
) -> DS:
...

def _construct_dataset_with_checks(
self, partition: Partition, path: Optional[str], params: Optional[P],
self,
partition: Partition,
path: Optional[str],
params: Optional[P],
) -> DS:
if path is None:
raise ValueError(f"Cannot construct {partition} dataset: no data directory")
Expand All @@ -342,7 +350,6 @@ def _construct_dataset_with_checks(
return self.construct_dataset(partition, path, params)

def setup(self, stage: Optional[str] = None):

if self._num_workers is None:
if self.trainer is not None and isinstance(
self.trainer.strategy, pl.strategies.DDPSpawnStrategy
Expand All @@ -354,23 +361,30 @@ def setup(self, stage: Optional[str] = None):
if self._pin_memory is None:
if self.trainer is not None:
self._pin_memory = isinstance(
self.trainer.accelerator, pl.accelerators.CUDAAccelerator,
self.trainer.accelerator,
pl.accelerators.CUDAAccelerator,
)
else:
self._pin_memory = True

if stage in {"fit", None}:
self.train_set = self._construct_dataset_with_checks(
"train", self.params.train_dir, self.params.train_params,
"train",
self.params.train_dir,
self.params.train_params,
)
if self.params.val_dir is not None:
self.val_set = self._construct_dataset_with_checks(
"val", self.params.val_dir, self.params.val_params,
"val",
self.params.val_dir,
self.params.val_params,
)

if stage in {"test", None}:
self.test_set = self._construct_dataset_with_checks(
"test", self.params.test_dir, self.params.test_params,
"test",
self.params.test_dir,
self.params.test_params,
)

if stage in {"predict", None}:
Expand All @@ -391,7 +405,10 @@ def construct_dataloader(self, partition: Partition, ds: DS, params: P) -> DL:
...

def _construct_dataloader_with_checks(
self, partition: Partition, ds: Optional[DS], params: Optional[P],
self,
partition: Partition,
ds: Optional[DS],
params: Optional[P],
) -> DL:
if params is None:
raise ValueError(
Expand Down Expand Up @@ -420,15 +437,19 @@ def dev_dataloader(self) -> DL:

def test_dataloader(self) -> DL:
return self._construct_dataloader_with_checks(
"test", self.test_set, self.params.test_params,
"test",
self.test_set,
self.params.test_params,
)

def predict_dataloader(self) -> DL:
params = self.params.predict_params
if params is None:
params = self.params.test_params
return self._construct_dataloader_with_checks(
"predict", self.predict_set, params,
"predict",
self.predict_set,
params,
)

@classmethod
Expand All @@ -450,7 +471,11 @@ def add_argparse_args(
)

grp = pargparse.add_deserialization_group_to_parser(
parser, pobj, "data_params", reckless=True, flag_format_str=read_format_str,
parser,
pobj,
"data_params",
reckless=True,
flag_format_str=read_format_str,
)

if include_overloads:
Expand Down Expand Up @@ -488,7 +513,9 @@ def add_argparse_args(

@classmethod
def from_argparse_args(
cls, namespace: argparse.Namespace, **kwargs,
cls,
namespace: argparse.Namespace,
**kwargs,
):
data_params = namespace.data_params
data_params.initialize_missing()
Expand All @@ -506,7 +533,9 @@ class LitLangDataModuleParams(LitDataModuleParams[LangDataLoaderParams]):
pclass = LangDataLoaderParams

vocab_size: Optional[int] = param.Integer(
None, bounds=(1, None), doc="Vocabulary size",
None,
bounds=(1, None),
doc="Vocabulary size",
)
info_path: Optional[str] = PosixPath(
None,
Expand Down Expand Up @@ -626,15 +655,15 @@ def get_info_dict_value(
@property
def vocab_size(self) -> Optional[int]:
"""int : vocabulary size
Alias of ``max_ref_class + 1``.
"""
return None if self.max_ref_class is None else self.max_ref_class + 1

@property
def batch_size(self) -> int:
"""int : training batch size
This property is just the value of ``self.params.train_params.batch_size``.
It is exposed in case ``auto_scale_batch_size`` is desired.
"""
Expand All @@ -647,36 +676,39 @@ def batch_size(self, batch_size: int):
@property
def feat_size(self) -> Optional[int]:
"""int : feature vector size
Alias of `num_filts`.
"""
return self.num_filts

@property
def max_ref_class(self) -> Optional[int]:
"""The maximum token id in the ref/ subdirectory (usually of training)
Corresponds to the
Corresponds to the
"""
return self.get_info_dict_value("max_ref_class")

def max_ali_class(self) -> Optional[int]:
"""int: the maximum token id in the ali/ subdirectory (usually of training)
Determined in :func:`setup` if `params.info_path` is not :obj:`None`.
"""
return self.get_info_dict_value("max_ali_class")

@property
def num_filts(self) -> Optional[int]:
"""int : size of the last dimension of tensors in feat/
Determined in :func:`setup` if `params.info_path` is not :obj:`None`.
"""
return None if self._info_dict is None else self._info_dict["num_filts"]

def construct_dataset(
self, partition: Partition, path: str, params: SpectDataLoaderParams,
self,
partition: Partition,
path: str,
params: SpectDataLoaderParams,
) -> SpectDataSet:
suppress_uttids = self.suppress_uttids
if suppress_uttids is None:
Expand All @@ -693,7 +725,6 @@ def construct_dataset(
)

def setup(self, stage: Optional[str] = None):

if self.params.info_path is not None and self._info_dict is None:
self._info_dict = dict()
with open(self.params.info_path) as f:
Expand All @@ -719,7 +750,10 @@ def setup(self, stage: Optional[str] = None):
super().setup(stage)

def construct_dataloader(
self, partition: Partition, ds: SpectDataSet, params: SpectDataLoaderParams,
self,
partition: Partition,
ds: SpectDataSet,
params: SpectDataLoaderParams,
) -> SpectDataLoader:
shuffle = self.shuffle
if shuffle is None:
Expand Down

0 comments on commit 1fca4d4

Please sign in to comment.