Skip to content

Commit

Permalink
Some changes in comments for deepspeed saving to improve clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
rom1504 committed Jun 4, 2021
1 parent 58c6035 commit 914715a
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions train_dalle.py
@@ -1,5 +1,6 @@
import argparse
from pathlib import Path
import time

import torch
import wandb # Quit early if user doesn't have wandb installed.
Expand Down Expand Up @@ -325,6 +326,7 @@ def group_weight(model):
print(f"As such, they will require DeepSpeed as a dependency in order to resume from or generate with.")
print("See the deespeed conversion script for details on how to convert your ZeRO stage 2/3 checkpoint to a single file.")
print("If using a single GPU, consider running with apex automatic mixed precision instead for a similar speedup to ZeRO.")
time.sleep(2)

(distr_dalle, distr_opt, distr_dl, distr_scheduler) = distr_backend.distribute(
args=args,
Expand Down Expand Up @@ -366,7 +368,7 @@ def save_model(path):
),
}
torch.save(save_obj, str(cp_dir / DEEPSPEED_CP_AUX_FILENAME))
if deepspeed_config.get('zero_optimization', {}).get('stage', 0) >= 2: # see explanation for this at the top of the file
if deepspeed_config.get('zero_optimization', {}).get('stage', 0) >= 2: # see https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints
return

if not distr_backend.is_root_worker():
Expand All @@ -381,9 +383,8 @@ def save_model(path):

# training

# save the initial model
# that makes sure saving is working, to avoid waiting
# a long time and not getting any output
# Saves a checkpoint before training begins to fail early when mis-configured.
# See https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints
save_model(DALLE_OUTPUT_FILE_NAME)

for epoch in range(EPOCHS):
Expand Down

0 comments on commit 914715a

Please sign in to comment.