Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,12 @@ def init_args_from_command_line(
], # TODO: add "delayed" option back in when supported
help="Type of fp8 linear quantization to apply to the model",
)
parser.add_argument(
"--training.gc_freq",
type=int,
default=50,
help="Python garbage control scheduling interval, in steps",
)

# activation checkpointing
parser.add_argument(
Expand Down
9 changes: 9 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import contextlib
import gc
import os

from dataclasses import dataclass, field
Expand Down Expand Up @@ -101,6 +102,11 @@ def main(job_config: JobConfig):
init_logger()
logger.info(f"Starting job: {job_config.job.description}")

# take control of garbage collection to avoid stragglers
_gc_freq = job_config.training.gc_freq
gc.disable()
gc.collect(1)

# init world mesh
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
Expand Down Expand Up @@ -231,6 +237,9 @@ def main(job_config: JobConfig):

while train_state.step < job_config.training.steps:
train_state.step += 1
if train_state.step > 1 and train_state.step % _gc_freq == 0:
gc.collect(1)

# get batch
data_load_start = timer()
batch = next(data_iterator)
Expand Down