From f63daca07b059081ef05385b491c261c74ad9843 Mon Sep 17 00:00:00 2001 From: tlpss Date: Fri, 25 Aug 2023 09:14:26 +0200 Subject: [PATCH] add some print statements to dataset splitting --- keypoint_detection/data/datamodule.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/keypoint_detection/data/datamodule.py b/keypoint_detection/data/datamodule.py index 302dae1..acd53f7 100644 --- a/keypoint_detection/data/datamodule.py +++ b/keypoint_detection/data/datamodule.py @@ -55,7 +55,7 @@ def __init__( json_validation_dataset_path: str = None, json_test_dataset_path=None, augment_train: bool = False, - **kwargs + **kwargs, ): super().__init__() self.batch_size = batch_size @@ -72,6 +72,7 @@ def __init__( json_validation_dataset_path, keypoint_channel_configuration, **kwargs ) else: + print(f"splitting the train set to create a validation set with ratio {validation_split_ratio} ") self.train_dataset, self.validation_dataset = KeypointsDataModule._split_dataset( self.train_dataset, validation_split_ratio ) @@ -104,6 +105,8 @@ def _split_dataset(dataset, validation_split_ratio): validation_size = int(validation_split_ratio * len(dataset)) train_size = len(dataset) - validation_size train_dataset, validation_dataset = torch.utils.data.random_split(dataset, [train_size, validation_size]) + print(f"train size: {len(train_dataset)}") + print(f"validation size: {len(validation_dataset)}") return train_dataset, validation_dataset def train_dataloader(self):