diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 2b5082f3d9a1e8..8ae06d9cfeefe5 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -772,6 +772,13 @@ def __iter__(self): for i in process_slice: yield current_batch[i] + def __len__(self): + # Will raise an error if the underlying dataset is not sized. + if self.drop_last: + return len(self.dataset) // self.num_processes + else: + return math.ceil(len(self.dataset) / self.num_processes) + # In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer # helper methods here