Skip to content

Commit

Permalink
Replace deprecated hub utils in train_unconditional_ort (huggingfac…
Browse files Browse the repository at this point in the history
…e#1504)

* Replace deprecated hub utils in `train_unconditional_ort`

* typo
  • Loading branch information
anton-l authored and Thomas Capelle committed Dec 12, 2022
1 parent b9fb425 commit 40c0865
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions examples/unconditional_image_generation/train_unconditional_ort.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
import math
import os
from pathlib import Path
from typing import Optional

import torch
import torch.nn.functional as F
Expand All @@ -9,9 +11,9 @@
from accelerate.logging import get_logger
from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from huggingface_hub import HfFolder, Repository, whoami
from onnxruntime.training.ortmodule import ORTModule
from torchvision.transforms import (
CenterCrop,
Expand All @@ -28,6 +30,16 @@
logger = get_logger(__name__)


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"


def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator(
Expand Down Expand Up @@ -113,8 +125,22 @@ def transforms(examples):

ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)

if args.push_to_hub:
repo = init_git_repo(args, at_init=True)
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)

with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

if accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0]
Expand Down Expand Up @@ -186,10 +212,9 @@ def transforms(examples):

if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
pipeline.save_pretrained(args.output_dir)
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
accelerator.wait_for_everyone()

accelerator.end_training()
Expand Down

0 comments on commit 40c0865

Please sign in to comment.