Skip to content

Commit

Permalink
add some print statements to dataset splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
tlpss committed Aug 25, 2023
1 parent 2a75411 commit f63daca
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion keypoint_detection/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
json_validation_dataset_path: str = None,
json_test_dataset_path=None,
augment_train: bool = False,
**kwargs
**kwargs,
):
super().__init__()
self.batch_size = batch_size
Expand All @@ -72,6 +72,7 @@ def __init__(
json_validation_dataset_path, keypoint_channel_configuration, **kwargs
)
else:
print(f"splitting the train set to create a validation set with ratio {validation_split_ratio} ")
self.train_dataset, self.validation_dataset = KeypointsDataModule._split_dataset(
self.train_dataset, validation_split_ratio
)
Expand Down Expand Up @@ -104,6 +105,8 @@ def _split_dataset(dataset, validation_split_ratio):
validation_size = int(validation_split_ratio * len(dataset))
train_size = len(dataset) - validation_size
train_dataset, validation_dataset = torch.utils.data.random_split(dataset, [train_size, validation_size])
print(f"train size: {len(train_dataset)}")
print(f"validation size: {len(validation_dataset)}")
return train_dataset, validation_dataset

def train_dataloader(self):
Expand Down

0 comments on commit f63daca

Please sign in to comment.